Generative Adversarial Networks (GANs) are considered one of the most innovative advancements in deep learning. GAN consists of two neural networks: a Generator and a Discriminator. The Generator creates data, while the Discriminator determines whether the data is real or fake. This competitive structure helps to enhance each other’s performance. In this course, we will build a GAN using PyTorch and perform data generation around the theme of ‘The Literary Club for Bad Criminals’ in an interesting way.
1. Basic Structure and Principles of GAN
The operation of GAN works as follows:
- Generator: Takes random noise (z) as input and generates realistic data.
- Discriminator: Determines whether the input data is real or generated by the Generator.
- The Generator tries to deceive the Discriminator, while the Discriminator tries to differentiate between the two. As this competition continues, both networks progress further.
2. Preparing Required Libraries and Datasets
Install PyTorch and other necessary libraries. Then you will need to choose the dataset to prepare the data for this process. In this example, we will use the MNIST dataset to generate images of numbers. The MNIST dataset is composed of images of handwritten digits.
2.1 Setting Up the Environment
pip install torch torchvision
2.2 Loading the Dataset
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Load MNIST Dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset=mnist_dataset, batch_size=64, shuffle=True)
3. Constructing the GAN Model
We define the Generator and Discriminator models to implement the Generative Adversarial Network.
3.1 Generator Model
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28), # MNIST image size
nn.Tanh() # Adjust the pixel value range of generated images to -1 ~ 1
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28) # Transform into image shape
return img
3.2 Discriminator Model
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # Output value between 0 and 1
)
def forward(self, img):
img_flat = img.view(img.size(0), -1) # Flatten the image
validity = self.model(img_flat)
return validity
4. Training Process of GAN
The training process of GAN is carried out as follows:
- Provide real images and generated images to the Discriminator to calculate its loss.
- Update the Generator to make the generated images closer to the real ones.
- Repeat this process to help each network improve.
4.1 Defining Loss Function and Optimizers
import torch.optim as optim
# Create instances of Generator and Discriminator
generator = Generator()
discriminator = Discriminator()
# Set loss function and optimizer
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
4.2 Training Loop
num_epochs = 200
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(data_loader):
# Label for real data: 1
real_imgs = imgs
valid = torch.ones(imgs.size(0), 1) # Ground truth for real images
fake = torch.zeros(imgs.size(0), 1) # Ground truth for fake images
# Train Discriminator
optimizer_D.zero_grad()
z = torch.randn(imgs.size(0), 100) # Sample random noise
generated_imgs = generator(z) # Generated images
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
g_loss = adversarial_loss(discriminator(generated_imgs), valid)
g_loss.backward()
optimizer_G.step()
print(f'Epoch {epoch}/{num_epochs} | D Loss: {d_loss.item()} | G Loss: {g_loss.item()}')
5. Visualizing Results
After training is completed, we visualize the generated images to check the results.
import matplotlib.pyplot as plt
# Visualize generated images
def show_generated_images(generator, num_images=16):
z = torch.randn(num_images, 100) # Sample random noise
generated_images = generator(z)
generated_images = generated_images.detach().numpy()
fig, axs = plt.subplots(4, 4, figsize=(10, 10))
for i in range(4):
for j in range(4):
axs[i, j].imshow(generated_images[i * 4 + j, 0], cmap='gray')
axs[i, j].axis('off')
plt.show()
show_generated_images(generator)
In this way, you can build and train a GAN and verify the generated images. The potential applications of GANs are vast, and they can facilitate creative tasks. Now, you can take a step closer to the world of GANs!
6. Conclusion
Generative Adversarial Networks are a very interesting area of deep learning, actively used in many research and development projects. In this course, we explored the basic principles and structures of GAN using PyTorch and covered the process of building and training deep learning models. I hope you gain a deep understanding and interest in GAN through this course and that it greatly helps you in your future deep learning journey.