Deep Learning with GAN Using PyTorch, Training Process

Generative Adversarial Network (GAN) is a neural network architecture introduced by Ian Goodfellow and colleagues in 2014, consisting of two competing neural networks: the Generator and the Discriminator. GANs are primarily used in fields such as image generation, transformation, and reconstruction, and are particularly popular for creating high-resolution photographs and artworks. In this article, we will take a detailed look at the overall structure and training process of GANs using PyTorch.

1. Structure of GAN

GAN consists of two main components:

  • Generator (G): A network that takes a random noise vector as input and transforms it into a fake sample that resembles real data.
  • Discriminator (D): A network that determines whether the input sample is real data or fake data created by the generator. The discriminator must be able to distinguish real data from fake data generated by the generator as effectively as possible.

These two networks are structured to compete with each other to perform better than the opponent. The generator is gradually improved to generate more plausible data, while the discriminator is trained to distinguish increasingly sophisticated data.

2. Training Process of GAN

The training process of GAN proceeds through the following steps:

  1. A random noise vector is generated and input to the generator.
  2. The generator transforms the noise vector into a fake sample.
  3. The discriminator receives both the real data and the generated fake data as input.
  4. The discriminator predicts whether each sample is real data or fake data.
  5. The generator is updated via a loss function to make the discriminator classify fake samples as real data. Conversely, the discriminator is updated to better distinguish between real and fake data.

3. Implementing GAN Using PyTorch

Now let’s write the code to implement GAN using PyTorch. We will create a GAN to generate handwritten digits using the MNIST dataset.

3.1 Install Required Libraries

pip install torch torchvision matplotlib

3.2 Prepare the Dataset

Let’s load the MNIST dataset. In PyTorch, data can be easily downloaded through the torchvision library.

import torch
from torchvision import datasets, transforms

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

3.3 Implementing the Generator and Discriminator

Now let’s define the two core components of GAN: the generator and the discriminator. We will use a simple fully connected neural network for the generator and a CNN to process the images for the discriminator.

import torch.nn as nn

# Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

3.4 Setting Up Loss Function and Optimization Algorithm

To train GAN, we will define the loss function and optimization algorithm. Typically, the generator and discriminator use different loss functions. We will use a simple binary cross-entropy loss.

criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(Generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(Discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

3.5 GAN Training Loop

Now let’s implement the loop for training GAN. Here, we alternate training the generator and discriminator for a specified number of epochs.

def train_gan(generator, discriminator, train_loader, num_epochs=100):
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(train_loader):
            batch_size = real_images.size(0)

            # Define real and fake labels
            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)

            # Train the discriminator
            discriminator.zero_grad()
            outputs = discriminator(real_images)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, 100)
            fake_images = generator(z)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()

            optimizer_D.step()

            # Train the generator
            generator.zero_grad()
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}')

# Start GAN training
generator = Generator()
discriminator = Discriminator()
train_gan(generator, discriminator, train_loader)

4. Visualizing Results

After training, we can visualize the generated images to check the results.

import matplotlib.pyplot as plt

def show_generated_images(generator, num_images=25):
    z = torch.randn(num_images, 100)
    generated_images = generator(z).detach().numpy()
    
    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i][0], cmap='gray')
        plt.axis('off')
    plt.show()

# Show generated images
show_generated_images(generator)

5. Conclusion

In this post, we explored the basic concepts of GANs and a simple implementation using PyTorch. GANs are powerful generative models that can be applied to various data generation problems. However, training GANs can be unstable and may require various techniques and hyperparameter tuning. Exploring more complex GAN architectures (e.g., DCGAN, WGAN) can yield interesting results.

Now you are aware of the basic operation of GANs and how to implement them using PyTorch. Based on this knowledge, I encourage you to try out various examples!