Use of PyTorch for GAN Deep Learning, Probabilistic Generative Model

In this post, we will take a closer look at Generative Adversarial Networks (GAN). GAN is a generative model proposed by Ian Goodfellow in 2014, which uses two neural networks (Generator and Discriminator) to generate data. The key aspect of GAN that we focus on is that the two neural networks compete with each other, which allows for the generation of more advanced data.

1. Basic Structure of GAN

GAN consists of the following two components:

  • Generator: It is responsible for generating new data. It takes random noise as input and outputs data that is similar to real data.
  • Discriminator: It distinguishes whether the given data is real data or data generated by the Generator.

The Generator and Discriminator are trained through the following loss functions:

  • Generator Loss Function: It encourages the Discriminator to classify the output of the Generator as real data.
  • Discriminator Loss Function: It learns to distinguish between the distribution of real data and data generated by the Generator as much as possible.

2. Training Process of GAN

The training process of the GAN model consists of the following steps:

  1. Select a random sample from the real dataset.
  2. Generate fake data by inputting random noise into the Generator.
  3. Feed the Discriminator with both real and fake data, calculating their respective probabilities.
  4. Update the Generator and Discriminator based on their respective loss functions.
  5. Repeat this process.

3. Implementing GAN Using PyTorch

Now, let’s implement a simple GAN using PyTorch. In this example, we will implement a GAN model that generates digit images using the MNIST dataset.

3.1 Installing Required Libraries


# Install required libraries
!pip install torch torchvision matplotlib

3.2 Loading and Preprocessing the Dataset


import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Download and preprocess the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

3.3 Defining Generator and Discriminator Models


import torch.nn as nn

# Define Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.fc(x)
        return x.view(-1, 1, 28, 28)

# Define Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        return self.fc(x)

3.4 Model Training


# Setting hyperparameters
num_epochs = 200
learning_rate = 0.0002
beta1 = 0.5

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Define loss function and optimization algorithm
criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# Training loop
for epoch in range(num_epochs):
    for i, (data, _) in enumerate(train_loader):
        # Setting labels for real and fake data
        real_labels = torch.ones(data.size(0), 1)
        fake_labels = torch.zeros(data.size(0), 1)

        # Training Discriminator
        optimizerD.zero_grad()
        outputs = discriminator(data)
        lossD_real = criterion(outputs, real_labels)
        lossD_real.backward()

        noise = torch.randn(data.size(0), 100)
        fake_data = generator(noise)
        outputs = discriminator(fake_data.detach())
        lossD_fake = criterion(outputs, fake_labels)
        lossD_fake.backward()
        optimizerD.step()

        # Training Generator
        optimizerG.zero_grad()
        outputs = discriminator(fake_data)
        lossG = criterion(outputs, real_labels)
        lossG.backward()
        optimizerG.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss D: {lossD_real.item() + lossD_fake.item():.4f}, Loss G: {lossG.item():.4f}')

3.5 Visualizing Results


# Function to visualize generated images
def visualize(generator):
    noise = torch.randn(64, 100)
    fake_data = generator(noise)
    fake_data = fake_data.detach().numpy()
    fake_data = (fake_data + 1) / 2  # Normalize to [0, 1]

    plt.figure(figsize=(8, 8))
    for i in range(fake_data.shape[0]):
        plt.subplot(8, 8, i+1)
        plt.axis('off')
        plt.imshow(fake_data[i][0], cmap='gray')
    plt.show()

# Visualize results
visualize(generator)

4. Applications of GAN

GANs are used not only for image generation but also in various fields:

  • Image Generation: GAN can be used to generate high-quality images.
  • Style Transfer: GAN can be used to transform the style of an image. For instance, it can convert a daytime photo to nighttime.
  • Data Augmentation: GAN can be used to augment datasets by generating new data.

5. Conclusion

In this post, we explored the concept of GAN and a simple implementation method using PyTorch. GAN is a type of generative model with various potential applications. With the current advancements in GANs and various model variations being proposed, learning and utilizing GANs will be a very useful skill.

I hope this post has helped in understanding GAN and aided in practical implementation. I will return with more diverse topics on deep learning in the future!