Deep Learning with PyTorch, Introduction to GAN

1. Introduction to GAN (Generative Adversarial Network)

GAN (Generative Adversarial Network) is a deep learning model first proposed by Ian Goodfellow in 2014,
consisting of two neural networks: a Generator and a Discriminator that compete with each other.
The Generator creates fake data, while the Discriminator is responsible for determining whether the data is real or fake.
These two networks continuously learn to improve each other’s performance.

The core idea of GANs is “Adversarial Training”.
The Generator continues to produce more convincing fake data to prevent the Discriminator from accurately distinguishing
between real and fake data. In contrast, the Discriminator learns more elaborately to accurately judge whether the data created by the Generator is real or fake.
This competitive structure is a unique feature of GANs, which are utilized in various fields, including creative image generation, video generation, and text generation.

2. Structure and Learning Process of GANs

The learning process of GANs consists of the following stages:

  1. Data Collection: GANs require a large amount of data, typically using samples from real datasets.
  2. Training the Generator: The Generator takes noise (z) as input and generates fake images (or data).
  3. Training the Discriminator: The Discriminator takes real images and fake images created by the Generator as input and predicts whether they are real or fake.
  4. Loss Function Calculation: The loss function is calculated to evaluate the performance of both the Generator and the Discriminator.
    The Generator’s goal is to deceive the Discriminator, while the Discriminator’s goal is to accurately judge the fake images created by the Generator.
  5. Model Update: Based on the loss function, both the Generator and the Discriminator update their model parameters using optimization algorithms.
  6. Iteration: Steps 2 to 5 are repeated to ensure that both networks can mutually improve.

In this way, the Generator gradually produces better images, and the Discriminator becomes more proficient at distinguishing them.
As this process is repeated, the Generator eventually reaches a level where it can produce very realistic data.

3. How to Implement GAN

Now, let’s implement GAN using PyTorch.
In this example, we will create a simple GAN to work with the hand-written digit dataset, MNIST.
MNIST consists of 70,000 grayscale images containing digits from 0 to 9.
Our goal is to generate images of these digits.

3.1. Install Required Libraries

First, we need to install PyTorch and other necessary libraries.
You can install the required packages using the command below.

!pip install torch torchvision matplotlib

3.2. Load and Preprocess the Dataset

Now, we will load the MNIST dataset, transform it into Tensor format, and prepare it for training.


import torch
from torchvision import datasets, transforms

# Data transformation settings
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

3.3. Define the Generator and Discriminator of GAN

We will define the Generator and Discriminator of the GAN.
The Generator takes random noise as input to generate images, while the Discriminator determines whether the given image is real or fake.


import torch.nn as nn

# Generator definition
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh() # Normalize the output to -1 ~ 1
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# Discriminator definition
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid() # Normalize the output to 0 ~ 1
        )

    def forward(self, img):
        return self.model(img)

3.4. Set Loss Function and Optimization Algorithm

The loss function of GAN consists of two losses.
We will set the Generator’s loss and the Discriminator’s loss, and define the optimization algorithms for both neural networks.


import torch.optim as optim

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Set loss function and optimization algorithms
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

3.5. Train the GAN

Now, let’s train the GAN.
During the training process, the Generator and the Discriminator are trained alternately.


import matplotlib.pyplot as plt

def train_gan(num_epochs):
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(train_loader):
            # Labels for real images
            real_imgs = imgs
            real_labels = torch.ones(real_imgs.size(0), 1)
            fake_labels = torch.zeros(real_imgs.size(0), 1)

            # Train the Discriminator
            optimizer_D.zero_grad()
            outputs = discriminator(real_imgs)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(real_imgs.size(0), 100)
            fake_imgs = generator(z)
            outputs = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_D.step()

            # Train the Generator
            optimizer_G.zero_grad()
            outputs = discriminator(fake_imgs)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

        if epoch % 100 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item():.4f}, g_loss: {g_loss.item():.4f}')

            # Display generated images
            with torch.no_grad():
                generated_images = generator(torch.randn(64, 100)).detach().cpu()
                plt.figure(figsize=(10, 10))
                plt.imshow(torchvision.utils.make_grid(generated_images, nrow=8, normalize=True).permute(1, 2, 0))
                plt.axis('off')
                plt.show()

train_gan(num_epochs=1000)

4. Conclusion

GANs are very powerful generative models that are applied in various fields.
In this tutorial, we explored how to implement GAN using PyTorch.
By learning through the competition between the Generator and the Discriminator, GANs can generate high-quality data.
For practical applications, various techniques (e.g., conditional GAN, style GAN, etc.) can be used to improve performance.

In the future, we will discuss more advanced GAN architectures and their applications.
GANs are still under active research, and new methods of GAN are continuously being introduced, so it is important to keep an eye on updates related to them.