Deep Learning with GAN using PyTorch, Environment Setup

In recent years, deep learning has made innovative advancements in various fields such as image generation, transformation, and segmentation. Among them, GAN (Generative Adversarial Network) has opened new possibilities for image generation. GAN consists of two networks, the Generator and the Discriminator, which compete against each other to improve performance. In this post, we will explore an overview of GAN and detail how to set up the environment to implement GAN using the PyTorch framework.

1. Overview of GAN

GAN is a model proposed by Ian Goodfellow in 2014, where two neural networks interact to be trained. The generator creates data that resembles real data, while the discriminator determines whether the generated data is real or not. These two networks continuously enhance each other.

1.1 Structure of GAN

GAN consists of the following two components:

  • Generator: Takes a random noise vector as input and generates fake data.
  • Discriminator: Determines whether the received data is real or fake.

1.2 Mathematical Definition of GAN

The goal of GAN can be expressed as a Minimax game. The generator has the following objective:

G^{*} = arg \min_{G} \max_{D} V(D, G) = E_{x \sim pdata(x)}[\log D(x)] + E_{z \sim pz(z)}[\log(1 - D(G(z)))]

Here, G represents the generator, D represents the discriminator, pdata(x) is the distribution of real data, and pz(z) is the noise distribution used by the generator.

2. Setting Up the PyTorch Environment

PyTorch is an open-source machine learning library that provides various tools for tensor operations, automatic differentiation, and easily building deep learning models. The following outlines how to install PyTorch and set up the necessary libraries for implementing GAN.

2.1 Installing PyTorch

PyTorch supports CUDA, allowing it to operate efficiently on NVIDIA GPUs. You can install it using the following command:

pip install torch torchvision torchaudio

If you are using CUDA, please check the official PyTorch website for the installation commands that match your environment.

2.2 Installing Additional Libraries

You will also need to install additional libraries required for image processing. Install them using the command below:

pip install matplotlib numpy

2.3 Setting Up the Basic Directory Structure

Create the project directory with the following structure:


    gan_project/
    ├── dataset/
    ├── models/
    ├── results/
    └── train.py
    

Each directory serves the purpose of storing datasets, models, and results. The train.py file contains the script for training and evaluating GAN.

3. Code Examples Needed to Implement GAN

Now, let’s write the basic code to implement GAN. This code defines the generator and discriminator and includes the process of training GAN.

3.1 Defining the Model

First, we define the generator and discriminator networks. The code below demonstrates an example of building the generator and discriminator using a simple CNN (Convolutional Neural Network):

import torch
import torch.nn as nn

# Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),  # MNIST image size
            nn.Tanh()  # Normalized to [-1, 1]
        )
    
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Normalized to [0, 1]
        )
    
    def forward(self, img):
        return self.model(img)
    

3.2 Preparing the Dataset

We load and preprocess the MNIST dataset. Using the torchvision library makes it easy to load the dataset.

from torchvision import datasets, transforms

# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset/', download=True, transform=transform),
    batch_size=64,
    shuffle=True
)  
    

3.3 GAN Training Code

Now, let’s create a loop that can train the generator and discriminator.

import torch.optim as optim

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Loss function and optimizer
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)

num_epochs = 50
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        batch_size = imgs.size(0)
        imgs = imgs.view(batch_size, -1)

        # Generate real and fake labels
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Train the discriminator
        optimizer_d.zero_grad()
        
        outputs = discriminator(imgs)
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()
        
        z = torch.randn(batch_size, 100)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        
        optimizer_d.step()

        # Train the generator
        optimizer_g.zero_grad()
        z = torch.randn(batch_size, 100)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        
        optimizer_g.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss_real.item() + d_loss_fake.item()}] "
                  f"[G loss: {g_loss.item()}]")
    

4. Visualizing Results

After training is complete, it is also important to visualize the generated images. You can output the generated images using matplotlib.

import matplotlib.pyplot as plt

def generate_and_plot_images(generator, n_samples=25):
    z = torch.randn(n_samples, 100)
    generated_images = generator(z).detach().numpy()

    plt.figure(figsize=(5, 5))
    for i in range(n_samples):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i][0], cmap='gray')
        plt.axis('off')
    plt.show()

generate_and_plot_images(generator)
    

5. Conclusion

In this post, we explained the principles and basic structure of GAN, as well as provided the setup and code examples needed to implement GAN using PyTorch. GAN is a very powerful generative model with various applications. We encourage you to try various projects utilizing GAN in the future.

References