The advancement of deep learning is impacting various fields, and especially generative modeling is opening new horizons for data generation. Generative Adversarial Networks (GANs) are one of the most famous models in generative modeling, excelling in the ability to generate new data from raw data. This article aims to explain the main concepts of GAN, the implementation methods using PyTorch, and provide practical examples.
1. Basics of GAN
GAN consists of two neural networks that serve the roles of a generator and a discriminator. These two networks are in an adversarial relationship and learn simultaneously.
1.1 Generator
The generator’s role is to generate data that resembles real data from random noise (input noise). This involves learning the distribution of the data to create new data, with the goal of deceiving the discriminator.
1.2 Discriminator
The discriminator’s role is to determine whether the input data is real or generated by the generator. This is also implemented as a neural network, and the discriminator’s goal is to accurately distinguish between real and fake data as much as possible.
1.3 Adversarial Learning Process
The learning process of GAN consists of the following steps:
- The generator produces data from random noise.
- The discriminator receives real data and the fake data created by the generator and tries to distinguish between them.
- The generator is optimized to trick the discriminator into misjudging fake data as real.
- The discriminator is optimized to accurately distinguish fake data.
This process is repeated many times, gradually leading the generator to produce better data and the discriminator to make more refined judgments.
2. Structure of GAN
GAN has the following structure.
- Input Noise: Typically, a noise vector following a normal distribution is input.
- Generator Network: Accepts input noise and generates fake samples from it.
- Discriminator Network: Accepts generated fake samples and real samples to determine whether they are real or fake.
3. Implementing GAN Using PyTorch
Now, let’s implement GAN using PyTorch. PyTorch is a very useful library for building and training deep learning models.
3.1 Installing Required Libraries
!pip install torch torchvision matplotlib
3.2 Defining the Generator and Discriminator Networks
First, we define the generator and discriminator networks. These are designed based on their respective properties.
import torch
import torch.nn as nn
# Define the generator
class Generator(nn.Module):
def __init__(self, input_size, output_size):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, output_size),
nn.Tanh() # Limit output values between -1 and 1
)
def forward(self, z):
return self.model(z)
# Define the discriminator
class Discriminator(nn.Module):
def __init__(self, input_size):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_size, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid() # Limit output values between 0 and 1
)
def forward(self, x):
return self.model(x)
3.3 Data Preparation
We will use the MNIST dataset to train the generative model. MNIST is a dataset of handwritten digit images, containing digits from 0 to 9.
from torchvision import datasets, transforms
# Download and transform the dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalization
])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)
3.4 Defining Loss Functions and Optimization Techniques
Since GAN comprises a generator and a discriminator competing against each other, we define a loss function for each. We will use binary cross-entropy loss.
# Set loss function and optimization techniques
criterion = nn.BCELoss()
lr = 0.0002
beta1 = 0.5
generator = Generator(input_size=100, output_size=784).cuda()
discriminator = Discriminator(input_size=784).cuda()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
3.5 Implementing the GAN Training Process
Now, let’s implement the training process of GAN. This includes how to update the generator and discriminator for each batch.
num_epochs = 50
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Labels for real images (1)
real_imgs = imgs.view(imgs.size(0), -1).cuda()
real_labels = torch.ones((imgs.size(0), 1)).cuda()
# Labels for fake images (0)
noise = torch.randn((imgs.size(0), 100)).cuda()
fake_imgs = generator(noise)
fake_labels = torch.zeros((imgs.size(0), 1)).cuda()
# Update the discriminator
optimizer_D.zero_grad()
outputs = discriminator(real_imgs)
d_loss_real = criterion(outputs, real_labels)
d_loss_real.backward()
outputs = discriminator(fake_imgs.detach())
d_loss_fake = criterion(outputs, fake_labels)
d_loss_fake.backward()
optimizer_D.step()
# Update the generator
optimizer_G.zero_grad()
outputs = discriminator(fake_imgs)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_G.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}')
3.6 Visualizing Generated Images
After the training is complete, we will visualize the images generated by the generator.
import matplotlib.pyplot as plt
# Generate images
noise = torch.randn(16, 100).cuda()
fake_imgs = generator(noise).view(-1, 1, 28, 28).cpu().data
# Visualize images
plt.figure(figsize=(10, 10))
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(fake_imgs[i].squeeze(), cmap='gray')
plt.axis('off')
plt.show()
4. Conclusion
In this post, we explored the theoretical background of GAN as well as the basic implementation process of a GAN model using PyTorch. GAN has brought many innovations to the field of generative modeling, and future advancements are highly anticipated. I hope this example helped in understanding the basic principles of GAN and how to implement it in PyTorch.
The advancement of GANs will change the way we generate and process data. I look forward to seeing more active research in generative models like GAN.