생성적 적대 신경망(Generative Adversarial Networks, GAN)은 Ian Goodfellow가 2014년에 제안한 혁신적인 딥러닝 모델로, 두 개의 신경망이 서로 대립하며 학습하는 방법론입니다. GAN은 주로 이미지 생성, 텍스트 생성, 비디오 생성 등 여러 분야에서 광범위하게 사용됩니다. 이번 포스트에서는 PyTorch를 사용하여 GAN의 기본 개념과 구현 방법에 대해 단계적으로 설명하겠습니다.
1. GAN의 기본 개념
GAN은 생성자(Generator)와 판별자(Discriminator)라는 두 개의 신경망으로 구성됩니다. 생성자의 역할은 진짜처럼 보이는 데이터를 생성하는 것이고, 판별자는 주어진 데이터가 진짜인지 생성자가 만든 위조 데이터인지 판별하는 역할을 합니다. 이 두 신경망은 동시에 학습하며, 생성자는 판별자를 속이기 위해 더욱 정교한 데이터를 생성하도록 발전하고, 판별자는 더욱 정교하게 위조 데이터를 식별하도록 발전합니다.
1.1 GAN의 구조
GAN의 구조는 다음과 같이 단순히 설명할 수 있습니다:
- 생성자(Generator): 랜덤 노이즈를 입력받아 진짜 같은 데이터를 생성합니다.
- 판별자(Discriminator): 진짜 데이터와 생성된 위조 데이터를 입력받아, 각각의 데이터가 진짜인지 아닌지를 예측합니다.
1.2 GAN의 학습 과정
GAN의 학습 과정은 다음과 같이 진행됩니다:
- 진짜 데이터와 랜덤 노이즈를 사용하여 생성자(G)를 사용하여 위조 데이터를 생성합니다.
- 생성된 데이터와 진짜 데이터를 판별자(D)에 입력하고, 각각의 데이터에 대한 예측 값을 얻습니다.
- 생성자의 손실 함수는 판별자가 위조 데이터를 진짜로 판단할 확률을 극대화하는 방향으로 설정됩니다.
- 판별자의 손실 함수는 진짜 데이터를 진짜로, 위조 데이터를 위조로 판단할 확률을 극대화하는 방향으로 설정됩니다.
- 이 과정을 반복하여 두 네트워크가 서로 경쟁하며 성능이 향상됩니다.
2. PyTorch를 이용한 GAN 구현
이제 PyTorch를 사용하여 간단한 GAN을 구현해보겠습니다. 여기서는 MNIST 데이터셋을 이용해 수치 형태의 이미지를 생성하는 GAN을 만드는 작업을 진행할 것입니다.
2.1 환경 설정
첫 번째로, 필요한 라이브러리를 설치하고 임포트합니다. PyTorch 및 torchvision을 사용하여 데이터셋을 로드하고 모델을 구축합니다.
!pip install torch torchvision matplotlib
2.2 데이터셋 준비
MNIST 데이터셋을 로드하고 데이터 전처리를 수행하겠습니다. 이를 통해 이미지 데이터를 0~1 사이로 스케일링하고, 배치 단위로 나누도록 하겠습니다.
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 데이터 로딩 및 전처리
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
2.3 생성자(Generator) 및 판별자(Discriminator) 정의
다음으로, GAN의 두 핵심 구성 요소인 생성자와 판별자를 정의하겠습니다. 여기서 생성자는 랜덤 노이즈를 입력받아 이미지를 생성하고, 판별자는 이미지를 입력받아 그 이미지가 진짜인지 위조인지 판단합니다.
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28),
nn.Tanh()
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
2.4 모델 초기화 및 손실 함수, 옵티마이저 설정
생성자와 판별자를 초기화하고, 손실 함수와 옵티마이저를 지정하도록 하겠습니다. CrossEntropyLoss 및 Adam 옵티마이저를 사용할 것입니다.
generator = Generator()
discriminator = Discriminator()
ad = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
ag = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
2.5 GAN 학습
이제 GAN을 학습시켜보겠습니다. 각 에폭마다 생성자와 판별자를 학습시키고, 생성된 이미지를 볼 수 있습니다.
import matplotlib.pyplot as plt
import numpy as np
def train_gan(generator, discriminator, criterion, ag, ad, dataloader, epochs=50):
for epoch in range(epochs):
for real_imgs, _ in dataloader:
batch_size = real_imgs.size(0)
# 진짜 이미지와 레이블 생성
real_labels = torch.ones(batch_size, 1)
noise = torch.randn(batch_size, 100)
fake_imgs = generator(noise)
fake_labels = torch.zeros(batch_size, 1)
# 판별자 학습
discriminator.zero_grad()
real_loss = criterion(discriminator(real_imgs), real_labels)
fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
ad.step()
# 생성자 학습
generator.zero_grad()
g_loss = criterion(discriminator(fake_imgs), real_labels)
g_loss.backward()
ag.step()
print(f'Epoch [{epoch + 1}/{epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')
# 생성된 이미지 저장
if (epoch + 1) % 10 == 0:
save_generated_images(generator, epoch + 1)
def save_generated_images(generator, epoch):
noise = torch.randn(64, 100)
generated_imgs = generator(noise)
generated_imgs = generated_imgs.detach().numpy()
generated_imgs = (generated_imgs + 1) / 2 # Rescale to [0, 1]
fig, axs = plt.subplots(8, 8, figsize=(8, 8))
for i, ax in enumerate(axs.flat):
ax.imshow(generated_imgs[i][0], cmap='gray')
ax.axis('off')
plt.savefig(f'generated_images_epoch_{epoch}.png')
plt.close()
train_gan(generator, discriminator, criterion, ag, ad, dataloader, epochs=50)
2.6 결과 확인
학습이 완료된 후, 생성된 이미지를 확인해보세요. GAN은 반복 학습을 통해 점점 더 실제와 유사한 데이터 이미지를 생성하게 됩니다. 결과적으로, GAN의 성능을 평가하는 것은 생성된 이미지의 품질입니다. 훈련이 잘되었을 경우, 생성된 이미지들이 생소하면서도 아름다운 형태를 가질 것입니다.
3. 결론
이번 포스트에서는 PyTorch를 활용하여 GAN을 구현하는 방법에 대해 설명했습니다. GAN의 기본 개념과 함께 실제 코드를 통해 자신만의 GAN을 만들어보는 경험을 할 수 있었기를 바랍니다. GAN은 강력한 도구이지만, 견고한 모델을 구축하는 데에는 다양하고 심도 있는 연구가 필요합니다. 아름답고 창의적인 결과를 만들어내는 GAN의 세상으로 여러분을 초대합니다!
4. 참고 자료
- Ian Goodfellow et al. “Generative Adversarial Networks”. NIPS 2014.
- PyTorch Documentation: https://pytorch.org/docs/stable/index.html
- GANs in Action book