Deep Learning with GANs using PyTorch, WGAN-GP

Generative Adversarial Networks (GAN) is a powerful generative model proposed by Ian Goodfellow in 2014. GAN consists of two neural networks, namely the Generator and the Discriminator, which compete with each other to learn. The Generator tries to create new data that resembles real data, while the Discriminator attempts to distinguish whether the given data is real or generated. They continuously improve and ultimately become capable of reliably generating very realistic data.

This article explains Wasserstein GAN with Gradient Penalty (WGAN-GP), a variant of GAN, and demonstrates how to implement WGAN-GP using PyTorch. WGAN-GP is based on the Wasserstein distance and adds a Gradient Penalty to the Discriminator to enhance training stability.

1. Basic Structure of GAN

The basic structure of GAN is as follows.

  • Generator: Receives random noise as input and generates fake data.
  • Discriminator: Receives real data and fake data produced by the Generator as input and judges how similar they are.

The learning process of GAN consists of the following two steps.

  1. The Generator generates random data from noise.
  2. The Discriminator distinguishes between real data and the generated data.

This process is performed repeatedly, leading to improvements in both networks. However, traditional GANs often face training instability and mode collapse issues, prompting research into various approaches for more stable training.

2. Introduction to WGAN-GP

WGAN aims to address the inherent problems of GAN by introducing the concept of the Wasserstein distance. The Wasserstein distance allows for a clearer definition of the differences between two distributions, facilitating network training. The key idea of WGAN is to introduce the concept of a “critic” instead of a Discriminator. The critic evaluates the distance between the generated data and real data and updates the network using Wasserstein loss rather than mean squared error (MSE) loss based on this evaluation.

By adding Gradient Penalty (GP) in WGAN, training stability is further enhanced by ensuring that the Discriminator adheres to the Lipschitz condition. The Gradient Penalty is defined as follows:

GP = λ * E[(||∇D(x) ||2 - 1)²]

Here, λ is a hyperparameter, and D(x) is the output of the Discriminator. The Gradient Penalty reinforces keeping the gradient of the Discriminator at 1. This approach enables WGAN-GP to overcome the instability of GANs and allows for more stable training.

3. Implementing WGAN-GP in PyTorch

Now, let’s implement WGAN-GP using PyTorch. The following steps will be followed:

  1. Install necessary libraries and load the dataset
  2. Define the Generator and Discriminator models
  3. Implement the WGAN-GP training loop
  4. Visualize the results

3.1 Installing Libraries and Loading the Dataset

First, install the necessary libraries and load the MNIST dataset.

!pip install torch torchvision matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

3.2 Defining the Generator and Discriminator Models

Define the Generator and Discriminator models. The Generator takes a random noise vector as input and transforms it into an image, while the Discriminator evaluates whether the input image is real or fake.

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).reshape(-1, 1, 28, 28)

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.model(x.view(-1, 28 * 28))

3.3 Implementing the WGAN-GP Training Loop

Now, let’s implement the training loop of WGAN-GP. During the training process, the Discriminator is updated a certain number of times before updating the Generator. The Gradient Penalty is also included in the loss.

def compute_gradient_penalty(critic, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).expand_as(real_samples)
    interpolated_samples = alpha * real_samples + (1 - alpha) * fake_samples
    interpolated_samples.requires_grad_(True)

    d_interpolated = critic(interpolated_samples)

    gradients = torch.autograd.grad(outputs=d_interpolated, inputs=interpolated_samples,
                                    grad_outputs=torch.ones_like(d_interpolated),
                                    create_graph=True, retain_graph=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator().to(device)
critic = Critic().to(device)

learning_rate = 0.00005
num_epochs = 100
critic_iterations = 5
lambda_gp = 10

criterion = nn.MSELoss()
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_critic = optim.Adam(critic.parameters(), lr=learning_rate)

Real_data = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

data_loader = torch.utils.data.DataLoader(Real_data, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(data_loader):
        real_images = real_images.to(device)

        for _ in range(critic_iterations):
            optimizer_critic.zero_grad()

            # Generate fake images
            z = torch.randn(real_images.size(0), 100).to(device)
            fake_images = generator(z)

            # Get critic scores
            real_validity = critic(real_images)
            fake_validity = critic(fake_images)
            gradient_penalty = compute_gradient_penalty(critic, real_images.data, fake_images.data)

            # Compute loss
            critic_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
            critic_loss.backward()
            optimizer_critic.step()

        # Update generator
        optimizer_generator.zero_grad()
        
        # Get generator score
        fake_images = generator(z)
        validity = critic(fake_images)
        generator_loss = -torch.mean(validity)
        generator_loss.backward()
        optimizer_generator.step()

    if epoch % 10 == 0:
        print(f"Epoch: {epoch}/{num_epochs}, Critic Loss: {critic_loss.item():.4f}, Generator Loss: {generator_loss.item():.4f}")

3.4 Visualizing Results

Finally, let’s visualize the generated images. This is a good way to verify how well the Generator has learned during the training process.

def show_generated_images(generator, num_images=25):
    z = torch.randn(num_images, 100).to(device)
    generated_images = generator(z).cpu().detach().numpy()

    plt.figure(figsize=(5, 5))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i][0], cmap='gray')
        plt.axis('off')
    plt.show()

show_generated_images(generator)

4. Conclusion

This article discussed WGAN-GP, a variant of GAN, and demonstrated how to implement it using PyTorch. WGAN-GP offers the advantage of more stable training by leveraging the Wasserstein distance and Gradient Penalty. These GAN-based models can be applied in various fields, including image generation, image translation, and style transfer.

As deep learning continues to advance, GANs and their variants are receiving ongoing attention, and future developments are highly anticipated. I encourage you to also take on various projects using GANs and WGAN-GP!