Skip to content

Instantly share code, notes, and snippets.

@O957
Last active May 21, 2024 17:43
Show Gist options
  • Select an option

  • Save O957/2da7d10d4713a8c28a436e7940bc9de7 to your computer and use it in GitHub Desktop.

Select an option

Save O957/2da7d10d4713a8c28a436e7940bc9de7 to your computer and use it in GitHub Desktop.
An SEIR model for simulation, implemented in diffrax, with parameters geared towards influenza.

Consider the following compartmental SEIR model (without births or deaths) for influenza:

$$ \begin{aligned} S' &= -\beta S I \\ E' &= \beta S I - \alpha E\\ I' &= \alpha E - \gamma I \\ R' &= \gamma I \end{aligned} $$

where $\beta$ is the transmission rate, $\alpha$ the rate of exposed depletion (incubation period), and $\gamma$ the recovery rate.

The implemented model:

"""
Using diffrax to implement a basic compartmental
SEIR (Susceptible, Exposed, Infected, Recovered)
model for simulation purposes.
"""

import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


@jax.jit
def ODE(t, y, params):  # numpydoc ignore=GL08
    S, E, I, R = y
    beta, alpha, gamma = params
    dS = -(beta * S * I)
    dE = (beta * S * I) - (alpha * E)
    dI = (alpha * E) - (gamma * I)
    dR = gamma * I
    return jnp.array([dS, dE, dI, dR])

# population size
N = 5000000

# transmission rate
beta = 0.4

# incubation period (~3 days)
alpha = 1 / 3

# recovery rate (~7 days)
gamma = 1 / 7

# initial compartment sizes and state, scaled
I0 = 35
S0 = N - I0
E0, R0 = 0, 0
init_state = jnp.array([S0/N, E0/N, I0/N, R0/N])

# parameters
params = (beta, alpha, gamma)

# epidemic evolution
ts = list(range(1, 200, 1))
dt0 = (ts[-1] - ts[0]) / len(ts)

# retrieving solution
solution = diffrax.diffeqsolve(
    diffrax.ODETerm(ODE),
    solver=diffrax.Tsit5(),
    t0=ts[0],
    t1=ts[-1],
    dt0=dt0,
    args=params,
    y0=init_state,
    saveat=diffrax.SaveAt(ts=ts),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment