딥러닝 파이토치 강좌, GAN 구현

이번 강좌에서는 GAN(Generative Adversarial Network)을 파이토치(PyTorch)로 구현하는 방법에 대해 심도 깊은 설명을 드리겠습니다. GAN은 좋은 생성 모델을 학습시키는 도구로서, 이미지 생성, 스타일 변환, 데이터 보강 등 다양한 분야에서 사용되고 있습니다. 강좌에서는 GAN의 기본 개념부터 시작하여, 각 구성 요소를 구현하고, 마지막으로 실제적인 예제를 통해 GAN이 어떻게 작동하는지 이해해보겠습니다.

1. GAN의 기본 개념

GAN은 두 가지 주요 구성 요소, 즉 생성자(Generator)와 판별자(Discriminator)로 구성됩니다. 이 두 모델은 서로 경쟁하면서 학습하게 되며, 이 과정이 GAN의 핵심입니다.

1.1 생성자(Generator)

생성자의 역할은 랜덤 노이즈를 입력 받아서 실제 데이터와 유사한 가짜 데이터를 생성하는 것입니다. 이 모델은 실제 데이터를 모방하는 방법을 학습합니다.

1.2 판별자(Discriminator)

판별자는 입력된 데이터가 실제 데이터인지 아니면 생성자가 만들어낸 가짜 데이터인지를 판별하는 역할을 합니다. 이 모델은 실제 데이터와 가짜 데이터를 구별하는 방법을 학습합니다.

1.3 GAN의 학습 과정

GAN의 학습은 생성자와 판별자가 서로 경쟁하는 방식으로 진행됩니다. 생성자는 판별자를 속이기 위해 점점 더 나은 가짜 데이터를 생성하려고 하며, 판별자는 이러한 가짜 데이터를 인식하기 위해 노력합니다. 이 과정이 반복되면서 두 모델은 점점 더 발전하게 됩니다.

2. GAN의 구성요소 구현하기

이제 GAN을 구현하기 위해 필요한 주요 구성 요소들을 코드를 통해 구현해보겠습니다. 여기서는 간단한 GAN을 구현하고, MNIST 숫자 데이터셋을 사용하여 손글씨 숫자를 생성하는 모델을 만들어보겠습니다.

2.1 환경 설정

우선, 필요한 라이브러리를 설치하고, MNIST 데이터셋을 다운로드하여 준비합니다.

!pip install torch torchvision matplotlib
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

2.2 데이터셋 로드

MNIST 데이터셋을 로드하고, 전처리를 수행합니다.

# 데이터셋 준비
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)

2.3 생성자(Generator) 모델 구현

생성자는 입력된 노이즈 벡터를 받아서 이미지로 변환하는 신경망입니다.

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

2.4 판별자(Discriminator) 모델 구현

판별자는 입력된 이미지를 실제 이미지인지 가짜 이미지인지 판별하는 모델입니다.

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid() 
        )
    
    def forward(self, img):
        return self.model(img.view(-1, 28 * 28))

2.5 모델 초기화

생성자와 판별자 모델을 초기화하고, 손실 함수와 옵티마이저를 정의합니다.

generator = Generator()
discriminator = Discriminator()

criterion = nn.BCELoss()
optimizer_gen = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_disc = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

2.6 GAN 학습 루프

이제 GAN의 학습 루프를 구현해보겠습니다. 생성자와 판별자의 손실을 계산하고, 옵티마이저를 통해 가중치를 업데이트합니다.

def train_gan(num_epochs):
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(dataloader):
            z = torch.randn(imgs.size(0), 100)
            real_labels = torch.ones(imgs.size(0), 1)
            fake_labels = torch.zeros(imgs.size(0), 1)

            # 판별자 학습
            optimizer_disc.zero_grad()
            outputs = discriminator(imgs)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            fake_imgs = generator(z)
            outputs = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_disc.step()

            # 생성자 학습
            optimizer_gen.zero_grad()
            outputs = discriminator(fake_imgs)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_gen.step()

        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item():.4f}, g_loss: {g_loss.item():.4f}')

3. GAN 실행하기

이제 GAN을 학습시키고 생성된 이미지 결과를 시각화해보겠습니다.

num_epochs = 100
train_gan(num_epochs)

def show_generated_images(generator, num_images=16):
    z = torch.randn(num_images, 100)
    fake_images = generator(z).detach()
    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(4, 4, i + 1)
        plt.imshow(fake_images[i][0], cmap='gray')
        plt.axis('off')
    plt.show()

show_generated_images(generator)

4. 결론

이번 강좌에서는 GAN의 기본 개념과 파이토치를 사용한 간단한 GAN 모델 구현 과정을 살펴보았습니다. GAN은 이미지 생성, 스타일 변환 등 다양한 분야에 응용될 수 있으며, GAN을 통해 인공지능의 가능성을 더욱 확장할 수 있습니다. 이 강좌를 바탕으로 더 복잡한 GAN 변형 모델들을 탐구해보는 것도 좋습니다.

이상이 딥러닝 파이토치를 활용한 GAN 구현에 대한 강좌였습니다. 학습 중 궁금한 점이나 더 알고 싶은 내용이 있다면 언제든지 댓글로 문의해 주세요!