import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import itertools
import random
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
%matplotlib inline
# samples1-2 shape must be (n samples, 2)
def plot_samples(samples1, samples2=None):
fig, ax = plt.subplots()
ax.scatter(samples1[:,0], samples1[:,1], marker="x", color="blue")
if samples2 is not None:
ax.scatter(samples2[:,0], samples2[:,1], marker="x", color="red")
# More complicated example than the linear flow one:
# we fit 2 bivariate Gaussians!
mean1 = 3
std1 = 2
mean2 = 1
std2 = 4
p = 0.7
mean = np.array([mean1, mean2])
cov = np.array(
[
[std1 ** 2, p * std1 * std2],
[p * std1 * std2, std2 ** 2]
]
)
target_samples1 = torch.from_numpy(np.random.multivariate_normal(mean, cov, 500)).float()
mean1 = -5
std1 = 1
mean2 = 2
std2 = 1
p = -0.5
mean = np.array([mean1, mean2])
cov = np.array(
[
[std1 ** 2, p * std1 * std2],
[p * std1 * std2, std2 ** 2]
]
)
target_samples2 = torch.from_numpy(np.random.multivariate_normal(mean, cov, 500)).float()
target_samples = torch.cat([target_samples1, target_samples2], 0)
plot_samples(target_samples)
class NiceLayer(nn.Module):
def __init__(self, size, reverse=False):
super().__init__()
mid = int(size / 2)
self.mask = torch.zeros(size, requires_grad=False)
if reverse:
self.mask[mid:] = 1.
else:
self.mask[:mid] = 1.
self.h = nn.Sequential(
nn.Linear(size, 10),
nn.ReLU(),
nn.Linear(10, size),
)
def forward(self, z):
return z + (1 - self.mask) * self.h(z * self.mask)
def inv(self, x):
return x - (1 - self.mask) * self.h(x * self.mask)
class Nice(nn.Module):
def __init__(self, size, n_layers):
super().__init__()
self.prior = torch.distributions.normal.Normal(torch.zeros(2), torch.ones(2))
self.layers = nn.ModuleList(
NiceLayer(size, i % 2 == 0)
for i in range(n_layers)
)
# the last layer of generation is the scale layer
self.scale = nn.Parameter(torch.rand((size,)))
# transform a point from latent space to observation space
def forward(self, z):
x = z
for i in range(len(self.layers)):
x = self.layers[i](x)
x = x * self.scale
return x
# transform a point from observation space to latent space
def inv(self, x):
z = x / self.scale
for i in reversed(range(len(self.layers))):
z = self.layers[i].inv(z)
return z
# sample from the distribution
def sample(self, n_samples):
z = self.prior.sample((n_samples,))
x = self(z)
return x
# compute the log probability of observations
# using the change of variable theorem
def log_prior(self, x):
z = self.inv(x)
ret = self.prior.log_prob(z).sum(1) + (1. / self.scale).prod().abs().log()
return ret
trained_distrib = Nice(2, 10)
optimizer = torch.optim.Adam(trained_distrib.parameters(), lr=1e-3)
batch_size = 1000
losses = list()
for _ in range(1500):
for i in range(0, target_samples.shape[0], batch_size):
batch = target_samples[i:i+batch_size]
optimizer.zero_grad()
loss = -trained_distrib.log_prior(batch).mean()
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(trained_distrib.parameters(), 5)
optimizer.step()
plt.plot(np.arange(len(losses)), losses)
with torch.no_grad():
samples = trained_distrib.sample(1000)
plot_samples(target_samples, samples)
# display latent space
with torch.no_grad():
source_sample1 = trained_distrib.inv(target_samples1)
source_sample2 = trained_distrib.inv(target_samples2)
plot_samples(source_sample1, source_sample2)