SIR model¶
Here we calibrate a differentiable version of the SIR model. We use the exact same model as https://ndlib.readthedocs.io/en/latest/reference/models/epidemics/SIR.html, but implemented in a differentiable way.
In [1]:
Copied!
from blackbirds.models.sir import SIR
from blackbirds.infer.vi import VI
from blackbirds.simulate import simulate_and_observe_model
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import networkx
import normflows as nf
import pygtc
from blackbirds.models.sir import SIR
from blackbirds.infer.vi import VI
from blackbirds.simulate import simulate_and_observe_model
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import networkx
import normflows as nf
import pygtc
Generating synthetic true data¶
In [2]:
Copied!
device = "cpu"
device = "cpu"
In [3]:
Copied!
# generate a random graph
n_agents = 1000
graph = networkx.watts_strogatz_graph(n_agents, 10, 0.1)
# generate a random graph
n_agents = 1000
graph = networkx.watts_strogatz_graph(n_agents, 10, 0.1)
In [4]:
Copied!
sir = SIR(graph, n_timesteps=100, device=device)
sir = SIR(graph, n_timesteps=100, device=device)
In [5]:
Copied!
%%time
# the simulator takes as parameters the log10 of the fraction of initial cases, beta, and gamma parameters
true_parameters = torch.log10(torch.tensor([0.05, 0.05, 0.05], device=device))
data = sir.run_and_observe(true_parameters)
true_infected, true_recovered = data
%%time
# the simulator takes as parameters the log10 of the fraction of initial cases, beta, and gamma parameters
true_parameters = torch.log10(torch.tensor([0.05, 0.05, 0.05], device=device))
data = sir.run_and_observe(true_parameters)
true_infected, true_recovered = data
CPU times: user 43.5 ms, sys: 19.3 ms, total: 62.8 ms Wall time: 48 ms
In [6]:
Copied!
f, ax = plt.subplots()
ax.plot(true_infected.cpu(), label = "active infected")
ax.set_xlabel("Time-step")
ax.plot(true_recovered.cpu(), label = "cumulative recovered")
ax.legend()
f, ax = plt.subplots()
ax.plot(true_infected.cpu(), label = "active infected")
ax.set_xlabel("Time-step")
ax.plot(true_recovered.cpu(), label = "cumulative recovered")
ax.legend()
Out[6]:
<matplotlib.legend.Legend at 0x2a951e380>
Approximating the posterior by a normalizing flow¶
We construct the flow using the normflows library (https://github.com/VincentStimper/normalizing-flows )
In this case we define Neural Spline Flow with 4 transformations, each parametrised by 2 layers with 64 hidden units.
In [7]:
Copied!
def make_flow(n_parameters, device):
K = 16
torch.manual_seed(0)
flows = []
for i in range(K):
flows.append(nf.flows.MaskedAffineAutoregressive(n_parameters, 20, num_blocks=2))
flows.append(nf.flows.Permute(n_parameters, mode="swap"))
q0 = nf.distributions.DiagGaussian(n_parameters)
nfm = nf.NormalizingFlow(q0=q0, flows=flows)
return nfm.to(device)
def make_flow(n_parameters, device):
K = 16
torch.manual_seed(0)
flows = []
for i in range(K):
flows.append(nf.flows.MaskedAffineAutoregressive(n_parameters, 20, num_blocks=2))
flows.append(nf.flows.Permute(n_parameters, mode="swap"))
q0 = nf.distributions.DiagGaussian(n_parameters)
nfm = nf.NormalizingFlow(q0=q0, flows=flows)
return nfm.to(device)
In [8]:
Copied!
# Plot the inital flow:
flow = make_flow(len(true_parameters), device=device)
samples = flow.sample(10000)[0].cpu().detach().numpy()
pygtc.plotGTC([samples], truths=true_parameters.cpu().numpy(), figureSize=7, paramNames=[r"$I_0$", r"$\beta$", r"$\gamma$"]);
# Plot the inital flow:
flow = make_flow(len(true_parameters), device=device)
samples = flow.sample(10000)[0].cpu().detach().numpy()
pygtc.plotGTC([samples], truths=true_parameters.cpu().numpy(), figureSize=7, paramNames=[r"$I_0$", r"$\beta$", r"$\gamma$"]);
Let's also plot runs sampled from the untrained flow, to compare later with the trained flow.
In [9]:
Copied!
f, ax = plt.subplots()
for i in range(15):
with torch.no_grad():
sim_sir = sir.run_and_observe(flow.sample(1)[0][0])
ax.plot(sim_sir[0].cpu().numpy(), color = "C0", alpha=0.5)
ax.plot(sim_sir[1].cpu().numpy(), color = "C1", alpha=0.5)
ax.plot([], [], color = "C0", label = "predicted infected")
ax.plot([], [], color = "C1", label = "predicted recovered")
ax.plot(data[0].cpu(), color = "black", linewidth=2, label = "data infected")
ax.plot(data[1].cpu(), color = "black", linewidth=2, label = "data recovered", linestyle="--")
ax.legend()
f, ax = plt.subplots()
for i in range(15):
with torch.no_grad():
sim_sir = sir.run_and_observe(flow.sample(1)[0][0])
ax.plot(sim_sir[0].cpu().numpy(), color = "C0", alpha=0.5)
ax.plot(sim_sir[1].cpu().numpy(), color = "C1", alpha=0.5)
ax.plot([], [], color = "C0", label = "predicted infected")
ax.plot([], [], color = "C1", label = "predicted recovered")
ax.plot(data[0].cpu(), color = "black", linewidth=2, label = "data infected")
ax.plot(data[1].cpu(), color = "black", linewidth=2, label = "data recovered", linestyle="--")
ax.legend()
Out[9]:
<matplotlib.legend.Legend at 0x2aa81a590>
Train the flow¶
In [10]:
Copied!
torch.manual_seed(0)
class L2Loss:
def __init__(self, model):
self.model = model
self.loss_fn = torch.nn.MSELoss()
def __call__(self, params, data):
observed_outputs = simulate_and_observe_model(self.model, params, gradient_horizon=0)
return self.loss_fn(observed_outputs[0], data[0])
prior = torch.distributions.MultivariateNormal(-2.0 * torch.ones(3, device=device), torch.eye(3, device=device))
loss = L2Loss(sir)
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)
w = 100
vi = VI(loss = loss,
posterior_estimator = flow,
prior=prior,
optimizer=optimizer,
w=w,
n_samples_per_epoch=10,
log_tensorboard=True,
device=device
)
# and we run for 1000 epochs, stopping if the loss doesn't improve in 100 epochs.
vi.run(data, n_epochs=1000, max_epochs_without_improvement=50);
torch.manual_seed(0)
class L2Loss:
def __init__(self, model):
self.model = model
self.loss_fn = torch.nn.MSELoss()
def __call__(self, params, data):
observed_outputs = simulate_and_observe_model(self.model, params, gradient_horizon=0)
return self.loss_fn(observed_outputs[0], data[0])
prior = torch.distributions.MultivariateNormal(-2.0 * torch.ones(3, device=device), torch.eye(3, device=device))
loss = L2Loss(sir)
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)
w = 100
vi = VI(loss = loss,
posterior_estimator = flow,
prior=prior,
optimizer=optimizer,
w=w,
n_samples_per_epoch=10,
log_tensorboard=True,
device=device
)
# and we run for 1000 epochs, stopping if the loss doesn't improve in 100 epochs.
vi.run(data, n_epochs=1000, max_epochs_without_improvement=50);
22%|███████████████████████████████████████████ | 223/1000 [07:05<24:44, 1.91s/it, loss=6.05e+3, reg.=561, total=6.61e+3, best loss=2.73e+3, epochs since improv.=50]
In [11]:
Copied!
# Let's have a look at the loss over epochs:
df = pd.DataFrame(vi.losses_hist)
df.plot(logy=True)
# Let's have a look at the loss over epochs:
df = pd.DataFrame(vi.losses_hist)
df.plot(logy=True)
Out[11]:
<Axes: >
In [12]:
Copied!
# We can load the best model to check the results
flow.load_state_dict(vi.best_estimator_state_dict)
# We can load the best model to check the results
flow.load_state_dict(vi.best_estimator_state_dict)
Out[12]:
<All keys matched successfully>
In [13]:
Copied!
# Plot the final flow posterior approximator and compare it to the real parameters:
samples = flow.sample(50000)[0].cpu().detach().numpy()
#corner(samples, truths=true_parameters.numpy(), smooth=2, range = ((-2, -1.0), (-1.7, -1.0), (-1.4, -1.25)), labels=["initial_fraction_infected", r"$\beta$", r"$\gamma$"]);
#corner(samples, truths=true_parameters.numpy(), smooth=2, labels=["initial_fraction_infected", r"$\beta$", r"$\gamma$"]);
pygtc.plotGTC([samples], truths=true_parameters.cpu().numpy(), figureSize=10, priors=[(-2, 1) for i in range(3)], paramRanges=[(-3.0, -0.5) for i in range(3)]);
# Plot the final flow posterior approximator and compare it to the real parameters:
samples = flow.sample(50000)[0].cpu().detach().numpy()
#corner(samples, truths=true_parameters.numpy(), smooth=2, range = ((-2, -1.0), (-1.7, -1.0), (-1.4, -1.25)), labels=["initial_fraction_infected", r"$\beta$", r"$\gamma$"]);
#corner(samples, truths=true_parameters.numpy(), smooth=2, labels=["initial_fraction_infected", r"$\beta$", r"$\gamma$"]);
pygtc.plotGTC([samples], truths=true_parameters.cpu().numpy(), figureSize=10, priors=[(-2, 1) for i in range(3)], paramRanges=[(-3.0, -0.5) for i in range(3)]);
In [14]:
Copied!
# compare the predictions to the synthetic data:
f, ax = plt.subplots()
for i in range(25):
with torch.no_grad():
sim_sir = sir.observe(sir.run((flow.sample(1)[0][0])))
ax.plot(sim_sir[0].cpu().numpy(), color = "C0", alpha=0.5)
ax.plot(sim_sir[1].cpu().numpy(), color = "C1", alpha=0.5)
ax.plot([], [], color = "C0", label = "predicted infected")
ax.plot([], [], color = "C1", label = "predicted recovered")
ax.plot(data[0].cpu(), color = "black", linewidth=2, label = "data infected")
ax.plot(data[1].cpu(), color = "black", linewidth=2, label = "data recovered", linestyle="--")
ax.legend()
# compare the predictions to the synthetic data:
f, ax = plt.subplots()
for i in range(25):
with torch.no_grad():
sim_sir = sir.observe(sir.run((flow.sample(1)[0][0])))
ax.plot(sim_sir[0].cpu().numpy(), color = "C0", alpha=0.5)
ax.plot(sim_sir[1].cpu().numpy(), color = "C1", alpha=0.5)
ax.plot([], [], color = "C0", label = "predicted infected")
ax.plot([], [], color = "C1", label = "predicted recovered")
ax.plot(data[0].cpu(), color = "black", linewidth=2, label = "data infected")
ax.plot(data[1].cpu(), color = "black", linewidth=2, label = "data recovered", linestyle="--")
ax.legend()
Out[14]:
<matplotlib.legend.Legend at 0x2aaf95930>