1. Introduction
Deep learning is achieving innovative advancements in various fields such as computer vision, natural language processing, and speech recognition. Among these, Generative Adversarial Networks (GANs) have garnered special attention as a technology. GAN consists of two neural networks, namely a Generator and a Discriminator, which compete against each other, enabling it to generate realistic data.
In this article, we will take a detailed look at one of the variants of GAN, the Conditional Generative Adversarial Network (cGAN). cGAN allows for the generation of images of specific classes by providing conditions during the generation process. For example, we will explore how to generate images of specific digits using the MNIST dataset.
2. Overview of cGAN
2.1 Basic Structure of GAN
A GAN essentially consists of two neural networks. The Generator takes a random noise vector as input to generate fake images, while the Discriminator evaluates whether the input image is real or fake. They interact as follows:
- The Generator creates images based on random noise input
- The generated images are sent to the Discriminator for comparison with real images
- The Discriminator classifies the real image as ‘1’ and the fake image as ‘0’
- This process repeats, gradually causing the Generator to produce more realistic images
2.2 Structure of cGAN
cGAN extends the concept of GAN by adding conditional information to both the Generator and the Discriminator, allowing the generation of images for specific classes. For example, when setting the condition to the digit ‘3’ in digit image generation, the Generator will produce an image corresponding to ‘3’. The structure of cGAN is as follows:
- The Generator takes conditional information as input to generate images
- The Discriminator accepts both the input image and the conditional information to determine real or fake
3. Basic Setup for Implementing cGAN in PyTorch
3.1 Install Required Libraries
We will install the necessary Python libraries to implement cGAN. We will primarily use PyTorch, NumPy, and Matplotlib libraries. They can be installed with the following command.
        
        pip install torch torchvision numpy matplotlib
        
    
3.2 Prepare Dataset
We will use the MNIST dataset to implement cGAN. MNIST is a dataset consisting of handwritten digit images from 0 to 9. This dataset can be loaded from PyTorch’s torchvision.
        
import torch
from torchvision import datasets, transforms
# Load dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
        
    
4. Implementing cGAN Architecture
4.1 Generator
The Generator takes random noise and conditional information as input to create images. The Generator model is generally constructed using multiple linear layers and ReLU activation functions.
        
import torch.nn as nn
class Generator(nn.Module):
    def __init__(self, z_dim, num_classes):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(z_dim + num_classes, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1 * 28 * 28),
            nn.Tanh()
        )
    def forward(self, noise, labels):
        label_input = self.label_embedding(labels)
        input = torch.cat((noise, label_input), dim=1)
        img = self.model(input)
        img = img.view(img.size(0), 1, 28, 28)
        return img
        
    
4.2 Discriminator
The Discriminator accepts both the image and conditional information to evaluate whether they are real or fake. It can be designed in a structure that starts with a bottom layer and gradually deepens.
        
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(1 * 28 * 28 + num_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, img, labels):
        label_input = self.label_embedding(labels)
        img_flat = img.view(img.size(0), -1)
        input = torch.cat((img_flat, label_input), dim=1)
        validity = self.model(input)
        return validity
        
    
5. Loss Function and Optimization
The loss function for cGAN evaluates the performance of the Generator and the Discriminator. It mainly uses binary cross-entropy loss, as the Generator and Discriminator have opposing objectives.
        
import torch.optim as optim
def build_optimizers(generator, discriminator, lr=0.0002, beta1=0.5):
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
    return g_optimizer, d_optimizer
        
    
6. Training cGAN
The Generator and Discriminator train by competing against each other. In each iteration, the Discriminator is adjusted to show high confidence on real images while maintaining low confidence for images generated by the Generator. Below is an example of the training loop.
        
num_classes = 10
z_dim = 100
generator = Generator(z_dim, num_classes)
discriminator = Discriminator(num_classes)
g_optimizer, d_optimizer = build_optimizers(generator, discriminator)
criterion = nn.BCELoss()
# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    for imgs, labels in train_loader:
        batch_size = imgs.size(0)
        # Prepare real and fake image labels
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        # Train Discriminator
        discriminator.zero_grad()
        outputs = discriminator(imgs, labels)
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()
        noise = torch.randn(batch_size, z_dim)
        random_labels = torch.randint(0, num_classes, (batch_size,))
        generated_imgs = generator(noise, random_labels)
        outputs = discriminator(generated_imgs, random_labels)
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        d_optimizer.step()
        d_loss = d_loss_real + d_loss_fake
        
        # Train Generator
        generator.zero_grad()
        noise = torch.randn(batch_size, z_dim)
        generated_imgs = generator(noise, random_labels)
        outputs = discriminator(generated_imgs, random_labels)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()
        print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')
        
    
7. Visualizing Results
After training is complete, we can visualize the generated images. Using Matplotlib, we can generate and display images of specific classes.
        
import matplotlib.pyplot as plt
def generate_and_show_images(generator, num_images=10):
    noise = torch.randn(num_images, z_dim)
    labels = torch.randint(0, num_classes, (num_images,))
    generated_images = generator(noise, labels)
    for i in range(num_images):
        img = generated_images[i].detach().numpy().reshape(28, 28)
        plt.subplot(2, 5, i + 1)
        plt.imshow(img, cmap='gray')
        plt.axis('off')
    plt.show()
generate_and_show_images(generator)
        
    
8. Conclusion
In this article, we explored the concept and implementation of Conditional Generative Adversarial Networks (cGAN). cGAN is a powerful method for generating images based on specific conditions and can be applied in various fields. It can be utilized not only for image generation but also in tasks like image transformation and style transfer. Having discussed in detail how to implement cGAN using PyTorch, we hope for the future development of more advanced models and diverse applications.