1. Overview of GAN
GAN (Generative Adversarial Networks) is a deep learning model proposed by Ian Goodfellow in 2014. GAN has the ability to generate new data by learning the distribution of a given dataset.
The main components of GAN are two neural networks: the Generator and the Discriminator. The Generator creates fake data that resembles real data, while the Discriminator determines whether the generated data is real or fake.
2. Structure of GAN
GAN consists of the following structure:
- Generator (G): Takes random noise as input and generates fake data from it.
- Discriminator (D): Functions to distinguish between real data and generated fake data.
2.1. Loss Function
During the training process of GAN, both the Generator and the Discriminator learn competitively by optimizing their respective loss functions. The goal of the Discriminator is to accurately distinguish real data from fake data, while the goal of the Generator is to fool the Discriminator. This can be expressed mathematically as follows:
min_G max_D V(D, G) = E[log(D(x))] + E[log(1 - D(G(z)))]
3. Implementing GAN using PyTorch
In this section, we will implement a simple GAN using PyTorch. We will create a GAN that generates digit images using the MNIST dataset as a simple example.
3.1. Importing Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
3.2. Setting Hyperparameters
# Setting hyperparameters
latent_size = 64
batch_size = 128
learning_rate = 0.0002
num_epochs = 50
3.3. Loading the Dataset
# Loading MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
3.4. Defining the Generator and Discriminator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_size, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, z):
return self.model(z).reshape(-1, 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
3.5. Setting up the Model, Loss Function, and Optimization Techniques
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
3.6. GAN Training Loop
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Define real images and labels.
real_imgs = imgs
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Training the Discriminator
optimizer_D.zero_grad()
outputs = discriminator(real_imgs)
d_loss_real = criterion(outputs, real_labels)
d_loss_real.backward()
z = torch.randn(batch_size, latent_size)
fake_imgs = generator(z)
outputs = discriminator(fake_imgs.detach())
d_loss_fake = criterion(outputs, fake_labels)
d_loss_fake.backward()
optimizer_D.step()
# Training the Generator
optimizer_G.zero_grad()
outputs = discriminator(fake_imgs)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_G.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}')
3.7. Visualization of Results
After training, we will visualize the generated images to evaluate the performance of the GAN.
z = torch.randn(64, latent_size)
generated_images = generator(z).detach().numpy()
generated_images = (generated_images + 1) / 2 # Normalize to 0-1
fig, axs = plt.subplots(8, 8, figsize=(10,10))
for i in range(8):
for j in range(8):
axs[i,j].imshow(generated_images[i*8 + j][0], cmap='gray')
axs[i,j].axis('off')
plt.show()
4. Conclusion
In this article, we explored the basic concepts of GAN and how to implement a simple GAN using PyTorch. GAN demonstrates excellent performance in the field of data generation and is utilized across various application domains.