Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save karpanGit/d108058de79873b11247263f1ed344ba to your computer and use it in GitHub Desktop.

Select an option

Save karpanGit/d108058de79873b11247263f1ed344ba to your computer and use it in GitHub Desktop.
PyMC compare treatment to control, individual measurements
# compare treatment vs control, individual animals
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
# === REPLACE these with your 10 raw measurements per group ===
control = np.array([120, 118, 122, 121, 119, 117, 123, 116, 120, 119])
treatment = np.array([125, 128, 130, 127, 126, 124, 129, 125, 127, 126])
with pm.Model() as model:
# Priors on group means (weakly informative)
mu_c = pm.Normal("mu_c", mu=120, sigma=20)
mu_t = pm.Normal("mu_t", mu=120, sigma=20)
# Robust likelihood: StudentT (nu inferred)
sigma_c = pm.HalfNormal("sigma_c", sigma=10)
sigma_t = pm.HalfNormal("sigma_t", sigma=10)
nu_c = pm.Exponential("nu_minus_c", 1/29.) + 1 # nu > 1
nu_t = pm.Exponential("nu_minus_t", 1/29.) + 1 # nu > 1
# Likelihoods
obs_c = pm.StudentT("obs_c", mu=mu_c, sigma=sigma_c, nu=nu_c, observed=control)
obs_t = pm.StudentT("obs_t", mu=mu_t, sigma=sigma_t, nu=nu_t, observed=treatment)
# Deterministic: difference and effect size
delta = pm.Deterministic("delta", mu_t - mu_c)
# pooled sd for Cohen d
pooled_sd = pm.math.sqrt((sigma_c**2 + sigma_t**2) / 2)
cohens_d = pm.Deterministic("cohens_d", delta / pooled_sd)
# Sample
trace = pm.sample(2000, tune=2000, target_accept=0.95, return_inferencedata=True)
# Summary
print(az.summary(trace, var_names=["mu_c", "mu_t", "delta", "cohens_d", "sigma_c", "sigma_t", "nu_minus_c", "nu_minus_t"]))
# Posterior probability that treatment > control
p_gt = (trace.posterior["delta"] > 0).mean().item()
hdi = az.hdi(trace, var_names=["delta"], hdi_prob=0.95)
print(f"P(delta > 0) = {p_gt:.3f}")
print(f"95% HDI of delta: [{hdi['delta'].sel(hdi="lower").item():.2f}, {hdi['delta'].sel(hdi='higher').item():.2f}]")
# Plot posterior of delta
az.plot_posterior(trace, var_names=["delta", "cohens_d"])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment