Consider the following compartmental SEIR model (without births or deaths) for influenza:
where
The implemented model:
"""
Using diffrax to implement a basic compartmental
SEIR (Susceptible, Exposed, Infected, Recovered)
model for simulation purposes.
"""
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
@jax.jit
def ODE(t, y, params): # numpydoc ignore=GL08
S, E, I, R = y
beta, alpha, gamma = params
dS = -(beta * S * I)
dE = (beta * S * I) - (alpha * E)
dI = (alpha * E) - (gamma * I)
dR = gamma * I
return jnp.array([dS, dE, dI, dR])
# population size
N = 5000000
# transmission rate
beta = 0.4
# incubation period (~3 days)
alpha = 1 / 3
# recovery rate (~7 days)
gamma = 1 / 7
# initial compartment sizes and state, scaled
I0 = 35
S0 = N - I0
E0, R0 = 0, 0
init_state = jnp.array([S0/N, E0/N, I0/N, R0/N])
# parameters
params = (beta, alpha, gamma)
# epidemic evolution
ts = list(range(1, 200, 1))
dt0 = (ts[-1] - ts[0]) / len(ts)
# retrieving solution
solution = diffrax.diffeqsolve(
diffrax.ODETerm(ODE),
solver=diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=dt0,
args=params,
y0=init_state,
saveat=diffrax.SaveAt(ts=ts),
)