Given observations
"""
A short example guide on using numpyro declaratively, i.e.
with out specifying defining a log posterior explicitly.
"""
import jax.random as jr
import numpyro as npro
import numpyro.distributions as dist
from jax.typing import ArrayLike
# define Gaussian distribution
mu = 0.0
sigma = 3.0
gaussian_dist = dist.Normal(mu, sigma)
# simulate data
num_samples = 500
reproducible_seed = 3243432
jr_key = jr.PRNGKey(reproducible_seed)
data = gaussian_dist.sample(jr_key, (num_samples,))
def model(observations=None) -> dict[str, ArrayLike]:
"""
Numpyro model to retrieve parameters from data.
Parameters
----------
observations : ArrayLike, optional
An array of observed data. Defaults to None. If
None, then data is generated; otherwise, posterior
samples are returned.
Returns
-------
dict[str, ArrayLike]
Dictionary of posterior samples (ArrayLike) for each
parameter in presence of observations; otherwise,
samples are simulated data with the given priors.
"""
# priors
mu = npro.sample("mu", dist.Uniform(0.0, 3.0))
sigma = npro.sample("sigma", dist.Uniform(0.0, 9.0))
# likelihood
npro.sample("obs", dist.Normal(mu, sigma), obs=observations)
# infer parameter values from observations
nuts_kernel = npro.infer.NUTS(model)
mcmc_with_obs = npro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_with_obs.run(jr_key, observations=data)
mcmc_with_obs.print_summary()
posterior_samples = mcmc_with_obs.get_samples()
# simulate data wo/ observations (prior predictive distribution)
nuts_kernel = npro.infer.NUTS(model)
mcmc_no_obs = npro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_no_obs.run(jr_key, observations=None)
mcmc_no_obs.print_summary()
generated_samples = mcmc_no_obs.get_samples()The output from inferring parameter values from observations:
mean std median 5.0% 95.0% n_eff r_hat
mu 0.19 0.11 0.17 0.01 0.36 559.51 1.00
sigma 3.02 0.11 3.02 2.86 3.20 424.99 1.00
Number of divergences: 0
The output from simulating data wo/ observations:
mean std median 5.0% 95.0% n_eff r_hat
mu 1.47 0.88 1.49 0.03 2.72 462.29 1.01
obs 1.61 5.72 1.50 -8.07 11.21 326.14 1.00
sigma 4.73 2.50 4.77 1.33 9.00 150.84 1.00
Number of divergences: 30