파이토치를 활용한 GAN 딥러닝, GAN 소개

1. GAN(Generative Adversarial Network) 소개

GAN(Generative Adversarial Network)은 2014년에 Ian Goodfellow가 처음 제안한 딥러닝 모델로,
두 개의 신경망인 생성자(Generator)와 판별자(Discriminator)가 서로 경쟁하는 구조로 이루어져 있습니다.
생성자는 가짜 데이터를 생성하고, 판별자는 이 데이터가 진짜인지 가짜인지 판단하는 역할을 합니다.
이 두 네트워크는 서로의 성능을 개선하기 위해 지속적으로 학습합니다.

GAN의 핵심 아이디어는 “적대적 학습”(Adversarial Training)입니다.
생성자는 판별자가 진짜 데이터와 가짜 데이터를 잘 구분하지 못하게 하기 위해 계속해서 더 그럴듯한 가짜 데이터를 생성하게 됩니다.
반면, 판별자는 생성자가 만든 데이터가 진짜인지 가짜인지 정확하게 판단하기 위해 더욱 정교하게 학습합니다.
이러한 경쟁 구조는 GAN의 독특한 특징이며, 창의적인 이미지 생성, 비디오 생성, 텍스트 생성 등 다양한 분야에서 활용되고 있습니다.

2. GAN의 구조와 학습 과정

GAN의 학습 과정은 다음과 같은 단계로 이루어집니다:

  1. 데이터 수집: GAN은 대량의 데이터를 필요로 합니다. 일반적으로 실제 데이터셋에서 샘플을 사용합니다.
  2. 생성자(Generator) 훈련: 생성자는 노이즈(z)를 입력받아 가짜 이미지(또는 데이터)를 생성합니다.
  3. 판별자(Discriminator) 훈련: 판별자는 진짜 이미지와 생성자가 만든 가짜 이미지를 입력받아 이들이 진짜인지 가짜인지를 예측합니다.
  4. 손실 함수 계산: 생성자와 판별자의 성능을 평가하기 위해 손실 함수를 계산합니다.
    생성자의 목표는 판별자를 속이는 것이고, 판별자의 목표는 생성자가 만든 가짜 이미지를 정확하게 판단하는 것입니다.
  5. 모델 업데이트: 손실 함수에 기반하여 생성자와 판별자 모두 최적화 알고리즘을 통해 모델 파라미터를 업데이트합니다.
  6. 반복: 2~5 단계를 반복하여 두 네트워크가 상호 개선될 수 있도록 합니다.

이런 방식으로 생성자는 점점 더 나은 이미지를 생성하고, 판별자는 이를 잘 구분할 수 있게 됩니다.
이 과정이 반복되면서 결국 생성자는 매우 현실적인 데이터를 생성할 수 있는 수준에 도달하게 됩니다.

3. GAN을 구현하는 방법

이제 GAN을 파이토치(PyTorch)를 사용하여 구현해보겠습니다.
이번 예제에서는 간단한 GAN을 만들어서 손글씨 숫자 데이터셋인 MNIST를 사용해보겠습니다.
MNIST는 0에서 9까지의 숫자가 포함된 70,000개의 흑백 이미지 데이터로 구성되어 있습니다.
우리가 생성하려는 목표는 이 숫자 이미지를 생성하는 것입니다.

3.1. 필수 라이브러리 설치

먼저 파이토치와 기타 필요한 라이브러리를 설치합니다.
아래 구문을 사용하여 필요한 패키지를 설치할 수 있습니다.

!pip install torch torchvision matplotlib

3.2. 데이터셋 로드 및 전처리

이제 MNIST 데이터셋을 로드하고, Tensor 형태로 변환한 후, 훈련을 위해 준비하겠습니다.


import torch
from torchvision import datasets, transforms

# 데이터 변환 설정
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# MNIST 데이터셋 다운로드 및 로드
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

3.3. GAN의 생성자와 판별자 정의하기

GAN의 생성자와 판별자를 정의하겠습니다.
생성자는 랜덤 노이즈를 입력받아 이미지를 생성하고, 판별자는 주어진 이미지가 진짜인지 가짜인지 판단합니다.


import torch.nn as nn

# 생성자 정의
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh() # -1 ~ 1로 출력을 정규화
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# 판별자 정의
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid() # 0 ~ 1로의 출력을 정규화
        )

    def forward(self, img):
        return self.model(img)

3.4. 손실 함수 및 최적화 알고리즘 설정

GAN의 손실 함수는 두 개의 손실로 구성됩니다.
생성자의 손실과 판별자의 손실을 설정하고, 두 신경망의 최적화 알고리즘을 정의하겠습니다.


import torch.optim as optim

# 모델 초기화
generator = Generator()
discriminator = Discriminator()

# 손실 함수 및 최적화 알고리즘 설정
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

3.5. GAN 훈련하기

이제 실제로 GAN을 훈련시켜보겠습니다.
훈련 과정에서는 생성자와 판별자가 교대로 훈련됩니다.


import matplotlib.pyplot as plt

def train_gan(num_epochs):
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(train_loader):
            # 진짜 이미지에 대한 라벨
            real_imgs = imgs
            real_labels = torch.ones(real_imgs.size(0), 1)
            fake_labels = torch.zeros(real_imgs.size(0), 1)

            # 판별자 훈련
            optimizer_D.zero_grad()
            outputs = discriminator(real_imgs)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(real_imgs.size(0), 100)
            fake_imgs = generator(z)
            outputs = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_D.step()

            # 생성자 훈련
            optimizer_G.zero_grad()
            outputs = discriminator(fake_imgs)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

        if epoch % 100 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item():.4f}, g_loss: {g_loss.item():.4f}')

            # 생성된 이미지 출력
            with torch.no_grad():
                generated_images = generator(torch.randn(64, 100)).detach().cpu()
                plt.figure(figsize=(10, 10))
                plt.imshow(torchvision.utils.make_grid(generated_images, nrow=8, normalize=True).permute(1, 2, 0))
                plt.axis('off')
                plt.show()

train_gan(num_epochs=1000)

4. 결론

GAN은 매우 강력한 생성 모델로, 다양한 분야에서 응용되고 있습니다.
이번 튜토리얼에서는 파이토치를 이용하여 GAN을 구현하는 방법을 살펴보았습니다.
생성자와 판별자가 서로 경쟁하며 학습하는 방식으로 GAN은 고품질의 데이터를 생성할 수 있게 됩니다.
실제 응용을 위해서는 여러 가지 기법(예: 조건부 GAN, 스타일 GAN 등)을 활용하여 성능을 개선할 수 있습니다.

앞으로 더 발전된 GAN 아키텍처와 그 활용에 대해서도 이야기해보겠습니다.
GAN은 현재도 활발히 연구되고 있으며, 새로운 방식의 GAN이 계속 발표되고 있으니 이에 대한 업데이트도 주목할 필요가 있습니다.