GAN Deep Learning Using PyTorch, Art Exhibition

Introduction

GAN (Generative Adversarial Network) is a type of generative model where two neural networks interact to generate new data. GANs are primarily used in various fields such as image generation, text generation, and music generation. In this article, we will implement GAN using the PyTorch library and detail the process of generating artwork that can be used in art exhibitions.

1. Basic Concept of GAN

GAN consists of two models, the Generator and the Discriminator. The Generator creates data based on random input, while the Discriminator distinguishes whether the given data is real or generated. The learning process of GAN is as follows.

  1. The Generator receives random noise as input and generates fake images.
  2. The Discriminator compares the generated images with real images.
  3. The more the Discriminator misjudges fake images as real, the more the Generator learns to create better images.

This process results in the Discriminator becoming increasingly sophisticated, ensuring that the Generator cannot produce overly simplistic fake images.

2. Installing PyTorch

Before implementing GAN, you first need to install PyTorch. You can do this using the following command.

    
    pip install torch torchvision
    
    

3. Preparing Data

To generate images of artworks, a dataset is required. In this practice, we will use the CIFAR-10 dataset. This dataset consists of images from 10 classes and can be used to build a dataset related to pictorial arts. It can be easily used with built-in functions in PyTorch.

3.1 CIFAR-10 Dataset

    
    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    from torch.utils.data import DataLoader

    # Data Preprocessing
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    # Load CIFAR-10 dataset
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
    
    

4. Implementing the GAN Model

The GAN model requires two neural networks to define the Generator and the Discriminator. The Generator generates images based on random noise, while the Discriminator determines the authenticity of the images.

4.1 Generator Model

    
    import torch
    import torch.nn as nn

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(100, 256),
                nn.ReLU(),
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 1024),
                nn.ReLU(),
                nn.Linear(1024, 3*64*64),
                nn.Tanh(),
            )

        def forward(self, z):
            z = self.model(z)
            return z.view(-1, 3, 64, 64)
    
    

4.2 Discriminator Model

    
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(3*64*64, 512),
                nn.LeakyReLU(0.2),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 1),
                nn.Sigmoid(),
            )

        def forward(self, img):
            img = img.view(-1, 3*64*64)
            return self.model(img)
    
    

5. Setting Loss Function and Optimizers

For training the GAN, you need to set the loss function and optimizers. Typically, Binary Cross Entropy Loss is used, and Adam optimization can be employed.

    
    import torch.optim as optim

    generator = Generator()
    discriminator = Discriminator()

    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    

6. Training the GAN

The training process of GAN involves alternating learning between the Generator and Discriminator. The code below represents the main loop for training the GAN.

    
    import numpy as np
    import matplotlib.pyplot as plt

    def train_gan(num_epochs):
        for epoch in range(num_epochs):
            for i, (imgs, _) in enumerate(dataloader):
                # Real image labels: 1
                real_labels = torch.ones(imgs.size(0), 1)
                # Fake image labels: 0
                fake_labels = torch.zeros(imgs.size(0), 1)

                # Discriminator training
                optimizer_D.zero_grad()
                outputs = discriminator(imgs)
                d_loss_real = criterion(outputs, real_labels)
                d_loss_real.backward()

                z = torch.randn(imgs.size(0), 100)
                fake_images = generator(z)
                outputs = discriminator(fake_images.detach())
                d_loss_fake = criterion(outputs, fake_labels)
                d_loss_fake.backward()

                optimizer_D.step()
                d_loss = d_loss_real + d_loss_fake

                # Generator training
                optimizer_G.zero_grad()
                outputs = discriminator(fake_images)
                g_loss = criterion(outputs, real_labels)
                g_loss.backward()
                optimizer_G.step()

                if (i % 100 == 0):
                    print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

    train_gan(num_epochs=20)
    
    

7. Visualizing Results

After training the model, you can visualize the generated images. The code below shows the process of saving and visualizing the generated images.

    
    def plot_generated_images(num_images):
        z = torch.randn(num_images, 100)
        generated_images = generator(z).detach().numpy()
        generated_images = (generated_images + 1) / 2  # (-1, 1) -> (0, 1)
        
        fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
        for i in range(num_images):
            axes[i].imshow(generated_images[i].transpose(1, 2, 0))
            axes[i].axis('off')
        plt.show()

    plot_generated_images(10)
    
    

8. Conclusion

In this article, we explored the basic concepts and implementation process of GAN deep learning using PyTorch. GAN can generate new images, such as artworks, which can be utilized in events like art exhibitions. We expect that GAN technology will continue to evolve, providing various creative possibilities.

9. Additional Resources

Below are additional resources on GAN and related topics.