파이토치를 활용한 GAN 딥러닝, 애니멀간

1. 서론

Generative Adversarial Networks (GANs)는 두 개의 신경망, 즉 생성자(Generator)와 판별자(Discriminator)가 서로 대립하면서 학습하는 모델입니다. 이러한 구조는 이미지 생성, 변환 및 스타일 전이와 같은 다양한 진보된 딥러닝 응용 분야에서 큰 주목을 받고 있습니다. 본 글에서는 PyTorch를 활용한 GAN의 기본 원리와 이를 통해 동물 이미지를 생성하는 애니멀간(AnimalGAN)에 대해 자세히 다뤄보겠습니다.

2. GAN의 기본 원리

GAN은 주로 두 개의 신경망으로 구성됩니다. 생성자는 무작위 노이즈 벡터를 입력 받아 가짜 이미지를 생성하고, 판별자는 진짜 이미지와 생성된 가짜 이미지를 구분합니다. 두 신경망은 서로의 학습을 방해하면서 최적화됩니다. 이 과정은 게임 이론에서의 ‘제로섬 게임’과 비슷합니다. 생성자는 판별자가 보지 못하게 만들기 위해 계속해서 개선하고, 판별자는 생성자가 만든 이미지의 진위를 판단하는 데 향상됩니다.

2.1 GAN의 학습 과정

학습 과정은 다음과 같은 단계로 진행됩니다:

  1. 진짜 데이터로 판별자를 학습시킵니다.
  2. 무작위 노이즈를 생성하고, 이를 기반으로 생성자를 통해 가짜 이미지를 만듭니다.
  3. 가짜 이미지로 또다시 판별자를 학습시킵니다.
  4. 위 과정을 반복합니다.

3. PyTorch를 이용한 GAN 구현

이제 PyTorch를 사용하여 간단한 GAN을 구현해보겠습니다. 전체 프로세스는 준비 단계, 모델 구현, 학습, 생성된 이미지 시각화로 나눌 수 있습니다.

3.1 환경 설정

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
    

3.2 데이터셋 준비

애니멀간 프로젝트에서는 CIFAR-10 또는 동물 이미지 데이터셋을 사용할 수 있습니다. 여기서는 CIFAR-10 데이터셋을 로드해 보겠습니다.

python
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# CIFAR-10 데이터셋 로드
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    

3.3 GAN 모델 구현

GAN 모델은 생성자와 판별자로 구성됩니다. 생성자는 노이즈 벡터를 입력 받고 이미지를 생성하고, 판별자는 이미지가 진짜인지 가짜인지 판별하는 역할을 합니다.

python
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 3 * 64 * 64),  # CIFAR-10 이미지 크기
            nn.Tanh()  # 출력 범위를 [-1, 1]로
        )

    def forward(self, z):
        return self.model(z).view(-1, 3, 64, 64)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(3 * 64 * 64, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 출력이 [0, 1] 사이에 있도록
        )

    def forward(self, img):
        return self.model(img.view(-1, 3 * 64 * 64))
    

3.4 모델 학습

GAN의 학습 과정은 판별자와 생성자를 번갈아 가며 학습시키는 방식으로 진행됩니다. 다음 코드를 통해 GAN을 학습해 보겠습니다.

python
# 모델, 손실 함수 및 옵티마이저 정의
generator = Generator().cuda()
discriminator = Discriminator().cuda()
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 학습 루프
num_epochs = 50
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # 진짜 이미지 레이블 및 가짜 이미지 레이블 설정
        real_imgs = imgs.cuda()
        batch_size = real_imgs.size(0)
        labels_real = torch.ones(batch_size, 1).cuda()
        labels_fake = torch.zeros(batch_size, 1).cuda()

        # 판별자 학습
        optimizer_D.zero_grad()
        outputs_real = discriminator(real_imgs)
        loss_real = criterion(outputs_real, labels_real)

        z = torch.randn(batch_size, 100).cuda()  # 노이즈 생성
        fake_imgs = generator(z)
        outputs_fake = discriminator(fake_imgs.detach())
        loss_fake = criterion(outputs_fake, labels_fake)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # 생성자 학습
        optimizer_G.zero_grad()
        outputs_fake = discriminator(fake_imgs)
        loss_G = criterion(outputs_fake, labels_real)  # 가짜 이미지가 진짜로 판단되도록 학습
        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch}/{num_epochs}], Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}')
    

3.5 결과 시각화

학습이 완료된 후 생성된 이미지를 시각화하여 GAN의 성능을 평가할 수 있습니다. 다음은 몇 개의 생성된 이미지를 시각화하는 코드입니다.

python
def show_generated_images(model, num_images=25):
    z = torch.randn(num_images, 100).cuda()
    with torch.no_grad():
        generated_imgs = model(z)
    generated_imgs = generated_imgs.cpu().numpy()
    generated_imgs = (generated_imgs + 1) / 2  # [0, 1] 범위로 변환

    fig, axes = plt.subplots(5, 5, figsize=(10, 10))
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(generated_imgs[i].transpose(1, 2, 0))  # 채널 순서를 이미지에 맞게 변경
        ax.axis('off')
    plt.tight_layout()
    plt.show()

show_generated_images(generator)
    

4. 결론

본 글에서는 파이토치를 사용하여 GAN을 통해 동물 이미지를 생성하는 애니멀간(AnimalGAN)을 구현하였습니다. GAN의 기본 원리를 이해하고 실제로 코드를 통해 결과를 확인함으로써, GAN의 개념과 작동 방식을 명확히 이해할 수 있었습니다. GAN은 여전히 연구가 활발히 이루어지고 있는 분야이며, 보다 진보된 모델과 기술들이 지속적으로 등장하고 있습니다. 이와 같은 다양한 시도를 통해 우리는 더 많은 가능성을 탐색할 수 있을 것입니다.