파이토치를 활용한 GAN 딥러닝, WGAN – 와서스테인 GAN

딥러닝의 발전과 함께 이미지 생성, 강화학습, 이미지 변환, 이미지 결합 등 다양한 분야에서 Generative Adversarial Networks(GANs)의 사용이 증가하고 있습니다. GAN은 두 네트워크, 즉 생성자(Generator)와 판별자(Discriminator) 간의 경쟁을 통해 고해상도의 이미지를 생성하는데 사용됩니다. 본 글에서는 GAN의 기본 개념과 함께 WGAN(Wasserstein GAN)의 구조와 작동 방식, 이를 구현하기 위한 파이토치 예제 코드를 다루겠습니다.

1. GAN의 기본 개념

GAN은 Ian Goodfellow가 2014년에 제안한 모델로, 생성자와 판별자라는 두 개의 신경망으로 구성됩니다. 생성자는 난수 벡터를 입력받아 실제와 유사한 데이터를 생성하고, 판별자는 입력 데이터가 실제 데이터인지 생성된 데이터인지를 판단합니다. 이 과정에서 두 신경망은 서로 경쟁하면서 점점 완벽한 데이터를 생성하는 방향으로 학습하게 됩니다.

1.1 GAN의 구조

  • Generator (G): 랜덤 노이즈를 입력으로 받아 데이터를 생성하는 네트워크입니다.
  • Discriminator (D): 실제 데이터와 생성된 데이터 간의 차이를 구별하는 네트워크입니다.

1.2 GAN의 손실 함수

GAN의 손실 함수는 다음과 같습니다.

    L(D) = -E[log(D(x))] - E[log(1 - D(G(z)))],
    L(G) = -E[log(D(G(z)))]
    

여기서 D(x)는 실제 데이터의 진짜로 판단할 확률이며, G(z)는 생성자가 생성한 데이터입니다.

2. WGAN – 와서스테인 GAN

기존의 GAN은 판별자의 손실 함수가 안정적이지 않고 학습이 불안정하다는 문제점이 있었습니다. WGAN은 Wasserstein Distance를 사용하여 이러한 문제점을 해결합니다. Wasserstein 거리(혹은 Earth Mover’s Distance)는 두 확률 분포 간의 최적 운반 비용을 측정하는 방법입니다.

2.1 WGAN의 개선점

  • WGAN은 판별자 대신 비선형 회귀모델인 ‘Critic’을 사용합니다.
  • WGAN의 손실 함수는 다음과 같습니다:
                L(D) = E[D(x)] - E[D(G(z))],
                L(G) = -E[D(G(z))]
                
  • WGAN은 Weight Clipping을 통해 Critic의 Lipschitz 연속성을 보장합니다.
  • Gradient Penalty 기법을 사용하여 Lipschitz 제약 조건을 완화합니다.

2.2 WGAN의 구조

WGAN은 기본 GAN의 구조에서 Critic을 도입해 수정된 형태입니다. 다음은 WGAN의 네트워크 구조입니다:

  • 이전의 판별자는 현재의 Critic으로 대체됩니다.

3. 파이토치를 활용한 WGAN 구현

이제 파이토치로 WGAN을 구현해 보겠습니다. 이번 예제는 MNIST 데이터셋을 사용하여 손글씨 숫자를 생성하는모델을 구축할 것입니다.

3.1 데이터셋 준비

먼저 데이터셋을 불러오고 전처리합니다.


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 데이터셋을 로드하고 전처리합니다.
transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
    

3.2 WGAN 모델 정의

이제 Generator와 Critic 모델을 정의할 차례입니다.


# 생성자 모델 정의
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
        
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)  # 28x28 이미지로 변형

# 비평가(critic) 모델 정의
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )
    
    def forward(self, img):
        return self.model(img.view(-1, 784))  # 784 차원으로 변형
    

3.3 WGAN 학습 과정

이제 WGAN의 학습 과정을 정의합니다.


def train_wgan(num_epochs):
    generator = Generator()
    critic = Critic()
    
    # 옵티마이저 설정
    optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005)
    optimizer_C = optim.RMSprop(critic.parameters(), lr=0.00005)

    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(train_loader):
            imgs = imgs.to(device)

            # Critic 잔여 방정식
            optimizer_C.zero_grad()
            z = torch.randn(imgs.size(0), 100).to(device)
            fake_imgs = generator(z)
            c_real = critic(imgs)
            c_fake = critic(fake_imgs.detach())
            c_loss = c_fake.mean() - c_real.mean()
            c_loss.backward()
            optimizer_C.step()

            # Weight Clipping
            for p in critic.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Generator 업데이트
            if i % 5 == 0:
                optimizer_G.zero_grad()
                g_loss = -critic(fake_imgs).mean()
                g_loss.backward()
                optimizer_G.step()
            
        print(f'Epoch [{epoch}/{num_epochs}], Loss C: {c_loss.item()}, Loss G: {g_loss.item()}')

# GPU 사용 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_wgan(num_epochs=50)
    

3.4 결과 시각화

학습이 끝나면 생성된 이미지를 시각화하여 결과를 확인합니다.


import matplotlib.pyplot as plt

def show_generated_images(num_images):
    z = torch.randn(num_images, 100).to(device)
    generated_imgs = generator(z).cpu().detach()
    
    fig, axes = plt.subplots(1, num_images, figsize=(15, 15))
    for i in range(num_images):
        axes[i].imshow(generated_imgs[i][0], cmap='gray')
        axes[i].axis('off')
    plt.show()

# 결과 시각화
show_generated_images(5)
    

4. 결론

WGAN은 기존 GAN의 문제점을 극복하기 위해 Wasserstein Distance를 이용하여 더 안정적인 학습 과정을 제공합니다. 본 글에서는 파이토치를 활용하여 WGAN을 구현하는 방법을 소개하였으며, 이를 통해 생성적 적대 신경망의 이해를 높일 수 있기를 바랍니다. GAN과 그 변형 모델들은 이미지 생성뿐만 아니라 다양한 분야에서 혁신적인 결과를 가져올 수 있는 강력한 도구입니다.

5. 참고문헌

  • Ian J. Goodfellow et al., “Generative Adversarial Nets”, 2014.
  • Martin Arjovsky et al., “Wasserstein Generative Adversarial Networks”, 2017.
  • PyTorch Documentation: https://pytorch.org/docs/stable/index.html