1. What is GAN?
Generative Adversarial Networks (GAN) is one of the fundamental deep learning models proposed by Ian Goodfellow in 2014. GAN consists of two neural networks:
Generator and Discriminator. The generator tries to create fake data, while the discriminator attempts to determine whether the data is real or fake.
These two networks compete against each other during training, hence the term “Adversarial.” This process is unsupervised, allowing the generator to produce data that increasingly resembles real data.
2. Structure of GAN
The structure of GAN operates as follows:
- Generator Network: Takes a random noise vector as input and generates fake images.
- Discriminator Network: Responsible for distinguishing between real and fake images.
- During the training process, the generator learns to produce images that the discriminator cannot easily classify. This causes both networks to improve and compete with each other.
3. How GAN Works
The training process of GAN consists of the following iterative steps:
- Training the Discriminator: The discriminator receives real images and fake images produced by the generator as input and updates its parameters to classify these images correctly.
- Training the Generator: The generator evaluates the quality of images produced through the trained discriminator, updating its parameters to prevent the discriminator from recognizing its images as fake.
- This process is repeated, allowing both networks to become progressively stronger against each other.
4. Implementing GAN in PyTorch
Now, let’s implement GAN in PyTorch. Our goal is to generate digit images using the MNIST dataset.
4.1 Installing Required Libraries
pip install torch torchvision matplotlib
4.2 Preparing the Dataset
The MNIST dataset can be easily retrieved through PyTorch’s torchvision library. The code below shows how to load and preprocess the data.
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Data transformation settings
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
4.3 Defining the Generator and Discriminator Networks
Now we define the two networks of GAN. We will create the generator and discriminator using a simple neural network structure.
import torch.nn as nn
# Define Generator Network
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28),
nn.Tanh()
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
# Define Discriminator Network
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
4.4 Setting Loss Function and Optimization Algorithm
The loss function used for both the generator and discriminator is binary cross-entropy loss. We adopt the Adam optimization algorithm.
import torch.optim as optim
# Define loss function and optimization algorithm
criterion = nn.BCELoss()
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
4.5 Implementing the Training Process
Training proceeds as follows:
import numpy as np
num_epochs = 50
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
batch_size = images.size(0)
# Set real and fake labels
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Train Discriminator
outputs = discriminator(images)
d_loss_real = criterion(outputs, real_labels)
noise = torch.randn(batch_size, 100)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
# Train Generator
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
# Output training status
if (i + 1) % 100 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
4.6 Visualizing Generated Images
After training is complete, we visualize the images generated.
import matplotlib.pyplot as plt
def visualize_generator(num_images):
noise = torch.randn(num_images, 100)
with torch.no_grad():
generated_images = generator(noise)
plt.figure(figsize=(10, 10))
for i in range(num_images):
plt.subplot(5, 5, i + 1)
plt.imshow(generated_images[i][0].cpu().numpy(), cmap='gray')
plt.axis('off')
plt.show()
visualize_generator(25)
5. Applications of GAN
GANs can be used in various fields beyond image generation. For example, they are employed in style transfer, image restoration, video generation, and have garnered significant attention in the field of artificial intelligence.
The development of GANs is revealing new possibilities through generative models.
6. Conclusion
In this tutorial, we learned the basics of implementing GANs using PyTorch and understood how GAN operates through actual code. GAN is a technology that will continue to evolve and holds great potential in various applications.
I recommend exploring various modified models based on GAN in the future.