Deep Learning with GAN Using PyTorch, First Deep Neural Network

1. What is GAN?

Generative Adversarial Networks (GAN) is one of the fundamental deep learning models proposed by Ian Goodfellow in 2014. GAN consists of two neural networks:
Generator and Discriminator. The generator tries to create fake data, while the discriminator attempts to determine whether the data is real or fake.
These two networks compete against each other during training, hence the term “Adversarial.” This process is unsupervised, allowing the generator to produce data that increasingly resembles real data.

2. Structure of GAN

The structure of GAN operates as follows:

  • Generator Network: Takes a random noise vector as input and generates fake images.
  • Discriminator Network: Responsible for distinguishing between real and fake images.
  • During the training process, the generator learns to produce images that the discriminator cannot easily classify. This causes both networks to improve and compete with each other.

3. How GAN Works

The training process of GAN consists of the following iterative steps:

  1. Training the Discriminator: The discriminator receives real images and fake images produced by the generator as input and updates its parameters to classify these images correctly.
  2. Training the Generator: The generator evaluates the quality of images produced through the trained discriminator, updating its parameters to prevent the discriminator from recognizing its images as fake.
  3. This process is repeated, allowing both networks to become progressively stronger against each other.

4. Implementing GAN in PyTorch

Now, let’s implement GAN in PyTorch. Our goal is to generate digit images using the MNIST dataset.

4.1 Installing Required Libraries

        pip install torch torchvision matplotlib
    

4.2 Preparing the Dataset

The MNIST dataset can be easily retrieved through PyTorch’s torchvision library. The code below shows how to load and preprocess the data.

        
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Data transformation settings
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        
        

4.3 Defining the Generator and Discriminator Networks

Now we define the two networks of GAN. We will create the generator and discriminator using a simple neural network structure.

        
import torch.nn as nn

# Define Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# Define Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)
        
    

4.4 Setting Loss Function and Optimization Algorithm

The loss function used for both the generator and discriminator is binary cross-entropy loss. We adopt the Adam optimization algorithm.

        
import torch.optim as optim

# Define loss function and optimization algorithm
criterion = nn.BCELoss()
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
    

4.5 Implementing the Training Process

Training proceeds as follows:

        
import numpy as np

num_epochs = 50
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.size(0)
        
        # Set real and fake labels
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Train Discriminator
        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)

        noise = torch.randn(batch_size, 100)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # Output training status
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
        
    

4.6 Visualizing Generated Images

After training is complete, we visualize the images generated.

        
import matplotlib.pyplot as plt

def visualize_generator(num_images):
    noise = torch.randn(num_images, 100)
    with torch.no_grad():
        generated_images = generator(noise)

    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i][0].cpu().numpy(), cmap='gray')
        plt.axis('off')
    plt.show()

visualize_generator(25)
        
    

5. Applications of GAN

GANs can be used in various fields beyond image generation. For example, they are employed in style transfer, image restoration, video generation, and have garnered significant attention in the field of artificial intelligence.
The development of GANs is revealing new possibilities through generative models.

6. Conclusion

In this tutorial, we learned the basics of implementing GANs using PyTorch and understood how GAN operates through actual code. GAN is a technology that will continue to evolve and holds great potential in various applications.
I recommend exploring various modified models based on GAN in the future.