NICE flow example

In [ ]:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import itertools
import random
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd

%matplotlib inline
In [ ]:
# samples1-2 shape must be (n samples, 2)
def plot_samples(samples1, samples2=None):
    fig, ax = plt.subplots()

    ax.scatter(samples1[:,0], samples1[:,1], marker="x", color="blue")
    if samples2 is not None:
        ax.scatter(samples2[:,0], samples2[:,1], marker="x", color="red")
In [ ]:
# More complicated example than the linear flow one:
# we fit 2 bivariate Gaussians!

mean1 = 3
std1 = 2
mean2 = 1
std2 = 4
p = 0.7
mean = np.array([mean1, mean2])
cov = np.array(
    [
        [std1 ** 2, p * std1 * std2],
        [p * std1 * std2, std2 ** 2]
    ]
)
target_samples1 = torch.from_numpy(np.random.multivariate_normal(mean, cov, 500)).float()


mean1 = -5
std1 = 1
mean2 = 2
std2 = 1
p = -0.5
mean = np.array([mean1, mean2])
cov = np.array(
    [
        [std1 ** 2, p * std1 * std2],
        [p * std1 * std2, std2 ** 2]
    ]
)
target_samples2 = torch.from_numpy(np.random.multivariate_normal(mean, cov, 500)).float()

target_samples = torch.cat([target_samples1, target_samples2], 0)
plot_samples(target_samples)
In [ ]:
class NiceLayer(nn.Module):
    def __init__(self, size, reverse=False):
        super().__init__()
        
        mid = int(size / 2)
        self.mask = torch.zeros(size, requires_grad=False)
        if reverse:
            self.mask[mid:] = 1.
        else:
            self.mask[:mid] = 1.

        
        self.h = nn.Sequential(
            nn.Linear(size, 10),
            nn.ReLU(),
            nn.Linear(10, size),
        )
                
    def forward(self, z):
        return z + (1 - self.mask) * self.h(z * self.mask)
    
    def inv(self, x):
        return x - (1 - self.mask) * self.h(x * self.mask)


class Nice(nn.Module):
    def __init__(self, size, n_layers):
        super().__init__()
        
        self.prior = torch.distributions.normal.Normal(torch.zeros(2), torch.ones(2))
        
        self.layers = nn.ModuleList(
                            NiceLayer(size, i % 2 == 0)
                            for i in range(n_layers)
        )
        
        # the last layer of generation is the scale layer
        self.scale = nn.Parameter(torch.rand((size,)))
    
    # transform a point from latent space to observation space
    def forward(self, z):
        x = z
        for i in range(len(self.layers)):
            x = self.layers[i](x)
        x = x * self.scale
        return x
    
    # transform a point from observation space to latent space
    def inv(self, x):
        z = x / self.scale
        for i in reversed(range(len(self.layers))):
            z = self.layers[i].inv(z)
        return z

    
    # sample from the distribution
    def sample(self, n_samples):
        z = self.prior.sample((n_samples,))
        x = self(z)
        return x


    # compute the log probability of observations
    # using the change of variable theorem
    def log_prior(self, x):
        z = self.inv(x)
        ret = self.prior.log_prob(z).sum(1) + (1. / self.scale).prod().abs().log()
        
        return ret
In [ ]:
trained_distrib = Nice(2, 10)
optimizer = torch.optim.Adam(trained_distrib.parameters(), lr=1e-3)

batch_size = 1000
losses = list()
for _ in range(1500):
    for i in range(0, target_samples.shape[0], batch_size):
        batch = target_samples[i:i+batch_size]
        optimizer.zero_grad()

        loss = -trained_distrib.log_prior(batch).mean()
        losses.append(loss.item())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trained_distrib.parameters(), 5)
        optimizer.step()
    
plt.plot(np.arange(len(losses)), losses)
In [ ]:
with torch.no_grad():
    samples = trained_distrib.sample(1000)
    plot_samples(target_samples, samples)
In [ ]:
# display latent space
with torch.no_grad():
    source_sample1 = trained_distrib.inv(target_samples1)
    source_sample2 = trained_distrib.inv(target_samples2)
    plot_samples(source_sample1, source_sample2)
In [ ]: