Deep Learning with GAN using PyTorch, Literary Club for Notorious Offenders

Generative Adversarial Networks (GANs) are considered one of the most innovative advancements in deep learning. GAN consists of two neural networks: a Generator and a Discriminator. The Generator creates data, while the Discriminator determines whether the data is real or fake. This competitive structure helps to enhance each other’s performance. In this course, we will build a GAN using PyTorch and perform data generation around the theme of ‘The Literary Club for Bad Criminals’ in an interesting way.

1. Basic Structure and Principles of GAN

The operation of GAN works as follows:

  • Generator: Takes random noise (z) as input and generates realistic data.
  • Discriminator: Determines whether the input data is real or generated by the Generator.
  • The Generator tries to deceive the Discriminator, while the Discriminator tries to differentiate between the two. As this competition continues, both networks progress further.

2. Preparing Required Libraries and Datasets

Install PyTorch and other necessary libraries. Then you will need to choose the dataset to prepare the data for this process. In this example, we will use the MNIST dataset to generate images of numbers. The MNIST dataset is composed of images of handwritten digits.

2.1 Setting Up the Environment

pip install torch torchvision

2.2 Loading the Dataset

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset=mnist_dataset, batch_size=64, shuffle=True)

3. Constructing the GAN Model

We define the Generator and Discriminator models to implement the Generative Adversarial Network.

3.1 Generator Model

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),  # MNIST image size
            nn.Tanh()  # Adjust the pixel value range of generated images to -1 ~ 1
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)  # Transform into image shape
        return img

3.2 Discriminator Model

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output value between 0 and 1
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)  # Flatten the image
        validity = self.model(img_flat)
        return validity

4. Training Process of GAN

The training process of GAN is carried out as follows:

  • Provide real images and generated images to the Discriminator to calculate its loss.
  • Update the Generator to make the generated images closer to the real ones.
  • Repeat this process to help each network improve.

4.1 Defining Loss Function and Optimizers

import torch.optim as optim

# Create instances of Generator and Discriminator
generator = Generator()
discriminator = Discriminator()

# Set loss function and optimizer
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

4.2 Training Loop

num_epochs = 200
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(data_loader):
        # Label for real data: 1
        real_imgs = imgs
        valid = torch.ones(imgs.size(0), 1)  # Ground truth for real images
        fake = torch.zeros(imgs.size(0), 1)  # Ground truth for fake images

        # Train Discriminator
        optimizer_D.zero_grad()
        z = torch.randn(imgs.size(0), 100)  # Sample random noise
        generated_imgs = generator(z)  # Generated images
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = adversarial_loss(discriminator(generated_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch {epoch}/{num_epochs} | D Loss: {d_loss.item()} | G Loss: {g_loss.item()}')

5. Visualizing Results

After training is completed, we visualize the generated images to check the results.

import matplotlib.pyplot as plt

# Visualize generated images
def show_generated_images(generator, num_images=16):
    z = torch.randn(num_images, 100)  # Sample random noise
    generated_images = generator(z)
    generated_images = generated_images.detach().numpy()

    fig, axs = plt.subplots(4, 4, figsize=(10, 10))
    for i in range(4):
        for j in range(4):
            axs[i, j].imshow(generated_images[i * 4 + j, 0], cmap='gray')
            axs[i, j].axis('off')
    plt.show()

show_generated_images(generator)

In this way, you can build and train a GAN and verify the generated images. The potential applications of GANs are vast, and they can facilitate creative tasks. Now, you can take a step closer to the world of GANs!

6. Conclusion

Generative Adversarial Networks are a very interesting area of deep learning, actively used in many research and development projects. In this course, we explored the basic principles and structures of GAN using PyTorch and covered the process of building and training deep learning models. I hope you gain a deep understanding and interest in GAN through this course and that it greatly helps you in your future deep learning journey.