In recent years, deep learning has made innovative advancements in various fields such as image generation, transformation, and segmentation. Among them, GAN (Generative Adversarial Network) has opened new possibilities for image generation. GAN consists of two networks, the Generator and the Discriminator, which compete against each other to improve performance. In this post, we will explore an overview of GAN and detail how to set up the environment to implement GAN using the PyTorch framework.
1. Overview of GAN
GAN is a model proposed by Ian Goodfellow in 2014, where two neural networks interact to be trained. The generator creates data that resembles real data, while the discriminator determines whether the generated data is real or not. These two networks continuously enhance each other.
1.1 Structure of GAN
GAN consists of the following two components:
- Generator: Takes a random noise vector as input and generates fake data.
- Discriminator: Determines whether the received data is real or fake.
1.2 Mathematical Definition of GAN
The goal of GAN can be expressed as a Minimax game. The generator has the following objective:
G^{*} = arg \min_{G} \max_{D} V(D, G) = E_{x \sim pdata(x)}[\log D(x)] + E_{z \sim pz(z)}[\log(1 - D(G(z)))]
Here, G represents the generator, D represents the discriminator, pdata(x) is the distribution of real data, and pz(z) is the noise distribution used by the generator.
2. Setting Up the PyTorch Environment
PyTorch is an open-source machine learning library that provides various tools for tensor operations, automatic differentiation, and easily building deep learning models. The following outlines how to install PyTorch and set up the necessary libraries for implementing GAN.
2.1 Installing PyTorch
PyTorch supports CUDA, allowing it to operate efficiently on NVIDIA GPUs. You can install it using the following command:
pip install torch torchvision torchaudio
If you are using CUDA, please check the official PyTorch website for the installation commands that match your environment.
2.2 Installing Additional Libraries
You will also need to install additional libraries required for image processing. Install them using the command below:
pip install matplotlib numpy
2.3 Setting Up the Basic Directory Structure
Create the project directory with the following structure:
gan_project/
├── dataset/
├── models/
├── results/
└── train.py
Each directory serves the purpose of storing datasets, models, and results. The train.py
file contains the script for training and evaluating GAN.
3. Code Examples Needed to Implement GAN
Now, let’s write the basic code to implement GAN. This code defines the generator and discriminator and includes the process of training GAN.
3.1 Defining the Model
First, we define the generator and discriminator networks. The code below demonstrates an example of building the generator and discriminator using a simple CNN (Convolutional Neural Network):
import torch
import torch.nn as nn
# Generator model
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28), # MNIST image size
nn.Tanh() # Normalized to [-1, 1]
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
# Discriminator model
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # Normalized to [0, 1]
)
def forward(self, img):
return self.model(img)
3.2 Preparing the Dataset
We load and preprocess the MNIST dataset. Using the torchvision library makes it easy to load the dataset.
from torchvision import datasets, transforms
# Data preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load MNIST dataset
dataloader = torch.utils.data.DataLoader(
datasets.MNIST('dataset/', download=True, transform=transform),
batch_size=64,
shuffle=True
)
3.3 GAN Training Code
Now, let’s create a loop that can train the generator and discriminator.
import torch.optim as optim
# Initialize models
generator = Generator()
discriminator = Discriminator()
# Loss function and optimizer
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
num_epochs = 50
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
batch_size = imgs.size(0)
imgs = imgs.view(batch_size, -1)
# Generate real and fake labels
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Train the discriminator
optimizer_d.zero_grad()
outputs = discriminator(imgs)
d_loss_real = criterion(outputs, real_labels)
d_loss_real.backward()
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
d_loss_fake.backward()
optimizer_d.step()
# Train the generator
optimizer_g.zero_grad()
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_g.step()
if i % 100 == 0:
print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {d_loss_real.item() + d_loss_fake.item()}] "
f"[G loss: {g_loss.item()}]")
4. Visualizing Results
After training is complete, it is also important to visualize the generated images. You can output the generated images using matplotlib.
import matplotlib.pyplot as plt
def generate_and_plot_images(generator, n_samples=25):
z = torch.randn(n_samples, 100)
generated_images = generator(z).detach().numpy()
plt.figure(figsize=(5, 5))
for i in range(n_samples):
plt.subplot(5, 5, i + 1)
plt.imshow(generated_images[i][0], cmap='gray')
plt.axis('off')
plt.show()
generate_and_plot_images(generator)
5. Conclusion
In this post, we explained the principles and basic structure of GAN, as well as provided the setup and code examples needed to implement GAN using PyTorch. GAN is a very powerful generative model with various applications. We encourage you to try various projects utilizing GAN in the future.