{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deep Probabilistic Generative Models - Restricted Boltzmann Machine\n",
"\n",
"In this lab exercise, you will code a Restricted Boltzmann Machine with Gaussian observed random variables and Bernoulli latent variables. The explanation of the model and the derivation of all formulas are given in the PDF available from the course website, lab exercise section. Please, read it carefully to understand what is going on!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import scipy.stats\n",
"import matplotlib.pyplot as plt\n",
"import itertools\n",
"import random\n",
"import math\n",
"import time\n",
"import sklearn.datasets\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.autograd as autograd\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# samples1-2 shape must be (n samples, 2)\n",
"def plot_samples(samples1, samples2=None):\n",
" fig, ax = plt.subplots()\n",
"\n",
" ax.scatter(samples1[:,0], samples1[:,1], marker=\"x\", color=\"blue\")\n",
" if samples2 is not None:\n",
" ax.scatter(samples2[:,0], samples2[:,1], marker=\"x\", color=\"red\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data generation\n",
"\n",
"The data we will use in the lab exercise looks like a ring.\n",
"You can see the it would be difficult fit this dataset with a Gaussian Mixture Model: we would need to use many Gaussians and each one would imperfectly fit a subset of the data space."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target_samples, y = sklearn.datasets.make_circles(n_samples=1000, noise=0.1, factor=0.2)\n",
"target_samples = target_samples[y == 0]\n",
"target_samples = torch.from_numpy(target_samples).float()\n",
"\n",
"plot_samples(target_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Restricted Boltzmann Machine code\n",
"\n",
"The class below represents a restricted Boltzmann Machine.\n",
"There are 3 functions you need to code:\n",
"\n",
"- forward: it should return the log-partition of the input, i.e. c(x)\n",
"- p_x_given_z: it should return the mean and variance for each observed variable\n",
"- p_z_given_x: it should return the bernoulli parameter for each latent variable\n",
"\n",
"Note that due to batching operations (i.e. each function takes as input a batch of data, not a single data point), computations that contains the matrix W will be tricky to implement!\n",
"Each time you have a sum with W appearing inside (look at the PDF!) it should be implemented as a matrix multplication, i.e. with the symbol @.\n",
"As an example, in the function p_x_given_z bellow you will have the following operation: \"z @ self.W.T\".\n",
"You should take time to to understand how operation broadcasting work and how to implement batched operation in this RBM. You will have to play with the transpose operator...\n",
"You should probably read this page on operation broadcasting: https://pytorch.org/docs/stable/notes/broadcasting.html\n",
"\n",
"**Advice:**\n",
"\n",
"- for each tensor, think about what each dimension contains\n",
"- how does it compares to the way we wrote it mathematically in the PDF?\n",
"\n",
"**Sampling functions in Pytorch:**\n",
"You will need the following sampling function from Pytorch:\n",
"\n",
"- torch.normal\n",
"- torch.bernoulli\n",
"\n",
"**WARNING:** Read the documentation! torch.normal takes as argument the **standard deviation** and not the **variance**, so you should pass sigma_squared.sqrt()! Always be careful when you use functions from a library: take time to read the documentation.\n",
"\n",
"**Stable operation**:\n",
"\n",
"To compute the value log(1 + exp(...)), you should use the softplus functino which is more stable: https://pytorch.org/docs/stable/nn.functional.html?highlight=softplus#torch.nn.functional.softplus"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class RBM(nn.Module):\n",
" def __init__(self, n_visible = 2, n_hidden=10):\n",
" super().__init__()\n",
" self.n_visible = n_visible\n",
" self.n_hidden = n_hidden\n",
" \n",
" # Create parameters:\n",
" # See the PDF file for a description of each parameter\n",
" \n",
" self.b = nn.Parameter(torch.empty(n_visible))\n",
" # because sigma squared should be greater than zero (variance),\n",
" # we instead take parameter log_sigma_squared which is unconstrained.\n",
" # as usual, to get sigma_squared you can do:\n",
" # sigma_squared = self.log_sigma_squared.exp()\n",
" self.log_sigma_squared = nn.Parameter(torch.empty(n_visible))\n",
" \n",
" self.W = nn.Parameter(torch.empty(n_visible, n_hidden))\n",
" self.d = nn.Parameter(torch.empty(n_hidden))\n",
" \n",
" self.init_params()\n",
" \n",
" # Initialize parameters:\n",
" # unfortunately, the RBM and training loop we are going to code\n",
" # is really sensitve to initialization...\n",
" # this is not perfect, but at least it works! :)\n",
" def init_params(self):\n",
" with torch.no_grad():\n",
" self.b.fill_(0.)\n",
" self.log_sigma_squared.fill_(math.log(0.1**2))\n",
" self.W.normal_(0., 0.01)\n",
" self.d.normal_(0., 0.1)\n",
"\n",
" # returns the log-partition function with fixed x, i.e. c(x)\n",
" # - shape of input x: (batch size, n_visible)\n",
" # - shape of return tensor: (batch size,)\n",
" def forward(self, x):\n",
" # BEGIN TODO\n",
" ...\n",
" # END TODO\n",
" \n",
" # compute the parameters of the conditional distribution p(x | z)\n",
" # - shape of input z: (batch size, n hidden)\n",
" # - shapes of the two returned tensors:\n",
" # 1. mean of the Gaussians mu: (batch size, n_visible)\n",
" # 2. variance of the Gaussians sigma: (batch size, n_visible)\n",
" def p_x_given_z(self, z):\n",
" sigma_squared = self.log_sigma_squared.exp()\n",
" # BEGIN TODO\n",
" ...\n",
" # END TODO\n",
" \n",
" # note the variance is independent of the hidden variable,\n",
" # it is fixed, so we just \"extend\" the dimension of the vector\n",
" # to fit the output shape\n",
" return mu, sigma_squared.unsqueeze(0).repeat(z.shape[0], 1)\n",
" \n",
" # compute the parameters of the Bernoulli conditional distribution p(z | x)\n",
" # - shape of input x: (batch size, n visible)\n",
" # - output shape: (batch size, n hidden)\n",
" def p_z_given_x(self, x):\n",
" # BEGIN TODO\n",
" ...\n",
" # END TODO"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check shapes of your RBM!\n",
"batch_size = 15\n",
"\n",
"machine = RBM()\n",
"\n",
"x = torch.ones((batch_size, machine.n_visible))\n",
"f = machine(x)\n",
"\n",
"if f.shape != (batch_size,):\n",
" print(\"Expected shape: \", (batch_size,))\n",
" print(\"Returned shape: \", f.shape)\n",
" raise RuntimeWarning(\"The shape of the log-partition tensor is wrong!\")\n",
"\n",
"\n",
"z = torch.ones((batch_size, machine.n_hidden))\n",
"mu, sigma_squared = machine.p_x_given_z(z)\n",
"\n",
"if mu.shape != (batch_size, machine.n_visible):\n",
" print(\"Expected shape: \", (batch_size, machine.n_visible))\n",
" print(\"Returned shape: \", mu.shape)\n",
" raise RuntimeWarning(\"The shape of the Gaussian mean tensor is wrong!\")\n",
"\n",
"if sigma_squared.shape != (batch_size, machine.n_visible):\n",
" print(\"Expected shape: \", (batch_size, machine.n_visible))\n",
" print(\"Returned shape: \", sigma_squared.shape)\n",
" raise RuntimeWarning(\"The shape of the Gaussian variance tensor is wrong!\")\n",
"\n",
"\n",
"mu = machine.p_z_given_x(x)\n",
"if mu.shape != (batch_size, machine.n_hidden):\n",
" print(\"Expected shape: \", (batch_size, machine.n_hidden))\n",
" print(\"Returned shape: \", mu.shape)\n",
" raise RuntimeWarning(\"The shape of the Bernoulli parameter tensor is wrong!\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training loop\n",
"\n",
"During training, we must approximate the loss via Markov Chain Monte Carlo.\n",
"It usual to rely on contrastive divergence in the case of RBM training:\n",
"\n",
"- we use a single sample per batch data point in the Monte Carlo approximation of the partition function contribution to the loss\n",
"- we start one Markov Chain per data point in the batch\n",
"- we only do a few steps, e.g. in general one! this is parameterized by k below\n",
"- we take the last sample from each chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = target_samples.shape[0] # we just feed all data at once, you can try with smaller batch size\n",
"n_epoch = 300 # number of epochs\n",
"k = 1 # number of step to take in the Gibbs sampler for contrastive loss\n",
"\n",
"n_visible = target_samples.shape[1]\n",
"n_hidden = 100\n",
"\n",
"\n",
"machine = RBM(n_visible, n_hidden)\n",
"# we fix the variance of the observed variables,\n",
"# its easier to train this way!\n",
"machine.log_sigma_squared.requires_grad_(False)\n",
"# optimizer\n",
"optimizer = torch.optim.SGD(machine.parameters(), lr=1e-3)\n",
"\n",
"\n",
"losses = list()\n",
"for _ in range(n_epoch):\n",
" for i in range(0, target_samples.shape[0], batch_size):\n",
" optimizer.zero_grad()\n",
" data_samples = target_samples[i:i+batch_size]\n",
"\n",
" with torch.no_grad():\n",
" # generate sample from your model!\n",
" # and store the in a variable called model_samples\n",
" # BEGIN TODO\n",
" ...\n",
" # END TODO\n",
" \n",
" # the loss function is quite simple\n",
" loss = (machine(data_samples) - machine(mc_samples)).mean()\n",
" losses.append(loss.item())\n",
"\n",
" loss.backward()\n",
" # we use gradient clipping to stabilize training\n",
" torch.nn.utils.clip_grad_norm_(machine.parameters(), 1)\n",
" optimizer.step()\n",
" \n",
"# plot the training loss\n",
"plt.plot(np.arange(len(losses)), losses)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sampling from the trained RBM\n",
"\n",
"We now turn to sampling from our trained model to check if it models the data correctly.\n",
"We will run many markov chains in parallel and do several visualization.\n",
"In the first one, we will just plot the samples at the timestep of each markov chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Sample from the distribution!\n",
"# just\n",
"\n",
"n_chains = 500\n",
"n_steps = 1000\n",
"with torch.no_grad():\n",
" # the vector x contains sample of the observed random variables at each time step for each markov chain\n",
" # at each time step you must do an inplace operation, i.e. x[timestep] = ...\n",
" x = torch.empty(n_steps, n_chains, machine.n_visible)\n",
" \n",
" # we don't plot the hidden random variables so we won't keep track of them,\n",
" # just use a temporary variable z\n",
" # we start with a random vector of binary variables\n",
" z = torch.bernoulli(torch.empty(n_chains, machine.n_hidden).fill_(0.5))\n",
" for timestep in range(n_steps):\n",
" # do one step of Gibbs sampling here and fill x!\n",
" \n",
" # BEGIN TODO\n",
" ...\n",
" # END TODO\n",
"\n",
"# plot the samples at the last timestep of each markov chaine\n",
"plot_samples(target_samples, x[-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualization of the Gibbs sampling Markov Chain Monter Carlo sampling process\n",
"\n",
"We can visualization the sampling process in two different ways:\n",
"\n",
"- first, we can plot at the sample of each markov chain at each time step\n",
"- second, we focus on a single markov chain and observe how it evolves\n",
"\n",
"Note that we use the values x of before for the animation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib.lines as mlines\n",
"from matplotlib import animation, rc\n",
"from IPython.display import HTML\n",
"\n",
"\n",
"# this is just code for building the animation\n",
"# http://louistiao.me/posts/notebooks/embedding-matplotlib-animations-in-jupyter-as-interactive-javascript-widgets/\n",
"def animate_mc(animation_data):\n",
" # First set up the figure, the axis, and the plot element we want to animate\n",
" fig, ax = plt.subplots()\n",
"\n",
" ax.set_xlim(( -1.5, 1.5))\n",
" ax.set_ylim((-1.5, 1.5))\n",
"\n",
" sct = ax.scatter([], [], lw=2)\n",
"\n",
" # initialization function: plot the background of each frame\n",
" def init():\n",
" sct.set_offsets([])\n",
" return (sct,)\n",
"\n",
" # animation function. This is called sequentially\n",
" def animate(i):\n",
" #fig = plt.figure()\n",
" #ax = plt.axes(xlim=(-1, 1), ylim=(-1, 1))\n",
"\n",
" data = animation_data[i]\n",
" \n",
" sct.set_offsets(data)\n",
" #sct.set_array(dev_labels)\n",
" \n",
" return sct,\n",
" \n",
" # call the animator. blit=True means only re-draw the parts that have changed.\n",
" anim = animation.FuncAnimation(fig, animate, init_func=init,\n",
" frames=len(animation_data), interval=20, blit=True)\n",
"\n",
" return anim.to_jshtml()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show the samples of each markov chain at each time step\n",
"# you can clic on - to slow down the animation\n",
"HTML(animate_mc(x[:20].numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this is just code for building the animation\n",
"# http://louistiao.me/posts/notebooks/embedding-matplotlib-animations-in-jupyter-as-interactive-javascript-widgets/\n",
"def animate_mc_persitent(animation_data):\n",
" # First set up the figure, the axis, and the plot element we want to animate\n",
" fig, ax = plt.subplots()\n",
"\n",
" ax.set_xlim(( -1.5, 1.5))\n",
" ax.set_ylim((-1.5, 1.5))\n",
"\n",
" sct = ax.scatter([], [], lw=2)\n",
"\n",
" # initialization function: plot the background of each frame\n",
" def init():\n",
" sct.set_offsets([])\n",
" return (sct,)\n",
"\n",
" # animation function. This is called sequentially\n",
" def animate(i):\n",
" #fig = plt.figure()\n",
" #ax = plt.axes(xlim=(-1, 1), ylim=(-1, 1))\n",
"\n",
" data = animation_data[:i]\n",
" \n",
" sct.set_offsets(data)\n",
" #sct.set_array(dev_labels)\n",
" \n",
" return sct,\n",
" \n",
" # call the animator. blit=True means only re-draw the parts that have changed.\n",
" anim = animation.FuncAnimation(fig, animate, init_func=init,\n",
" frames=len(animation_data), interval=20, blit=True)\n",
"\n",
" return anim.to_jshtml()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# animation of the evolution of a single chain\n",
"# we show the 300 undred first steps of the first chain\n",
"HTML(animate_mc_persitent(x[:300,0].numpy()))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}