Deep Probabilistice Generative Models - Variational Auto-Encoders

For this lab exercise, you have to use Pytorch. If you don't know Pytorch or you don't know how to use it, please start by reading the following link:

When you use Pytorch function, it is important to carefully read the documentation to understand how the different functions, what kind of transformation they apply etc. Searching on stackoverflow is not enough. :) You have to take time to read it carefully to understand what you are doing.

The most import tip: keep track of the dimension of your tensor and what information they contain!

For examples, you can look at the GMM optimization via gradient descent notebook from the first course and at the MNIST classification example below. Don't just run them, try to understand them! Could you recode them from scratch without internet/example access?

0. Preprocessing

Import the libraries we will use. You can download the file dataset_loader from the course website.

In [ ]:
import torch
import torch.nn as nn
import torch.optim
import torch.nn.functional as F

import os
import numpy as np
import matplotlib.pyplot as plt
import math
import matplotlib
import math

# To load the data we will use the script of Gaetan Marceau Caron
# You can download it from the course webiste and move it to the same directory that contains this ipynb file
import dataset_loader

%matplotlib inline

Next, we download and preprocess the MNIST dataset, see http://yann.lecun.com/exdb/mnist/ . Our observed random variables will be Bernoulli distributions, so we need to binarize the dataset, i.e. each pixel can either be black or white.

In [ ]:
# Download mnist dataset 
if("mnist.pkl.gz" not in os.listdir(".")):
    !wget http://deeplearning.net/data/mnist/mnist.pkl.gz

# if you have it somewhere else, you can comment the lines above
# and overwrite the path below
mnist_path = "./mnist.pkl.gz"

# the dataset contains 3 splits (train/dev/test),
# each one containing two vectors (pixels and classes)
(train_data_pixels, train_data_classes), \
(dev_data_pixels, dev_data_classes), _ = dataset_loader.load_mnist(mnist_path)
In [ ]:
# to binarize, we just set to 1 pixels with a value >= 0.5, 0 otherwise
train_data_pixels = 1*(train_data_pixels >= 0.5)
dev_data_pixels = 1*(dev_data_pixels >= 0.5)
In [ ]:
# dataset_loader return numpy tensors,
# we convert them directly to pytorch tensors
train_data_pixels = torch.from_numpy(train_data_pixels).float()
train_data_classes = torch.from_numpy(train_data_classes).long()

dev_data_pixels = torch.from_numpy(dev_data_pixels).float()
dev_data_classes = torch.from_numpy(dev_data_classes).long()
In [ ]:
# there 50000 images, each image is of size 28x28
# each image as been reshaped as vector of size 784
print("Shape of the pixel tensor:", train_data_pixels.shape)
print("Shape of the class tensor:", train_data_classes.shape)
In [ ]:
# we will use this function to print images generated by our network
# the parameter images must be a tensor of shape (n images, 784)
def display_images(images):
    # convert to numpy tensor
    images = images.detach().numpy()

    for i in range(images.shape[0]):
        plt.imshow(images[i].reshape(28,28), cmap='Greys')
        plt.show()
        
# display the 5 first image in the dataset
display_images(train_data_pixels[:5])

1. MNIST classification with Pytorch

In this first part, you need to study the code to understand how Pytorch works. :) The example is a simple classifier for MNIST data.

TODO

  • can you write mathematically the function computed by this network?
  • what are the dimensions of each parameter?

It is very import that you are able to write that and understand what tensor operation it does. This is the root of machine learning and especially deep learning. To help you, draw the matrices/vectors!

In [ ]:
# to create a neural network (or a part of a neural network),
# we create a class that extends nn.Module
class Classifier(nn.Module):
    # the constructor, parameters are self-documented by their name :)
    # input_dim: in our case it will the size of a single input image
    # hidden_dim: the hidden representation dim
    # output_dim: the number of class, in our case there are 10 digits
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_ratio=0.5):
        super().__init__()
        
        # we have two projections in the network,
        # the first one project the input in the hidden space,
        # the second one project the hidden representation in the output space.
        # have a look at the documentation of nn.Linear!
        # nn.Linear extends nn.Module
        self.proj1 = nn.Linear(input_dim, hidden_dim)
        self.proj2 = nn.Linear(hidden_dim, output_dim)
        
        # dropout can be used to prevent overfitting,
        # however, it may be better to start training VAE without!
        # I just put it here for the example
        
        # let assume you create a network with:
        # network = Classifier(10, 10, 10)
        # then, to enable dropout you must call:
        # network.train()
        # and to disable dropout:
        # network.eval()
        # basically this calls set the network
        # in training and evaluation modes, respectively
        self.dropout = nn.Dropout(0.5)
        
    # this is the function that will be called
    # when we pass data to the network to compute the output
    # shape of batch: (n samples, output_dim)
    # where n samples is the number of samples in the batch
    def forward(self, batch):
        hid = self.proj1(batch)
        
        # non-linearity, it could have been a nn.Module object stored in this module,
        # here we use the functionnal version
        hid = F.relu(hid)
        
        # note: dropout is used to prevent overfitting
        hid = self.dropout(hid)
        
        # return the output, i.e. logits of each class for each sample in the batch
        # shape: (n samples, output_dim)
        return self.proj2(hid)
In [ ]:
n_epoch = 20
batch_size = 100

# build the network!
network = Classifier(28*28, 200, 10)

# optimizer that will be used to update the parameters!
optimizer = torch.optim.Adam(network.parameters())

all_epoch_losses = list()
all_epoch_accuracies = list()
for epoch in range(n_epoch):
    # we keep track of the loss at each epoch
    epoch_loss = 0
    
    # Note: one trick often used is to shuffle the data at each epoch,
    # we won't do that here.
    # if you want to do it, be carefull:
    # you need to shuffle both train_data_pixels and train_data_classes
    # so that they match, i.e. you cannot shuffle them independently!

    network.train() # set the network in train mode (i.e. enable dropout)
    
    # loop over the data
    for i in range(0, train_data_pixels.shape[0], batch_size):
        # IMPORTANT:
        # we must always call zero_grad() because each call to .backward()
        # will accumulate the gradient!
        optimizer.zero_grad()
        
        # batch of data
        # shape: (n samples, 784)
        # do you understand why n samples can be different of batch_size?
        batch = train_data_pixels[i:i+batch_size]
        
        # compute the weights associate with each class for each picture
        # logits shape: (n samples, 10)
        logits = network(batch)
        
        # compute the loss
        # look at the documentation of the function!
        loss = F.cross_entropy(
            logits,
            train_data_classes[i:i+batch_size],
            reduction="mean"
        )
        
        # to accumulate to the epoch loss,
        # we need to call loss.item() so the value
        # will detached from the computational graph,
        # i.e. loss.item() is python float, not a pytorch object.
        # can you understand why we multiply by batch_size.shape[0] ?
        epoch_loss += loss.item() * batch.shape[0]
        
        loss.backward() # compute the gradient
        optimizer.step() # update parameters of the model
        
    # at the end of each epoch, we eval on the dev data
    network.eval()
    n_dev_correct = 0
    
    # we don't need gradient computation during evaluation,
    # so we enclose computation with torch.no_grad(),
    # this will makes computation more efficient
    with torch.no_grad():
        for i in range(0, dev_data_pixels.shape[0], batch_size):
            batch = dev_data_pixels[i:i+batch_size]
            logits = network(batch)
            
            # compute the class of maximum weight for each sample,
            # look at the documentation of argmax!
            # pred shape: (n samples,)
            pred = logits.argmax(1)
            
            # compare prediction with gold classes
            n_dev_correct += (pred == dev_data_classes[i:i+batch_size]).sum().item()
    
    all_epoch_losses.append(epoch_loss)
    all_epoch_accuracies.append(n_dev_correct / dev_data_classes.shape[0]) 
    print(
        epoch, ":\t",
        epoch_loss,# / len(train_data[0]),
        "\t / dev precision:\t",
        n_dev_correct / dev_data_classes.shape[0],
        flush=True
    )
In [ ]:
# print graph to show how the loss and the accuracy evolve during training,
# for example to see if we overfit the training data.
# Note that MNIST is a really simple dataset, so its no really interesting in this case :)
plt.plot(np.arange(n_epoch), all_epoch_losses)
plt.show()

plt.plot(np.arange(n_epoch), all_epoch_accuracies)
plt.show()

2. VAE with continuous latent space and binary observed space

The first VAE that we will develop is based on the following generative story:

  1. $z \sim p(z)$
  2. $x \sim p(x | z ; \theta)$

where the latent random variable $Z$ takes value in $\mathbb R^n$. The prior ditribution $p(z)$ is a multivariate Gaussian where each coordinate is independent. We fix the mean and variance of each coordinate to 0 and 1, respectively. The conditional distribution $p(x | z ; \theta)$ is parameterized by a neural network. The random variables $X$ are $m$ independent Bernoulli random variables.

Note: this kind of VAE will be quite bad at generating MNIST picture. Therefore, when you do you experiments, you should both generate picture and show the distribution $p(x | z ; \theta)$ as a picture. This is a well known problem of VAE, you can try to play with the network architecture and the parameters to improve generation. However, hyperparameter optimization is not the goal of this lab exercise. :) Another solution would be to use Gaussian random variables for the output distribution, but we won't do that here.

KL divergence

The first function we have to implement is the one that compute the KL divergence with the prior. You can find the formula in appendix B of the following paper: https://arxiv.org/pdf/1312.6114.pdf

WARNING: In the paper they give the formula for the -KL, and here we want the KL, so don't forget to put a minus sign in front of their formula. :)

In [ ]:
# the shapes of mu and log_sigma_squared are: (n batch, n latent)
# the return vector must be of shape: (n batch,)
def KL_with_gaussian_prior(mu, log_sigma_squared):
    #
    #
    # TODO TODO TODO TODO TODO
    #
    #
    
# test for shape
mu = torch.rand((10, 30))
log_sigma_squared = torch.rand((10, 30))

if KL_with_gaussian_prior(mu, log_sigma_squared).shape != (10, ):
    raise RuntimeError("Invalid return dim!")
else:
    print("return dim is ok.")
    
# test for return value
# result should be 0 as it is equal to the prior
print("Should be 0:", KL_with_gaussian_prior(torch.tensor([[0.]]), torch.tensor([[0.]])).item())
print("Should be 0.5:", KL_with_gaussian_prior(torch.tensor([[1.]]), torch.tensor([[0.]])).item())
print("Should be 2.6945...:", KL_with_gaussian_prior(torch.tensor([[1.]]), torch.tensor([[2.]])).item())

Network

Next you need to code the encoder and the decoder. The encoder should return a tuple of tensors (mu, log_sigma_squared), both are unconstrained so they should be the output of a linear projection. The term log_sigma_squared means that the standard deviation can be computed via sigma=exp(0.5 * log_sigma_squared). Be sure you understand this.

The structure of the encoder should be as follow:

  • one projection and non linearity (e.g. relu) to compute a hidden representation of the input,
  • two projections to compute my and log_sigma_squared from the same hidden representation.

The structure of the decoder shoud be similar to the classifier above. Similarly to the encoder, the decoder should output unconstrained value, i.e. it does not directly return the mu parameters of the Bernoulli distribution, but values in R. In your code, if you need to get the distribution, you can call sigmoid(mu) to transform the unconstrained values to the mu parameters.

  • one projection and non linearity (e.g. relu) to compute a hidden representation of the input,
  • one projection to the output space

The decode has two additional functions:

  • sample_distributions: it samples from the prior and return the conditional output distributions (so here it means that you must apply the sigmoid befaire returning)
  • sample_images: call sample_distributions and then sample images from the output distributions.

To sample from a Bernoulli you can use the following function: https://pytorch.org/docs/stable/generated/torch.bernoulli.html To sample from the Gaussian prior, you can:

This second function change a tensor in-place (hence the _ at the end of the name)

In [ ]:
class GaussianEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
        
    def forward(self, batch):
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
In [ ]:
class GaussianPriorDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
        
    def forward(self, batch):
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
    
    # n_samples: number of "image output distribution" to compute
    def sample_distributions(self, n_samples):
        # first you must sample from the prior,
        # then you must pass the data through this network, i.e. self(z)
        
        # shape of z: (n_samples, input_dim)
        z = # TODO....
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
        
    
    # sample images
    def sample_images(self, n_samples, argmax=False):
        # compute the output distributions for each image
        # probs shae: (n_samples, output_dim)
        probs = self.sample_distributions(n_samples)
        
        # then sample output pixels
        # there are two modes:
        # argmax=False: sample from the output distribution
        # argmax=True: instead of sampling, return the most probable value for each random variable
        # (check how the images are converted to binary images above!)
   
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #

Training loop

I did most of the work, you just need to implement ont thing: the reparameterization trick! :)

In [ ]:
dim_input = 28*28
latent_dim = 2 # we use only 2 dimensions so we can plot them!
hidden_dim = 400
n_epoch = 20
batch_dim = 256

encoder = GaussianEncoder(dim_input, hidden_dim, latent_dim)
decoder = GaussianPriorDecoder(latent_dim, hidden_dim, dim_input)

# we must pass both network parameters to the optimizer
# sometime its also useful to use two different optimizers,
# for example if we want to use different learning rates.
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))

for epoch in range(n_epoch):
    epoch_reconstruction = 0
    epoch_kl = 0

    decoder.train()
    encoder.train()
    for i in range(0, train_data_pixels.shape[0], batch_dim): 
        optimizer.zero_grad()
        
        batch = train_data_pixels[i:i+batch_dim]
        
        # compute parameters of the approximate posterior
        mu, log_sigma_squared = encoder(batch)
        
        # for MC sampling of the reconstruction loss,
        # we use one sample from the approximate posterior.
        # you need to use the reparameterization trick for this!
        # in order to sample from the the N(0, 1) distribution,
        # you can use:
        # torch.empty_like(mu).normal_(mean=0.,std=1.)
        # (check the documentation if you don't understand this!)
        

        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
        
        z =  # TODO
        
        # reconstruction_logits: (batches, classes, output size)
        # note that the output of the network is in R,
        # to get Bernoulli parameters you must use the sigmoid.
        # For the reconstruction loss, you can either use:
        # - F.binary_cross_entropy_with_logits
        # - F.entropy_with_logits
        # Read the doc in order to understand the difference!
        # Warning: use the reduction="none" parameter
        reconstruction_logits = decoder(z)
        reconstruction_loss = F.binary_cross_entropy_with_logits(
            reconstruction_logits,
            batch,
            reduction="none"
        )
        
        kl_div = KL_with_gaussian_prior(mu, log_sigma_squared)
        
        # loss shape: (batch dim,)
        # do you understand why we sum along this dimension only?
        # there is a minus sign before reconstruction because
        # the loss formula is -log etc, but in the elbo its +log
        elbo = -reconstruction_loss.sum(1) - kl_div
        
        epoch_reconstruction += -reconstruction_loss.sum().item()
        epoch_kl += kl_div.sum().item()
        
        # we want to maximize the ELBO,
        # but Pytorch does minimization,
        # so we define our loss as -ELBO
        # and we mean over instances in the dataset
        loss = -elbo.mean()
        loss.backward()
        
        # optimization trick :)
        # gradient clipping, so we don't take too big updates
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5)
        
        optimizer.step()
        
    print(
        epoch, ":\t",
        epoch_reconstruction / train_data_pixels.shape[0],
        "\t",
        epoch_kl / train_data_pixels.shape[0]
    )

We can now visualize the output distribution from random prior samples and generate new images.

In [ ]:
decoder.eval()

with torch.no_grad():
    # you can try true samples
    # or useing argmax
    samples = decoder.sample_distributions(5)
    display_images(samples)
In [ ]:
decoder.eval()

with torch.no_grad():
    # you can try true samples
    # or useing argmax
    samples = decoder.sample_images(5, argmax=False)
    display_images(samples)

Using the approximate posterior, we can visualize the latent space, by assigning one color per digit. What we display is the mean value. For a well trained model, we should see all points close to 0 and each class should be well delimited, i.e. there must be clusters of the same color.

If your samples and latent space are ugly, you could try to improve your model by:

  • change the learning hyperparameters
  • add/remove dropout
  • change the network structure (number and size of layers, dropout...)

Note that if you want to visualize a latent space when its dimension is greater than two, you could project it in 2 dimensions using PCA (its already implemented in scikit-learn!)

In [ ]:
decoder.eval()

# tensor that will contain all latent points
points = np.empty((train_data_pixels.shape[0], 2))
with torch.no_grad():
    for i in range(0, train_data_pixels.shape[0], batch_dim): 
        batch = train_data_pixels[i:i+batch_dim]
        mu, _ = encoder(batch)
        points[i:i+batch_dim] = mu.numpy()

plt.scatter(
    points[:,0], points[:, 1],
    # colormap is between 0 and 1, and we have 10 classes
    # so we just divide by 10 :)
    # https://matplotlib.org/3.1.1/tutorials/colors/colormaps.html
    c=matplotlib.cm.get_cmap("tab10")(train_data_classes / 10.)
)
plt.show()

3. VAE with binary latent space and binary observed space

The second VAE that we will develop is based on the same generative story:

  1. $z \sim p(z)$
  2. $x \sim p(x | z ; \theta)$

but where the latent variable $Z$ are independent Bernoulli random variables. See the last video from the course to understand how this change the model.

Architecture

How many values should the encoder return? What is the difference with the previous encoder?

In [ ]:
class BernoulliEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()

        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
        
    def forward(self, batch):
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
In [ ]:
class BernoulliPriorDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
        
    def forward(self, batch):
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
    
    def sample_distributions(self, n_samples):
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
    
    def sample_images(self, n_samples, argmax=False):
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #

KL divergence

We assume the prior are Bernoullis with parameter 0.5, i.e. the probability to sample 0 or 1 are equals ("fait coin" probability). What the formula of the KL divergence? How does it simplify?

Hint: you should see an entropy term in it that you can compute via https://pytorch.org/docs/stable/distributions.html#torch.distributions.bernoulli.Bernoulli.entropy

In [ ]:
# the shapes of mu and log_sigma_squared are: (n batch, n latent)
# the return vector must be of shape: (n batch,)
import math

def KL_with_bernoulli_prior(probs):
    #
    #
    # TODO TODO TODO TODO TODO
    #
    #
    
# test for shape
probs = torch.rand((10, 30))

if KL_with_bernoulli_prior(probs).shape != (10, ):
    raise RuntimeError("Invalid return dim!")
else:
    print("return dim is ok.")
    
# test for return value
# result should be 0 as it is equal to the prior
print("Should be 0:", KL_with_bernoulli_prior(torch.tensor([[0.5, 0.5]])).item())
print("Should be 0.693...:", KL_with_bernoulli_prior(torch.tensor([[1.]])).item())

Training loop

This time I let you on your own! :) Remember that you need to use the Score Function Estimator trick! (see the last video of the course)

You need to use a baseline to stabilize the variance of the estimator. Here, we will rely on the average reconstruction value from all previous update. You can easily store it and compute it using the running average trick:

In [ ]:
n1 = 5
n2 = 10
n3 = 1
n4 = 12

print((n1 + n2 + n3 + n4) / 4)

running_average = 0

# after the first parameter update
n_updates = 1
running_average += (n1 - running_average) / n_updates

# after the second parameter update
n_updates += 1
running_average += (n2 - running_average) / n_updates

# etc etc
n_updates += 1
running_average += (n3 - running_average) / n_updates

n_updates += 1
running_average += (n4 - running_average) / n_updates

# should be equal to the previous printed value!
print(running_average)
In [ ]:
dim_input = 28*28
latent = 64
hidden_dim = 400
n_epoch = 50
batch_dim = 256

encoder = BernoulliEncoder(dim_input, hidden_dim, latent)
decoder = BernoulliPriorDecoder(latent, hidden_dim, dim_input)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))

running_avg = 0  # to control the variance of the SFE
n_updates = 0
for epoch in range(n_epoch):
    epoch_reconstruction = 0
    epoch_kl = 0
    
    decoder.train()
    encoder.train()
    for i in range(0, train_data_pixels.shape[0], batch_dim): 
        optimizer.zero_grad()
        
        batch = train_data_pixels[i:i+batch_dim]
        
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5)
        optimizer.step()
        n_updates += 1
        
        # update the running_avg
        #
        #
        # TODO TODO TODO TODO TODO
        #
        #

        
    print(
        epoch, ":\t",
        epoch_reconstruction / train_data_pixels.shape[0],
        "\t",
        epoch_kl / train_data_pixels.shape[0],
    )

We can now visualize the trained model. :)

In [ ]:
decoder.eval()

with torch.no_grad():
    samples = decoder.sample_distributions(5)
    display_images(samples)
In [ ]:
decoder.eval()

with torch.no_grad():
    # you can try true samples
    # or useing argmax
    samples = decoder.sample_images(5, argmax=True)
    display_images(samples)

 4. Turning a Deterministic Auto-Encoder into a generative model

Last exercise! Here I want to show you a trick to turn a deterministic model into a probabilistic generative model. This a two step process.

First, you train an auto-encoder - not a VAE, just an encoder. The encoder takes an image and project it in the latent space, here we will use a 2 dimension latent space. It is a deterministic computation: the output of the encoder is not the parameters of the approximate posterior $q(z|x)$ distribution, it is a point in the $\mathbb R^2$ space. The decoder reconstruct the input image from the encoder output. Therefore, during training, the main difference is that you don't have a KL divergence term! :)

Second, we turn this deterministic model into a probabilistic generative model. To do this, we look at where each training image of the training is mapped in the latent space. You can train a Gaussian Mixture Model that will generate data that look like the latent space of the training set! Then, to generate a new data point, you can:

  1. sample a cluster from the GMM: $y \sim p(y)$
  2. sample a point from the bivariate Gaussian associated with this cluster: $z \sim p(z|y)$
  3. use the decoder to sample an image: $x \sim p(x | z)$

To fit and sample from the GMM, use the scikit learn library! https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html

Architecture and training loop of the auto-encoder

In [ ]:
# You are on your own! good luck :)

We can now visualize the latent space of the deterministic auto-encoder. How does it look like compared to the VAE one? Does it looks like we can fit it with a GMM?

In [ ]:
decoder.eval()

# tensor that will contain all latent points
points = np.empty((train_data_pixels.shape[0], 2))
with torch.no_grad():
    for i in range(0, train_data_pixels.shape[0], batch_dim): 
        batch = train_data_pixels[i:i+batch_dim]
        mu, _ = encoder(batch)
        points[i:i+batch_dim] = mu.numpy()

plt.scatter(
    points[:,0], points[:, 1],
    # colormap is between 0 and 1, and we have 10 classes
    # so we just divide by 10 :)
    # https://matplotlib.org/3.1.1/tutorials/colors/colormaps.html
    c=matplotlib.cm.get_cmap("tab10")(train_data_classes / 10.)
)
plt.show()

Turning the deterministic auto-encoder into a generative model

Todo:

  1. fit the point from the encoder with a GMM (for example with 20 clusters, called components in scikit-learn): https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html
  2. Sample many points from this GMM and plot them: does the resulting point cloud looks like the one above? (or, in other words, did the GMM correctly fit the latent space?)
  3. Sample new data from the model! How does it looks like compared to VAE? how does the output distribution looks like? Does it work?

Note: its an unsupervised problem, do not use class information in this part!

In [ ]:
# I let you do this alone, again...