# Lab exercise: Real NVP¶

In [ ]:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import itertools
import random
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline
In [ ]:
# samples1-2 shape must be (n samples, 2)
def plot_samples(samples1, samples2=None):
fig, ax = plt.subplots()

ax.scatter(samples1[:,0], samples1[:,1], marker="x", color="blue")
if samples2 is not None:
ax.scatter(samples2[:,0], samples2[:,1], marker="x", color="red")
In [ ]:
import sklearn.datasets

target_samples, target_classes = sklearn.datasets.make_moons(1000, noise=0.1)
target_samples = torch.from_numpy(target_samples).float()

plot_samples(target_samples)
In [ ]:
class RealNVPLayer(nn.Module):
def __init__(self, size, reverse=False):
super().__init__()

mid = int(size / 2)
if reverse:
else:

## the two operations
self.scale = nn.Sequential(
nn.Linear(size, 10),
nn.Tanh(),
nn.Linear(10, size),
)
self.transpose = nn.Sequential(
nn.Linear(size, 10),
nn.Tanh(),
nn.Linear(10, size),
)

# project from the latent space to the observed space,
# i.e. x = g(z)
def forward(self, z):
# you will need this!

x = #TODO
return x

# project from the observed space to the latent space,
# this function also return the log det jacobian of this inv function
def inv(self, x):
# you will need this!

# BEGIN TODO
z = # TODO
log_det_jacobian = # TODO
# END TODO

return z, log_det_jacobian
In [ ]:
# Test!
layer = RealNVPLayer(2, reverse=False)

x = torch.rand(1, 2)
z, _ = layer.inv(x)
xx = layer(z)

print("In the 3 vectors below, the first element must be equal")
print("This two vectors should be equal:")
print(x)
print(xx)
print("This vector should be different to the two above")
print(z)
print()

layer = RealNVPLayer(2, reverse=True)

x = torch.rand(1, 2)
z, _ = layer.inv(x)
xx = layer(z)

print("In the 3 vectors below, the second element must be equal")
print("This two vectors should be equal:")
print(x)
print(xx)
print("This vector should be different to the two above")
print(z)
In [ ]:
class RealNVP(nn.Module):
def __init__(self, size, n_layers):
super().__init__()

self.prior = torch.distributions.normal.Normal(torch.zeros(2), torch.ones(2))

self.layers = nn.ModuleList(
RealNVPLayer(size, i % 2 == 0)
for i in range(n_layers)
)

def forward(self, z):
x = z
for i in range(len(self.layers)):
x = self.layers[i](x)
return x

def inv(self, x):
log_det_jacobian = 0.
z = x
for i in reversed(range(len(self.layers))):
z, j = self.layers[i].inv(z)
# remember here, we just have to sum all log det jacobians!
log_det_jacobian = log_det_jacobian + j
return z, log_det_jacobian

def sample(self, n_samples):
z = self.prior.sample((n_samples,))
x = self(z)
return x

def log_prior(self, z):
x, det = self.inv(z)
ret = self.prior.log_prob(x).sum(1) + det

return ret
In [ ]:
trained_distrib = RealNVP(2, 50)

batch_size = 1000
losses = list()
for _ in range(500):
for i in range(0, target_samples.shape[0], batch_size):
batch = target_samples[i:i+batch_size]

loss = -trained_distrib.log_prior(batch).mean()
losses.append(loss.item())

loss.backward()
optimizer.step()

plt.plot(np.arange(len(losses)), losses)
In [ ]:
# sample from the model