Deep Learning with GAN using PyTorch, Structured Data and Unstructured Data

1. Overview of GAN

GAN (Generative Adversarial Networks) is an innovative deep learning model proposed by Ian Goodfellow in 2014, operating by having a generator model and a discriminator model compete against each other during training. The basic components of GAN are the generator and the discriminator. The generator tries to create data that is similar to real data, while the discriminator determines whether the given data is real or generated. Through the competition between these two models, the generator increasingly produces data that resembles real data.

2. Key Components of GAN

2.1 Generator

The generator is a neural network that takes a random noise vector as input and generates data similar to real data. This network typically uses either a multilayer perceptron or a convolutional neural network.

2.2 Discriminator

The discriminator is a neural network that judges whether the input data is real or generated. Its ability to effectively distinguish between fake data created by the generator and real data is crucial.

2.3 Loss Function

The loss function of GAN is divided into the loss of the generator and the loss of the discriminator. The goal of the generator is to deceive the discriminator, while the goal of the discriminator is to accurately distinguish the generated data.

3. Training Process of GAN

The training process of GAN is iterative and consists of the following steps:

  1. Generate fake data through the generator using real data and random noise.
  2. Input the fake data and real data into the discriminator.
  3. Calculate how well the discriminator distinguished between fake and real data and update the loss.
  4. Update the generator to deceive the discriminator.

4. Applications of GAN

GAN is used in various fields, including:

  • Image generation
  • Video generation
  • Voice synthesis
  • Text generation
  • Data augmentation

5. Difference Between Structured Data and Unstructured Data

Structured data is organized in a format that can easily be represented in a relational database. On the other hand, unstructured data refers to data without a unique form or structure, such as text, images, and videos. GAN is primarily used for unstructured data but can also be utilized for structured data.

6. Example of GAN Implementation Using PyTorch

Below is a simple implementation example of GAN using PyTorch. In this example, we will generate handwritten digits using the MNIST dataset.

6.1 Setting Up the Environment

First, install and import the necessary libraries.

!pip install torch torchvision matplotlib

6.2 Preparing the Dataset

Load the MNIST dataset and transform it to a tensor.

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

6.3 Model Definition

Define the generator and discriminator models.

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784), # 28x28 images like MNIST
            nn.Tanh(),
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

6.4 Setting Loss Function and Optimizer

Use BCELoss as the loss function and proceed with learning using the Adam optimizer.

criterion = nn.BCELoss()
generator = Generator()
discriminator = Discriminator()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

6.5 Model Training

Perform model training. At the end of each epoch, generated samples can be checked.

import matplotlib.pyplot as plt

num_epochs = 100
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)
        
        # Label generation
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Discriminator training
        optimizer_D.zero_grad()
        outputs = discriminator(real_images.view(batch_size, -1))
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()

        noise = torch.randn(batch_size, 100)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        optimizer_D.step()

        # Generator training
        optimizer_G.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item():.4f}, g_loss: {g_loss.item():.4f}')
        # Check generated images
        with torch.no_grad():
            fake_images = generator(noise).view(-1, 1, 28, 28)
            grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True)
            plt.imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
            plt.show()

6.6 Checking Results

As training progresses, the generator increasingly produces realistic handwritten digits, demonstrating the performance of GAN.

7. Conclusion

This article covered the overview of GAN, its components, training process, and implementation methods using PyTorch. GAN exhibits excellent performance in generating unstructured data and can be applied in various fields. With advancements in technology, the future possibilities are endless.

8. References

1. Ian Goodfellow et al., “Generative Adversarial Networks”, NeurIPS 2014.

2. PyTorch Documentation – https://pytorch.org/docs/stable/index.html

3. torchvision Documentation – https://pytorch.org/vision/stable/index.html