Deep Learning with GANs using PyTorch, Deep Neural Networks

1. Overview of GAN

GAN (Generative Adversarial Networks) is a deep learning model proposed by Ian Goodfellow in 2014. GAN has the ability to generate new data by learning the distribution of a given dataset.
The main components of GAN are two neural networks: the Generator and the Discriminator. The Generator creates fake data that resembles real data, while the Discriminator determines whether the generated data is real or fake.

2. Structure of GAN

GAN consists of the following structure:

  • Generator (G): Takes random noise as input and generates fake data from it.
  • Discriminator (D): Functions to distinguish between real data and generated fake data.

2.1. Loss Function

During the training process of GAN, both the Generator and the Discriminator learn competitively by optimizing their respective loss functions. The goal of the Discriminator is to accurately distinguish real data from fake data, while the goal of the Generator is to fool the Discriminator. This can be expressed mathematically as follows:


    min_G max_D V(D, G) = E[log(D(x))] + E[log(1 - D(G(z)))]
    

3. Implementing GAN using PyTorch

In this section, we will implement a simple GAN using PyTorch. We will create a GAN that generates digit images using the MNIST dataset as a simple example.

3.1. Importing Libraries


    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    

3.2. Setting Hyperparameters


    # Setting hyperparameters
    latent_size = 64
    batch_size = 128
    learning_rate = 0.0002
    num_epochs = 50
    

3.3. Loading the Dataset


    # Loading MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
    

3.4. Defining the Generator and Discriminator


    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(latent_size, 128),
                nn.ReLU(),
                nn.Linear(128, 256),
                nn.ReLU(),
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 784),
                nn.Tanh()
            )

        def forward(self, z):
            return self.model(z).reshape(-1, 1, 28, 28)

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.model = nn.Sequential(
                nn.Flatten(),
                nn.Linear(784, 512),
                nn.LeakyReLU(0.2),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 1),
                nn.Sigmoid()
            )

        def forward(self, img):
            return self.model(img)
    

3.5. Setting up the Model, Loss Function, and Optimization Techniques


    generator = Generator()
    discriminator = Discriminator()

    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
    

3.6. GAN Training Loop


    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(dataloader):
            # Define real images and labels.
            real_imgs = imgs
            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)

            # Training the Discriminator
            optimizer_D.zero_grad()
            outputs = discriminator(real_imgs)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, latent_size)
            fake_imgs = generator(z)
            outputs = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_D.step()

            # Training the Generator
            optimizer_G.zero_grad()
            outputs = discriminator(fake_imgs)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}')
    

3.7. Visualization of Results

After training, we will visualize the generated images to evaluate the performance of the GAN.


    z = torch.randn(64, latent_size)
    generated_images = generator(z).detach().numpy()
    generated_images = (generated_images + 1) / 2  # Normalize to 0-1

    fig, axs = plt.subplots(8, 8, figsize=(10,10))
    for i in range(8):
        for j in range(8):
            axs[i,j].imshow(generated_images[i*8 + j][0], cmap='gray')
            axs[i,j].axis('off')
    plt.show()
    

4. Conclusion

In this article, we explored the basic concepts of GAN and how to implement a simple GAN using PyTorch. GAN demonstrates excellent performance in the field of data generation and is utilized across various application domains.