import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import itertools
import random
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
%matplotlib inline
# samples1-2 shape must be (n samples, 2)
def plot_samples(samples1, samples2=None):
fig, ax = plt.subplots()
ax.scatter(samples1[:,0], samples1[:,1], marker="x", color="blue")
if samples2 is not None:
ax.scatter(samples2[:,0], samples2[:,1], marker="x", color="red")
# we will fit a bivariate Gaussian
mean1 = 3
std1 = 2
mean2 = 1
std2 = 4
p = 0.7
mean = np.array([mean1, mean2])
cov = np.array(
[
[std1 ** 2, p * std1 * std2],
[p * std1 * std2, std2 ** 2]
]
)
target_samples = torch.from_numpy(np.random.multivariate_normal(mean, cov, 1000)).float()
plot_samples(target_samples)
class LinearFlow(nn.Module):
def __init__(self, size):
super().__init__()
# latent space is a bivariate Gaussian with independent coordinates
self.prior = torch.distributions.normal.Normal(torch.zeros(2), torch.ones(2))
# we don't make any constraint and hope for the best...
self.A = nn.Parameter(torch.empty((size, size)).normal_())
self.b = nn.Parameter(torch.zeros((size,)))
# transform a point from prior to the target distrib
# function g in the course
def forward(self, z):
return z @ self.A + self.b
# transform a point from the target distrib to the prior
# function f in the course
def inv(self, x):
return (x - self.b) @ self.A.inverse()
# sample from the distribution
def sample(self, n_samples):
# we first sample from the latent space
z = self.prior.sample((n_samples,))
# and we tranform these samples using a deterministic function
x = self(z)
return x
# compute the log probability of observations
# using the change of variable theorem
def log_prior(self, x):
A_inv = self.A.inverse()
z = (x - self.b) @ A_inv
# get the log absolute determinant of A
_, t = torch.slogdet(A_inv)
return self.prior.log_prob(z).sum(1) + t
trained_distrib = LinearFlow(2)
optimizer = torch.optim.SGD(trained_distrib.parameters(), lr=1e-3)
batch_size = 250
losses = list()
for _ in range(1500):
for i in range(0, target_samples.shape[0], batch_size):
batch = target_samples[i:i+batch_size]
optimizer.zero_grad()
loss = -trained_distrib.log_prior(batch).mean()
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(trained_distrib.parameters(), 5)
optimizer.step()
plt.plot(np.arange(len(losses)), losses)
with torch.no_grad():
samples = trained_distrib.sample(1000)
plot_samples(target_samples, samples)