Deep Learning PyTorch Course, Variational Autoencoder

Deep learning is a field of machine learning that utilizes neural networks to learn patterns from data. In this article, we will delve deeply into Variational Autoencoder (VAE).

1. What is an Autoencoder?

An autoencoder is an unsupervised learning method that generally learns the process of compressing input data and then reconstructing it. An autoencoder consists of two parts: an encoder and a decoder.

  • Encoder: Maps input data to a latent space.
  • Decoder: Restores data from the latent space to the original input data.

1.1 The Process of Autoencoder

The training process of an autoencoder proceeds in a way that reduces the difference between the input data and the output data. To do this, a loss function is used to measure the difference between the actual output and the predicted output. The Mean Squared Error (MSE) loss function is commonly used.

2. Variational Autoencoder (VAE)

The Variational Autoencoder is an extended model of the traditional autoencoder, aimed at estimating the probability distribution of the input data. VAE, as a generative model, has the ability to generate new data.

2.1 Components of VAE

VAE consists of the following two main components:

  • Latent Variable: When encoding input data, the encoder outputs the mean (μ) and standard deviation (σ) to estimate the distribution of the latent variables.
  • Reconstruction Loss: Measures the difference between the output generated by the decoder and the original input.

2.2 Loss Function

VAE’s loss function can be divided into two parts:

  • Reconstruction Loss: Measures the loss between the actual input and the reconstructed input.
  • Kullback-Leibler Divergence: Measures the difference between the latent distribution and the normal distribution.

Definition of VAE Loss Function:

L = E[log p(x|z)] - D_{KL}(q(z|x) || p(z))
    

Where:

  • E[log p(x|z)]: The log likelihood for z given input x.
  • D_{KL}: Kullback-Leibler Divergence which measures the difference between two distributions.

3. Implementing VAE with PyTorch

Now that we understand the basic components and loss function of the Variational Autoencoder, let’s implement VAE using PyTorch.

3.1 Install Libraries

pip install torch torchvision matplotlib
    

3.2 Prepare Dataset

We will implement a VAE to recognize handwritten digits using the MNIST dataset. MNIST is a dataset consisting of 28×28 pixel grayscale images.

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)
    

3.3 Define Model

To construct the Variational Autoencoder model, we define the encoder and decoder classes.

import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 400)
        self.fc21 = nn.Linear(400, latent_dim)  # Mean
        self.fc22 = nn.Linear(400, latent_dim)  # Log Variance
        
    def forward(self, x):
        h1 = torch.relu(self.fc1(x))
        mu = self.fc21(h1)
        logvar = self.fc22(h1)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 400)
        self.fc2 = nn.Linear(400, output_dim)
        
    def forward(self, z):
        h2 = torch.relu(self.fc1(z))
        return torch.sigmoid(self.fc2(h2))
    
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar
    

3.4 Define Loss Function

We define the loss function for VAE. Here, we will implement it using PyTorch’s functionalities.

def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD
    

3.5 Train the Model

We train the model using a training loop. We recompute the loss function and perform backpropagation to update the weights.

import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(784, 20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
for epoch in range(10):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    
    print(f'Epoch {epoch+1}, Loss: {train_loss / len(train_loader.dataset)}')
    

3.6 Check Results

Once the training is complete, we can use the model to generate new data and see how similar it is to the training data.

import matplotlib.pyplot as plt

def visualize_results(model, num_images=10):
    with torch.no_grad():
        z = torch.randn(num_images, 20).to(device)
        sample = model.decoder(z).cpu()
        sample = sample.view(num_images, 1, 28, 28)
        
    plt.figure(figsize=(10, 1))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(sample[i].squeeze(), cmap='gray')
        plt.axis('off')
    plt.show()

visualize_results(model)
    

4. Conclusion

In this tutorial, we explored the concept of the Variational Autoencoder and how to implement it using PyTorch. VAE has the capability to learn the latent distribution of data and generate new samples, which can be utilized in various generative modeling tasks. This technique can be applied for interesting tasks such as generating images, text, and audio data.

Furthermore, VAE can contribute to the implementation of more powerful and diverse generative models when combined with other generative models like GAN. In particular, VAE helps explore and sample from the latent space of high-dimensional data.

References