Deep Learning with PyTorch, GAN, WGAN – Wasserstein GAN

With the advancement of deep learning, the use of Generative Adversarial Networks (GANs) is increasing in various fields such as image generation, reinforcement learning, image transformation, and image combination. GANs are used to generate high-resolution images through the competition between two networks: the Generator and the Discriminator. This article will cover the basic concepts of GANs, as well as the structure and operation of WGAN (Wasserstein GAN), along with example PyTorch code for implementation.

1. Basic Concept of GAN

GAN is a model proposed by Ian Goodfellow in 2014, composed of two neural networks: the Generator and the Discriminator. The Generator takes a random noise vector as input to generate data similar to real data, while the Discriminator determines whether the input data is real or generated. In this process, both neural networks learn in a competitive manner to generate increasingly perfect data.

1.1 Structure of GAN

  • Generator (G): A network that takes random noise as input to generate data.
  • Discriminator (D): A network that distinguishes between real data and generated data.

1.2 Loss Function of GAN

The loss function of GAN is as follows:

    L(D) = -E[log(D(x))] - E[log(1 - D(G(z)))],
    L(G) = -E[log(D(G(z)))]
    

Here, D(x) is the probability that the Discriminator judges the real data as true, and G(z) is the data generated by the Generator.

2. WGAN – Wasserstein GAN

The traditional GAN had the problem of an unstable loss function for the Discriminator and instability in learning. WGAN addresses these issues by using Wasserstein Distance. Wasserstein distance (or Earth Mover’s Distance) is a method to measure the optimal transportation cost between two probability distributions.

2.1 Improvements of WGAN

  • WGAN uses a ‘Critic’, a non-linear regression model, instead of a Discriminator.
  • The loss function of WGAN is as follows:
                L(D) = E[D(x)] - E[D(G(z))],
                L(G) = -E[D(G(z))]
                
  • WGAN guarantees the Lipschitz continuity of the Critic through Weight Clipping.
  • It uses Gradient Penalty techniques to relax Lipschitz constraints.

2.2 Structure of WGAN

WGAN introduces a Critic into the basic structure of GAN, resulting in a modified form. The following is the network structure of WGAN:

  • The previous Discriminator is replaced by the current Critic.

3. WGAN Implementation Using PyTorch

Now we will implement WGAN using PyTorch. This example will build a model to generate handwritten digits using the MNIST dataset.

3.1 Preparing the Dataset

First, we load and preprocess the dataset.


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load and preprocess the dataset.
transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
    

3.2 Defining the WGAN Model

Now it’s time to define the Generator and Critic models.


# Define Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
        
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)  # Reshape to 28x28 image

# Define Critic model
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )
    
    def forward(self, img):
        return self.model(img.view(-1, 784))  # Reshape to 784 dimensions
    

3.3 Training Process of WGAN

Now we define the training process for WGAN.


def train_wgan(num_epochs):
    generator = Generator()
    critic = Critic()
    
    # Set optimizers
    optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005)
    optimizer_C = optim.RMSprop(critic.parameters(), lr=0.00005)

    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(train_loader):
            imgs = imgs.to(device)

            # Critic's residual equations
            optimizer_C.zero_grad()
            z = torch.randn(imgs.size(0), 100).to(device)
            fake_imgs = generator(z)
            c_real = critic(imgs)
            c_fake = critic(fake_imgs.detach())
            c_loss = c_fake.mean() - c_real.mean()
            c_loss.backward()
            optimizer_C.step()

            # Weight Clipping
            for p in critic.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Update Generator
            if i % 5 == 0:
                optimizer_G.zero_grad()
                g_loss = -critic(fake_imgs).mean()
                g_loss.backward()
                optimizer_G.step()
            
        print(f'Epoch [{epoch}/{num_epochs}], Loss C: {c_loss.item()}, Loss G: {g_loss.item()}')

# Set GPU usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_wgan(num_epochs=50)
    

3.4 Visualizing the Results

After training is complete, we visualize the generated images to check the results.


import matplotlib.pyplot as plt

def show_generated_images(num_images):
    z = torch.randn(num_images, 100).to(device)
    generated_imgs = generator(z).cpu().detach()
    
    fig, axes = plt.subplots(1, num_images, figsize=(15, 15))
    for i in range(num_images):
        axes[i].imshow(generated_imgs[i][0], cmap='gray')
        axes[i].axis('off')
    plt.show()

# Visualize the results
show_generated_images(5)
    

4. Conclusion

WGAN provides a more stable training process by utilizing Wasserstein Distance to overcome the issues of traditional GANs. This article introduced the method of implementing WGAN using PyTorch, hoping to enhance the understanding of generative adversarial networks. GANs and their variant models are powerful tools that can yield innovative results in various fields beyond image generation.

5. References

  • Ian J. Goodfellow et al., “Generative Adversarial Nets”, 2014.
  • Martin Arjovsky et al., “Wasserstein Generative Adversarial Networks”, 2017.
  • PyTorch Documentation: https://pytorch.org/docs/stable/index.html