파이토치를 활용한 GAN 딥러닝, WGAN-GP

Generative Adversarial Networks(이하 GAN)는 Ian Goodfellow가 2014년에 제안한 강력한 생성 모델입니다. GAN은 두 개의 신경망, 즉 생성자(Generator)와 판별자(Discriminator)로 구성되며, 이 두 네트워크는 서로 경쟁하여 학습합니다. 생성자는 실제 데이터와 유사한 새로운 데이터를 생성하려고 하고, 판별자는 주어진 데이터가 실제 데이터인지 생성된 데이터인지를 구별하려고 합니다. 이들은 지속적으로 개선되면서 최종적으로 매우 평범한 데이터를 신뢰성 있게 생성할 수 있게 됩니다.

본 글에서는 GAN의 변형인 Wasserstein GAN with Gradient Penalty (WGAN-GP)에 대해 설명하고, 파이토치를 사용하여 WGAN-GP를 구현해보겠습니다. WGAN-GP는 Wasserstein 거리(Wasserstein Distance)를 기반으로 하며, 판별자에서의 Gradient Penalty를 추가하여 훈련의 안정성을 높입니다.

1. GAN의 기본 구조

GAN의 기본 구조는 다음과 같습니다.

  • 생성자(Generator): 랜덤 노이즈를 입력받아 가짜 데이터를 생성합니다.
  • 판별자(Discriminator): 실제 데이터와 생성자가 만든 가짜 데이터를 입력받아 이 둘이 얼마나 유사한지를 판단합니다.

GAN의 학습 과정은 다음의 두 단계로 이루어집니다.

이 과정은 반복적으로 수행되어 두 네트워크 모두 개선됩니다. 다만, 기존 GAN은 훈련 불안정성과 모드 붕괴(mode collapse) 문제가 종종 발생하여 보다 안정적인 학습을 위한 여러 접근법이 연구되었습니다.

2. WGAN-GP 소개

WGAN에서는 Wasserstein 거리 개념을 도입하여 GAN의 본질적인 문제를 해결하고자 하였습니다. Wasserstein 거리는 두 분포 간의 차이를 보다 명확히 정의할 수 있어, 네트워크를 훈련하는 데 유리합니다. WGAN의 핵심 아이디어는 판별자가 아닌 “비평가(critic)”라는 개념을 도입하는 것입니다. 비평가는 생성된 데이터와 실제 데이터 간의 거리를 평가하고, 이 평가 결과를 바탕으로 평균 제곱 오차(MSE) 손실이 아닌 Wasserstein 손실을 사용하여 네트워크를 업데이트합니다.

WGAN에서 Gradient Penalty (GP)를 추가함으로써 판별자의 Lipschitz 조건을 준수할 수 있도록 하여 훈련의 안정성을 더욱 높였습니다. Gradient Penalty는 다음과 같이 정의됩니다:

GP = λ * E[(||∇D(x) ||2 - 1)²]

여기서 λ는 하이퍼파라미터이며, D(x)는 판별자의 출력입니다. Gradient Penalty를 통해 판별자의 기울기가 1로 유지되도록 강화할 수 있습니다. 이 방식을 통해 WGAN-GP는 GAN의 불안정을 극복하고 보다 안정적인 훈련이 가능해집니다.

3. WGAN-GP 파이토치 구현

이제 파이토치를 사용하여 WGAN-GP를 구현해보겠습니다. 다음의 단계로 진행됩니다:

  1. 필요한 라이브러리 설치 및 데이터셋 불러오기
  2. 생성자 및 판별자 모델 정의
  3. WGAN-GP 훈련 루프 구현
  4. 결과 시각화

3.1 라이브러리 설치 및 데이터셋 불러오기

우선 필요한 라이브러리를 설치하고, MNIST 데이터셋을 불러옵니다.

!pip install torch torchvision matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

3.2 생성자 및 판별자 모델 정의

생성자와 판별자 모델을 정의합니다. 생성자는 랜덤한 노이즈 벡터를 입력으로 받아 이미지로 변환하고, 판별자는 입력된 이미지를 바탕으로 진짜인지 가짜인지를 평가합니다.

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).reshape(-1, 1, 28, 28)

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.model(x.view(-1, 28 * 28))

3.3 WGAN-GP 훈련 루프 구현

이제 WGAN-GP의 훈련 루프를 구현합니다. 훈련 과정에서는 판별자를 일정 횟수만큼 업데이트한 후 생성자를 업데이트합니다. Gradient Penalty도 손실에 포함됩니다.

def compute_gradient_penalty(critic, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).expand_as(real_samples)
    interpolated_samples = alpha * real_samples + (1 - alpha) * fake_samples
    interpolated_samples.requires_grad_(True)

    d_interpolated = critic(interpolated_samples)

    gradients = torch.autograd.grad(outputs=d_interpolated, inputs=interpolated_samples,
                                    grad_outputs=torch.ones_like(d_interpolated),
                                    create_graph=True, retain_graph=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator().to(device)
critic = Critic().to(device)

learning_rate = 0.00005
num_epochs = 100
critic_iterations = 5
lambda_gp = 10

criterion = nn.MSELoss()
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_critic = optim.Adam(critic.parameters(), lr=learning_rate)

Real_data = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

data_loader = torch.utils.data.DataLoader(Real_data, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(data_loader):
        real_images = real_images.to(device)

        for _ in range(critic_iterations):
            optimizer_critic.zero_grad()

            # Generate fake images
            z = torch.randn(real_images.size(0), 100).to(device)
            fake_images = generator(z)

            # Get critic scores
            real_validity = critic(real_images)
            fake_validity = critic(fake_images)
            gradient_penalty = compute_gradient_penalty(critic, real_images.data, fake_images.data)

            # Compute loss
            critic_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
            critic_loss.backward()
            optimizer_critic.step()

        # Update generator
        optimizer_generator.zero_grad()
        
        # Get generator score
        fake_images = generator(z)
        validity = critic(fake_images)
        generator_loss = -torch.mean(validity)
        generator_loss.backward()
        optimizer_generator.step()

    if epoch % 10 == 0:
        print(f"Epoch: {epoch}/{num_epochs}, Critic Loss: {critic_loss.item():.4f}, Generator Loss: {generator_loss.item():.4f}")

3.4 결과 시각화

최종적으로 생성된 이미지를 시각화합니다. 이는 학습 과정에서 생성자가 얼마나 잘 학습되었는지를 확인하는 좋은 방법입니다.

def show_generated_images(generator, num_images=25):
    z = torch.randn(num_images, 100).to(device)
    generated_images = generator(z).cpu().detach().numpy()

    plt.figure(figsize=(5, 5))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i][0], cmap='gray')
        plt.axis('off')
    plt.show()

show_generated_images(generator)

4. 결론

이 글에서는 GAN의 변형 모델인 WGAN-GP에 대해 설명하고, 파이토치를 활용하여 WGAN-GP를 구현해보았습니다. WGAN-GP는 Wasserstein 거리와 Gradient Penalty를 활용하여 보다 안정적인 훈련이 가능하다는 장점을 가지고 있습니다. 이러한 GAN 계열 모델은 이미지 생성, 이미지 변환, 스타일 전이 등 다양한 분야에 활용될 수 있습니다.

딥러닝의 발전과 함께 GAN과 그 변형 모델들은 계속해서 주목받고 있으며, 앞으로의 발전이 기대됩니다. 여러분도 GAN 및 WGAN-GP를 활용하여 다양한 프로젝트에 도전해보시기를 바랍니다!