파이토치를 활용한 GAN 딥러닝, 훈련 과정

생성적 적대 신경망(Generative Adversarial Network, GAN)은 2014년 Ian Goodfellow와 동료들이 발표한 신경망 아키텍처로, 두 개의 신경망인 생성자(Generator)와 판별자(Discriminator)가 경쟁하면서 훈련됩니다. GAN은 주로 이미지 생성, 변환, 재구성 등의 분야에 활용되며, 특히 고해상도의 사진이나 예술작품 생성에 많이 사용됩니다. 본 글에서는 파이토치를 활용하여 GAN의 전체 구조와 훈련 과정을 자세히 살펴보겠습니다.

1. GAN의 구조

GAN은 두 개의 주요 구성 요소로 이루어져 있습니다:

  • 생성자 (G): 무작위 노이즈 벡터를 입력으로 받아들이고, 이를 실제와 유사한 가짜 샘플로 변환하는 네트워크입니다.
  • 판별자 (D): 입력된 샘플이 실제 데이터인지(G)가 아닌지를 판별하는 네트워크입니다. 판별자는 실제 데이터와 생성자가 만든 가짜 데이터를 최대한 잘 구분해야 합니다.

이 두 네트워크는 서로를 상대방보다 더 잘 수행하기 위해 경쟁하는 구조를 갖습니다. 생성자는 점차적으로 더 그럴듯한 데이터를 생성하기 위해 개선되고, 판별자는 더 정교하게 생성된 데이터를 판별할 수 있도록 훈련됩니다.

2. GAN의 훈련 과정

GAN의 훈련 과정은 다음과 같은 단계로 진행됩니다:

  1. 무작위 노이즈 벡터를 생성하여 생성자에 입력합니다.
  2. 생성자는 노이즈 벡터를 가짜 샘플로 변환합니다.
  3. 판별자는 실제 데이터와 생성된 가짜 데이터 모두를 입력으로 받아들입니다.
  4. 판별자는 각 샘플이 실제 데이터인지 가짜 데이터인지를 예측합니다.
  5. 생성자는 판별자가 가짜 샘플을 실제 데이터라고 판단하게 만들기 위해 손실함수를 통해 업데이트됩니다. 반대로, 판별자는 실제 데이터와 가짜 데이터를 잘 구분하기 위해 업데이트됩니다.

3. 파이토치를 이용한 GAN 구현

이제 파이토치를 사용하여 GAN을 구현하는 코드를 작성해보겠습니다. 우리는 MNIST 데이터셋을 사용하여 손글씨 숫자를 생성하는 GAN을 만들어보겠습니다.

3.1 필요한 라이브러리 설치

pip install torch torchvision matplotlib

3.2 데이터셋 준비

MNIST 데이터셋을 로드해보겠습니다. PyTorch에서는 torchvision 라이브러리를 통해 쉽게 데이터를 다운로드할 수 있습니다.

import torch
from torchvision import datasets, transforms

# 데이터 변환
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 데이터셋 다운로드 및 로드
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

3.3 생성자(Generator)와 판별자(Discriminator) 구현

이제 GAN의 두 핵심 구성 요소인 생성자와 판별자를 정의해보겠습니다. 생성자는 간단한 완전연결 신경망을 사용하고, 판별자는 CNN을 사용하여 이미지를 처리하도록 하겠습니다.

import torch.nn as nn

# 생성자 모델
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# 판별자 모델
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

3.4 손실함수 및 최적화 알고리즘 설정

GAN의 훈련을 위해서 손실함수와 최적화 알고리즘을 정의하겠습니다. 일반적으로 생성자와 판별자는 각각 다른 손실함수를 사용합니다. 간단한 이진 교차 엔트로피 손실을 사용할 것입니다.

criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(Generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(Discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

3.5 GAN 훈련 루프

이제 GAN 훈련을 위한 루프를 구현해보겠습니다. 여기서는 일정 수의 에포크 동안 생성자와 판별자를 번갈아가며 훈련합니다.

def train_gan(generator, discriminator, train_loader, num_epochs=100):
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(train_loader):
            batch_size = real_images.size(0)

            # 진짜 라벨과 가짜 라벨 정의
            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)

            # 판별자 훈련
            discriminator.zero_grad()
            outputs = discriminator(real_images)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, 100)
            fake_images = generator(z)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()

            optimizer_D.step()

            # 생성자 훈련
            generator.zero_grad()
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}')

# GAN 훈련 시작
generator = Generator()
discriminator = Discriminator()
train_gan(generator, discriminator, train_loader)

4. 결과 시각화

훈련이 진행된 후에는 생성된 이미지들을 시각화하여 결과를 확인해볼 수 있습니다.

import matplotlib.pyplot as plt

def show_generated_images(generator, num_images=25):
    z = torch.randn(num_images, 100)
    generated_images = generator(z).detach().numpy()
    
    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i][0], cmap='gray')
        plt.axis('off')
    plt.show()

# 생성된 이미지 보여주기
show_generated_images(generator)

5. 결론

이번 포스트에서는 GAN의 기본 개념과 파이토치를 이용한 간단한 구현 방법을 알아보았습니다. GAN은 매우 강력한 생성 모델로, 다양한 데이터 생성 문제에 적용할 수 있습니다. 하지만 GAN의 훈련은 불안정할 수 있으며, 다양한 기법과 하이퍼파라미터 조정이 필요할 수 있습니다. 더 복잡한 GAN 아키텍처들(예: DCGAN, WGAN)도 탐색해보면 흥미로운 결과를 얻을 수 있습니다.

이제 여러분은 GAN의 기본적인 동작 방식과 이를 파이토치로 구현하는 방법을 알게 되었습니다. 이를 바탕으로 더 나아가 다양한 예제를 시도해보시기 바랍니다!