Variational Inference with classical posterior¶
In this notebook we set the VI loss to be the negative log-likelihood, to recover the classical posterior.
In [2]:
Copied!
from blackbirds.models.random_walk import RandomWalk
from blackbirds.infer.vi import VI
from blackbirds.posterior_estimators import TrainableGaussian
from blackbirds.simulate import simulate_and_observe_model
import torch
import math
import matplotlib.pyplot as plt
import pandas as pd
from blackbirds.models.random_walk import RandomWalk
from blackbirds.infer.vi import VI
from blackbirds.posterior_estimators import TrainableGaussian
from blackbirds.simulate import simulate_and_observe_model
import torch
import math
import matplotlib.pyplot as plt
import pandas as pd
In [9]:
Copied!
rw = RandomWalk(n_timesteps=40)
rw = RandomWalk(n_timesteps=40)
In [10]:
Copied!
true_p = torch.logit(torch.tensor(0.25))
data = rw.run_and_observe(torch.tensor([true_p]))
plt.plot(data[0].numpy())
true_p = torch.logit(torch.tensor(0.25))
data = rw.run_and_observe(torch.tensor([true_p]))
plt.plot(data[0].numpy())
Out[10]:
[<matplotlib.lines.Line2D at 0x14eb08a00>]
In [24]:
Copied!
class LogLikelihoodLoss:
def __init__(self, model):
self.model = model
def __call__(self, params, data):
N = self.model.n_timesteps
p = torch.sigmoid(params[0])
lp = 0
for n in range(1, N+1):
if data[0][n] == data[0][n-1] + 1:
lp += p.log()
else:
lp += (1 - p).log()
#k = int(data[0][n].item())
#likelihood = math.comb(n, (n+k)//2) * p**((n+k)//2) * (1-p)**((n-k)//2)
#lp += likelihood.log()
return -lp
class LogLikelihoodLoss:
def __init__(self, model):
self.model = model
def __call__(self, params, data):
N = self.model.n_timesteps
p = torch.sigmoid(params[0])
lp = 0
for n in range(1, N+1):
if data[0][n] == data[0][n-1] + 1:
lp += p.log()
else:
lp += (1 - p).log()
#k = int(data[0][n].item())
#likelihood = math.comb(n, (n+k)//2) * p**((n+k)//2) * (1-p)**((n-k)//2)
#lp += likelihood.log()
return -lp
In [48]:
Copied!
posterior_estimator = TrainableGaussian([0.], 1.0)
prior = torch.distributions.Normal(true_p + 0.1, 0.1)
optimizer = torch.optim.Adam(posterior_estimator.parameters(), 1e-2)
ll = LogLikelihoodLoss(rw)
vi = VI(ll, posterior_estimator=posterior_estimator, prior=prior, optimizer=optimizer, w = 1.0, n_samples_regularisation=1000)
posterior_estimator = TrainableGaussian([0.], 1.0)
prior = torch.distributions.Normal(true_p + 0.1, 0.1)
optimizer = torch.optim.Adam(posterior_estimator.parameters(), 1e-2)
ll = LogLikelihoodLoss(rw)
vi = VI(ll, posterior_estimator=posterior_estimator, prior=prior, optimizer=optimizer, w = 1.0, n_samples_regularisation=1000)
In [49]:
Copied!
# we can now train the estimator for a 100 epochs
vi.run(data, 1000, max_epochs_without_improvement=100)
# we can now train the estimator for a 100 epochs
vi.run(data, 1000, max_epochs_without_improvement=100)
28%|████████████████████████████████████████████████████████▌ | 284/1000 [00:05<00:13, 51.15it/s, loss=23.5, reg.=0.692, total=24.2, best loss=23.5, epochs since improv.=100]
In [50]:
Copied!
df = pd.DataFrame(vi.losses_hist)
df.plot()
df = pd.DataFrame(vi.losses_hist)
df.plot()
Out[50]:
<Axes: >
In [51]:
Copied!
# We can now load the best model
posterior_estimator.load_state_dict(vi.best_estimator_state_dict)
# We can now load the best model
posterior_estimator.load_state_dict(vi.best_estimator_state_dict)
Out[51]:
<All keys matched successfully>
In [52]:
Copied!
data_diff = data[0].diff()
p_hat = 0.5 * (1 + 1 / (len(data[0])-1) * data_diff.sum())
p_hat
data_diff = data[0].diff()
p_hat = 0.5 * (1 + 1 / (len(data[0])-1) * data_diff.sum())
p_hat
Out[52]:
tensor(0.5000)
In [53]:
Copied!
# and plot the posterior
with torch.no_grad():
samples = posterior_estimator.sample(20000)[0].flatten().cpu()
plt.hist(torch.sigmoid(samples), density=True, bins=100);
plt.axvline(torch.sigmoid(true_p), label = "true value", color = "black", linestyle=":")
# and plot the posterior
with torch.no_grad():
samples = posterior_estimator.sample(20000)[0].flatten().cpu()
plt.hist(torch.sigmoid(samples), density=True, bins=100);
plt.axvline(torch.sigmoid(true_p), label = "true value", color = "black", linestyle=":")
Out[53]:
<matplotlib.lines.Line2D at 0x14fafa3e0>
In [54]:
Copied!
# compare the predictions to the synthetic data:
f, ax = plt.subplots()
for i in range(50):
with torch.no_grad():
sim_rw = rw.run_and_observe(posterior_estimator.sample(1)[0])[0].numpy()
ax.plot(sim_rw, color = "C0", alpha=0.5)
ax.plot([], [], color = "C0", label = "predicted")
ax.plot(data[0], color = "black", linewidth=2, label = "data")
ax.legend()
# compare the predictions to the synthetic data:
f, ax = plt.subplots()
for i in range(50):
with torch.no_grad():
sim_rw = rw.run_and_observe(posterior_estimator.sample(1)[0])[0].numpy()
ax.plot(sim_rw, color = "C0", alpha=0.5)
ax.plot([], [], color = "C0", label = "predicted")
ax.plot(data[0], color = "black", linewidth=2, label = "data")
ax.legend()
Out[54]:
<matplotlib.legend.Legend at 0x14ecafbe0>
In [ ]:
Copied!