# Linear flow example¶

import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import itertools
import random
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline

# samples1-2 shape must be (n samples, 2)
def plot_samples(samples1, samples2=None):
fig, ax = plt.subplots()

ax.scatter(samples1[:,0], samples1[:,1], marker="x", color="blue")
if samples2 is not None:
ax.scatter(samples2[:,0], samples2[:,1], marker="x", color="red")

# we will fit a bivariate Gaussian

mean1 = 3
std1 = 2
mean2 = 1
std2 = 4
p = 0.7
mean = np.array([mean1, mean2])
cov = np.array(
[
[std1 ** 2, p * std1 * std2],
[p * std1 * std2, std2 ** 2]
]
)

target_samples = torch.from_numpy(np.random.multivariate_normal(mean, cov, 1000)).float()

plot_samples(target_samples)

class LinearFlow(nn.Module):
def __init__(self, size):
super().__init__()

# latent space is a bivariate Gaussian with independent coordinates
self.prior = torch.distributions.normal.Normal(torch.zeros(2), torch.ones(2))

# we don't make any constraint and hope for the best...
self.A = nn.Parameter(torch.empty((size, size)).normal_())
self.b = nn.Parameter(torch.zeros((size,)))

# transform a point from prior to the target distrib
# function g in the course
def forward(self, z):
return z @ self.A + self.b

# transform a point from the target distrib to the prior
# function f in the course
def inv(self, x):
return (x - self.b) @ self.A.inverse()

# sample from the distribution
def sample(self, n_samples):
# we first sample from the latent space
z = self.prior.sample((n_samples,))
# and we tranform these samples using a deterministic function
x = self(z)
return x

# compute the log probability of observations
# using the change of variable theorem
def log_prior(self, x):
A_inv = self.A.inverse()
z = (x - self.b) @ A_inv

# get the log absolute determinant of A
_, t = torch.slogdet(A_inv)

return self.prior.log_prob(z).sum(1) + t

trained_distrib = LinearFlow(2)
optimizer = torch.optim.SGD(trained_distrib.parameters(), lr=1e-3)

batch_size = 250
losses = list()
for _ in range(1500):
for i in range(0, target_samples.shape[0], batch_size):
batch = target_samples[i:i+batch_size]

loss = -trained_distrib.log_prior(batch).mean()
losses.append(loss.item())

loss.backward()

with torch.no_grad():