Linear flow example

In [ ]:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import itertools
import random
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd

%matplotlib inline
In [ ]:
# samples1-2 shape must be (n samples, 2)
def plot_samples(samples1, samples2=None):
    fig, ax = plt.subplots()

    ax.scatter(samples1[:,0], samples1[:,1], marker="x", color="blue")
    if samples2 is not None:
        ax.scatter(samples2[:,0], samples2[:,1], marker="x", color="red")
In [ ]:
# we will fit a bivariate Gaussian

mean1 = 3
std1 = 2
mean2 = 1
std2 = 4
p = 0.7
mean = np.array([mean1, mean2])
cov = np.array(
    [
        [std1 ** 2, p * std1 * std2],
        [p * std1 * std2, std2 ** 2]
    ]
)

target_samples = torch.from_numpy(np.random.multivariate_normal(mean, cov, 1000)).float()

plot_samples(target_samples)
In [ ]:
class LinearFlow(nn.Module):
    def __init__(self, size):
        super().__init__()
        
        # latent space is a bivariate Gaussian with independent coordinates
        self.prior = torch.distributions.normal.Normal(torch.zeros(2), torch.ones(2))
        
        # we don't make any constraint and hope for the best...
        self.A = nn.Parameter(torch.empty((size, size)).normal_())
        self.b = nn.Parameter(torch.zeros((size,)))
    
    # transform a point from prior to the target distrib
    # function g in the course
    def forward(self, z):
        return z @ self.A + self.b
    
    # transform a point from the target distrib to the prior
    # function f in the course
    def inv(self, x):
        return (x - self.b) @ self.A.inverse()
    
    # sample from the distribution
    def sample(self, n_samples):
        # we first sample from the latent space
        z = self.prior.sample((n_samples,))
        # and we tranform these samples using a deterministic function
        x = self(z)
        return x
    
    # compute the log probability of observations
    # using the change of variable theorem
    def log_prior(self, x):
        A_inv = self.A.inverse()
        z = (x - self.b) @ A_inv

        # get the log absolute determinant of A
        _, t = torch.slogdet(A_inv)
        
        return self.prior.log_prob(z).sum(1) + t
In [ ]:
trained_distrib = LinearFlow(2)
optimizer = torch.optim.SGD(trained_distrib.parameters(), lr=1e-3)

batch_size = 250
losses = list()
for _ in range(1500):
    for i in range(0, target_samples.shape[0], batch_size):
        batch = target_samples[i:i+batch_size]
        optimizer.zero_grad()

        loss = -trained_distrib.log_prior(batch).mean()
        losses.append(loss.item())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trained_distrib.parameters(), 5)
        optimizer.step()
    
plt.plot(np.arange(len(losses)), losses)
In [ ]:
with torch.no_grad():
    samples = trained_distrib.sample(1000)
    plot_samples(target_samples, samples)
In [ ]: