Using PyTorch for GAN Deep Learning, First GAN

Generative Adversarial Networks (GANs) are an innovative deep learning model proposed by Ian Goodfellow in 2014, where two neural networks learn in opposition to each other. GANs are widely used in various fields such as image generation, text generation, and video generation. In this post, we will explain the basic concepts and implementation methods of GANs step by step using PyTorch.

1. Basic Concepts of GAN

GAN consists of two neural networks: the Generator and the Discriminator. The role of the Generator is to create data that looks real, and the Discriminator’s role is to determine whether the given data is real or fake data produced by the Generator. These two networks learn simultaneously; the Generator evolves to create increasingly sophisticated data to fool the Discriminator, while the Discriminator evolves to identify fake data more accurately.

1.1 Structure of GAN

The structure of GAN can be described simply as follows:

  • Generator: Accepts random noise as input and generates data that looks real.
  • Discriminator: Accepts real data and generated fake data as input and predicts whether each piece of data is real or not.

1.2 Learning Process of GAN

The learning process of GAN proceeds as follows:

  1. Using real data and random noise, the Generator (G) creates fake data.
  2. The generated data and real data are input to the Discriminator (D), and predictions for each data are obtained.
  3. The loss function of the Generator is set to maximize the probability that the Discriminator judges fake data as real.
  4. The loss function of the Discriminator is set to maximize the probability that it judges real data as real and fake data as fake.
  5. This process is repeated so that both networks compete with each other, improving their performance.

2. Implementing GAN using PyTorch

Now, let’s implement a simple GAN using PyTorch. Here, we will work on creating a GAN that generates images in numerical form using the MNIST dataset.

2.1 Environment Setup

First, we install and import the necessary libraries. We will use PyTorch and torchvision to load the dataset and build the model.

    
    !pip install torch torchvision matplotlib
    
    

2.2 Preparing the Dataset

We will load the MNIST dataset and perform data preprocessing. This will scale the image data between 0 and 1 and divide it into batches.

    
    import torch
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader

    # Load and preprocess the data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    

2.3 Defining the Generator and Discriminator

Next, we will define the two key components of GAN: the Generator and the Discriminator. Here, the Generator takes random noise as input to generate images, and the Discriminator takes images as input to determine whether they are real or fake.

    
    import torch.nn as nn

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(100, 256),
                nn.ReLU(),
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 1024),
                nn.ReLU(),
                nn.Linear(1024, 28 * 28),
                nn.Tanh()
            )

        def forward(self, z):
            return self.model(z).view(-1, 1, 28, 28)

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.model = nn.Sequential(
                nn.Flatten(),
                nn.Linear(28 * 28, 512),
                nn.LeakyReLU(0.2),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 1),
                nn.Sigmoid()
            )

        def forward(self, img):
            return self.model(img)
    
    

2.4 Initializing the Model and Setting Loss Function, Optimizer

We will initialize the Generator and Discriminator and specify the loss function and optimizers. We will use CrossEntropyLoss and the Adam optimizer.

    
    generator = Generator()
    discriminator = Discriminator()

    ad = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    ag = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    criterion = nn.BCELoss()
    
    

2.5 Training the GAN

Now, let’s train the GAN. During each epoch, we train the Generator and Discriminator, and we can see the generated images.

    
    import matplotlib.pyplot as plt
    import numpy as np

    def train_gan(generator, discriminator, criterion, ag, ad, dataloader, epochs=50):
        for epoch in range(epochs):
            for real_imgs, _ in dataloader:
                batch_size = real_imgs.size(0)

                # Generate real images and labels
                real_labels = torch.ones(batch_size, 1)
                noise = torch.randn(batch_size, 100)
                fake_imgs = generator(noise)
                fake_labels = torch.zeros(batch_size, 1)

                # Train the Discriminator
                discriminator.zero_grad()
                real_loss = criterion(discriminator(real_imgs), real_labels)
                fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
                d_loss = real_loss + fake_loss
                d_loss.backward()
                ad.step()

                # Train the Generator
                generator.zero_grad()
                g_loss = criterion(discriminator(fake_imgs), real_labels)
                g_loss.backward()
                ag.step()

            print(f'Epoch [{epoch + 1}/{epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')

            # Save generated images
            if (epoch + 1) % 10 == 0:
                save_generated_images(generator, epoch + 1)

    def save_generated_images(generator, epoch):
        noise = torch.randn(64, 100)
        generated_imgs = generator(noise)
        generated_imgs = generated_imgs.detach().numpy()
        generated_imgs = (generated_imgs + 1) / 2  # Rescale to [0, 1]

        fig, axs = plt.subplots(8, 8, figsize=(8, 8))
        for i, ax in enumerate(axs.flat):
            ax.imshow(generated_imgs[i][0], cmap='gray')
            ax.axis('off')
        plt.savefig(f'generated_images_epoch_{epoch}.png')
        plt.close()

    train_gan(generator, discriminator, criterion, ag, ad, dataloader, epochs=50)
    
    

2.6 Checking the Results

After training is completed, check the generated images. As GANs undergo iterative training, they become capable of generating data images that increasingly resemble real ones. Ultimately, the performance of GANs is evaluated by the quality of the generated images. If the training goes well, the generated images will have unfamiliar yet beautiful forms.

3. Conclusion

In this post, we explained how to implement GANs using PyTorch. I hope you were able to experience creating your own GAN along with the basic concepts of GANs and actual code. GANs are powerful tools, but building a robust model requires diverse and in-depth research. We invite you into the world of GANs that create beautiful and creative results!

4. References