Skip to content

Instantly share code, notes, and snippets.

@O957
Last active May 18, 2024 05:26
Show Gist options
  • Select an option

  • Save O957/9917cbce59a4168717fd186b07e7ad6f to your computer and use it in GitHub Desktop.

Select an option

Save O957/9917cbce59a4168717fd186b07e7ad6f to your computer and use it in GitHub Desktop.
Example for declarative use of Numpyro to retrieve posterior samples and simulate data.

Given observations $$x_i \sim \mathcal{N}(\mu=0, \sigma=3) \quad i=1,2,\dotsc,500$$ and priors $\mu \sim \mathcal{U}(0,3)$ and $\sigma \sim \mathcal{U}(0, 9)$, produce 1000 samples from the posterior distribution and also simulate 1000 samples from the priors.

"""
A short example guide on using numpyro declaratively, i.e.
with out specifying defining a log posterior explicitly.
"""

import jax.random as jr
import numpyro as npro
import numpyro.distributions as dist
from jax.typing import ArrayLike

# define Gaussian distribution
mu = 0.0
sigma = 3.0
gaussian_dist = dist.Normal(mu, sigma)

# simulate data
num_samples = 500
reproducible_seed = 3243432
jr_key = jr.PRNGKey(reproducible_seed)
data = gaussian_dist.sample(jr_key, (num_samples,))


def model(observations=None) -> dict[str, ArrayLike]:
    """
    Numpyro model to retrieve parameters from data.

    Parameters
    ----------
    observations : ArrayLike, optional
        An array of observed data. Defaults to None. If
        None, then data is generated; otherwise, posterior
        samples are returned.

    Returns
    -------
    dict[str, ArrayLike]
        Dictionary of posterior samples (ArrayLike) for each
        parameter in presence of observations; otherwise,
        samples are simulated data with the given priors.
    """

    # priors
    mu = npro.sample("mu", dist.Uniform(0.0, 3.0))
    sigma = npro.sample("sigma", dist.Uniform(0.0, 9.0))

    # likelihood
    npro.sample("obs", dist.Normal(mu, sigma), obs=observations)


# infer parameter values from observations
nuts_kernel = npro.infer.NUTS(model)
mcmc_with_obs = npro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_with_obs.run(jr_key, observations=data)
mcmc_with_obs.print_summary()
posterior_samples = mcmc_with_obs.get_samples()

# simulate data wo/ observations (prior predictive distribution)
nuts_kernel = npro.infer.NUTS(model)
mcmc_no_obs = npro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_no_obs.run(jr_key, observations=None)
mcmc_no_obs.print_summary()
generated_samples = mcmc_no_obs.get_samples()

The output from inferring parameter values from observations:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      0.19      0.11      0.17      0.01      0.36    559.51      1.00
     sigma      3.02      0.11      3.02      2.86      3.20    424.99      1.00

Number of divergences: 0

The output from simulating data wo/ observations:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      1.47      0.88      1.49      0.03      2.72    462.29      1.01
       obs      1.61      5.72      1.50     -8.07     11.21    326.14      1.00
     sigma      4.73      2.50      4.77      1.33      9.00    150.84      1.00

Number of divergences: 30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment