MALA MCMC on conjugate prior-likelihood pairs¶
In [1]:
Copied!
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as scistats
import torch
import torch.distributions
from blackbirds.infer import mcmc
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as scistats
import torch
import torch.distributions
from blackbirds.infer import mcmc
Conjugate Normal¶
In [2]:
Copied!
mu_0, sigma_0 = 1., 2.
prior = torch.distributions.normal.Normal(mu_0, sigma_0)
mu_0, sigma_0 = 1., 2.
prior = torch.distributions.normal.Normal(mu_0, sigma_0)
In [3]:
Copied!
sigma = 1.
def negative_log_likelihood(theta, data):
dist = torch.distributions.normal.Normal(theta, sigma)
return - dist.log_prob(data).sum()
sigma = 1.
def negative_log_likelihood(theta, data):
dist = torch.distributions.normal.Normal(theta, sigma)
return - dist.log_prob(data).sum()
In [4]:
Copied!
data_size = 3
data = torch.distributions.normal.Normal(-1., sigma).sample((data_size,))
data_size = 3
data = torch.distributions.normal.Normal(-1., sigma).sample((data_size,))
In [5]:
Copied!
data
data
Out[5]:
tensor([-1.5064, -0.6129, 0.4516])
Sampling¶
In [6]:
Copied!
mala = mcmc.MALA(prior, negative_log_likelihood, w=1.)
mala = mcmc.MALA(prior, negative_log_likelihood, w=1.)
In [7]:
Copied!
sampler = mcmc.MCMC(mala, 10_000)
sampler = mcmc.MCMC(mala, 10_000)
In [8]:
Copied!
trial_samples = sampler.run(torch.tensor([-.5]), data)
trial_samples = sampler.run(torch.tensor([-.5]), data)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:17<00:00, 574.66it/s, Acceptance rate=0.288]
In [9]:
Copied!
thinned_trial_samples = torch.stack(trial_samples)[::100].T
scale = torch.cov(thinned_trial_samples)
thinned_trial_samples = torch.stack(trial_samples)[::100].T
scale = torch.cov(thinned_trial_samples)
In [10]:
Copied!
mala = mcmc.MALA(prior, negative_log_likelihood, 1.)
sampler = mcmc.MCMC(mala, 20_000)
post_samples = sampler.run(torch.tensor([thinned_trial_samples.mean()]), data, scale=scale)
mala = mcmc.MALA(prior, negative_log_likelihood, 1.)
sampler = mcmc.MCMC(mala, 20_000)
post_samples = sampler.run(torch.tensor([thinned_trial_samples.mean()]), data, scale=scale)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [00:35<00:00, 571.00it/s, Acceptance rate=0.759]
In [11]:
Copied!
plt.hist(torch.stack(post_samples).T.numpy()[0, ::100], density=True)
x = np.linspace(-2, 3., 1000)
plt.plot(x, scistats.norm.pdf(x,
(mu_0/sigma_0**2 + data.sum()/sigma**2)/(1/sigma_0**2 + data_size/sigma**2),
1/np.sqrt((1/sigma_0**2 + data_size/sigma**2))))
plt.hist(torch.stack(post_samples).T.numpy()[0, ::100], density=True)
x = np.linspace(-2, 3., 1000)
plt.plot(x, scistats.norm.pdf(x,
(mu_0/sigma_0**2 + data.sum()/sigma**2)/(1/sigma_0**2 + data_size/sigma**2),
1/np.sqrt((1/sigma_0**2 + data_size/sigma**2))))
Out[11]:
[<matplotlib.lines.Line2D at 0x7f80a8c14640>]
Conjugate multivariate Normal¶
In [12]:
Copied!
mu_0, sigma_0 = torch.tensor([2., 0.]), torch.tensor([[2., 0.,], [0., 1.]])
prior = torch.distributions.multivariate_normal.MultivariateNormal(mu_0, sigma_0)
mu_0, sigma_0 = torch.tensor([2., 0.]), torch.tensor([[2., 0.,], [0., 1.]])
prior = torch.distributions.multivariate_normal.MultivariateNormal(mu_0, sigma_0)
In [13]:
Copied!
prior_samples = prior.sample((1000,))
plt.scatter(prior_samples[:,0], prior_samples[:, 1], alpha=0.5)
plt.ylim([-4,4])
plt.xlim([-2,6])
prior_samples = prior.sample((1000,))
plt.scatter(prior_samples[:,0], prior_samples[:, 1], alpha=0.5)
plt.ylim([-4,4])
plt.xlim([-2,6])
Out[13]:
(-2.0, 6.0)
In [14]:
Copied!
sigma = torch.tensor([[1., 0.4,], [0.4, 2.]])
def negative_log_likelihood(data, theta):
dist = torch.distributions.multivariate_normal.MultivariateNormal(theta, sigma)
return - dist.log_prob(data).sum()
sigma = torch.tensor([[1., 0.4,], [0.4, 2.]])
def negative_log_likelihood(data, theta):
dist = torch.distributions.multivariate_normal.MultivariateNormal(theta, sigma)
return - dist.log_prob(data).sum()
In [15]:
Copied!
data_size = 3
true_mean = torch.tensor([-1., 2.])
true_density = torch.distributions.multivariate_normal.MultivariateNormal(true_mean, sigma)
data = true_density.sample((data_size,))
data_size = 3
true_mean = torch.tensor([-1., 2.])
true_density = torch.distributions.multivariate_normal.MultivariateNormal(true_mean, sigma)
data = true_density.sample((data_size,))
In [16]:
Copied!
true_density_samples = true_density.sample((1000,))
plt.scatter(true_density_samples[:,0], true_density_samples[:,1], alpha=0.5)
true_density_samples = true_density.sample((1000,))
plt.scatter(true_density_samples[:,0], true_density_samples[:,1], alpha=0.5)
Out[16]:
<matplotlib.collections.PathCollection at 0x7f80a819af50>
In [17]:
Copied!
data
data
Out[17]:
tensor([[-1.7573, 0.1746], [-1.2320, 4.3385], [-2.6572, 0.2604]])
Sampling¶
In [18]:
Copied!
mala = mcmc.MALA(prior, negative_log_likelihood, w=1.)
mala = mcmc.MALA(prior, negative_log_likelihood, w=1.)
In [19]:
Copied!
sampler = mcmc.MCMC(mala, 20_000)
sampler = mcmc.MCMC(mala, 20_000)
In [20]:
Copied!
trial_samples = sampler.run(torch.tensor([-2., -.5]), data)
trial_samples = sampler.run(torch.tensor([-2., -.5]), data)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [00:45<00:00, 434.89it/s, Acceptance rate=0.106]
In [21]:
Copied!
thinned_trial_samples = torch.stack(trial_samples)[::100].T
cov = torch.cov(thinned_trial_samples)
thinned_trial_samples = torch.stack(trial_samples)[::100].T
cov = torch.cov(thinned_trial_samples)
In [22]:
Copied!
init_state = thinned_trial_samples.mean(dim=1)
init_state = thinned_trial_samples.mean(dim=1)
In [23]:
Copied!
mala = mcmc.MALA(prior, negative_log_likelihood, 1.)
sampler = mcmc.MCMC(mala, 20_000)
post_samples = sampler.run(init_state, data, covariance=cov)
mala = mcmc.MALA(prior, negative_log_likelihood, 1.)
sampler = mcmc.MCMC(mala, 20_000)
post_samples = sampler.run(init_state, data, covariance=cov)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [00:44<00:00, 450.00it/s, Acceptance rate=0.642]
In [24]:
Copied!
inv_sigma_0 = torch.inverse(sigma_0)
inv_sigma = torch.inverse(sigma)
inv_additions = torch.inverse(inv_sigma_0 + data_size*inv_sigma)
true_mean = torch.matmul(inv_additions,
(torch.matmul(inv_sigma_0, mu_0) + data_size*torch.matmul(inv_sigma, data.mean(dim=0))))
true_cov = torch.inverse(inv_sigma_0 + data_size*inv_sigma)
true_post = torch.distributions.multivariate_normal.MultivariateNormal(true_mean, true_cov)
true_post_samples = true_post.sample((1000,))
inv_sigma_0 = torch.inverse(sigma_0)
inv_sigma = torch.inverse(sigma)
inv_additions = torch.inverse(inv_sigma_0 + data_size*inv_sigma)
true_mean = torch.matmul(inv_additions,
(torch.matmul(inv_sigma_0, mu_0) + data_size*torch.matmul(inv_sigma, data.mean(dim=0))))
true_cov = torch.inverse(inv_sigma_0 + data_size*inv_sigma)
true_post = torch.distributions.multivariate_normal.MultivariateNormal(true_mean, true_cov)
true_post_samples = true_post.sample((1000,))
In [25]:
Copied!
post_samples_numpy = torch.stack(post_samples).T.numpy()
plt.scatter(post_samples_numpy[0, ::100], post_samples_numpy[1, ::100], alpha=0.5, c='b')
plt.scatter(true_post_samples[:, 0], true_post_samples[:, 1], alpha=0.5, c='r', marker='x')
post_samples_numpy = torch.stack(post_samples).T.numpy()
plt.scatter(post_samples_numpy[0, ::100], post_samples_numpy[1, ::100], alpha=0.5, c='b')
plt.scatter(true_post_samples[:, 0], true_post_samples[:, 1], alpha=0.5, c='r', marker='x')
Out[25]:
<matplotlib.collections.PathCollection at 0x7f80a81e5930>
In [ ]:
Copied!