최근 몇 년 간 딥러닝은 이미지 생성, 변환, 분할 등 다양한 분야에서 혁신적인 발전을 이루었습니다. 그중에서도 GAN(Generative Adversarial Network)은 이미지 생성의 새로운 가능성을 열어주었습니다. GAN은 생성자(Generator)와 판별자(Discriminator)로 구성된 두 개의 네트워크가 서로 경쟁하며 성능을 향상시키는 구조입니다. 본 포스팅에서는 GAN의 개요 및 파이토치(PyTorch) 프레임워크를 이용하여 GAN을 구현하기 위한 환경 설정 방법에 대해 자세히 알아보겠습니다.
1. GAN 개요
GAN은 Ian Goodfellow가 2014년에 제안한 모델로, 두 개의 신경망이 상호작용하여 훈련되는 방식입니다. 생성자는 실제와 유사한 데이터를 생성하고, 판별자는 생성된 데이터가 실제 데이터인지 아닌지를 판단합니다. 이 두 네트워크는 서로 견제하며 점점 더 발전하게 됩니다.
1.1 GAN의 구조
GAN은 다음과 같은 두 가지 구성 요소로 이루어져 있습니다:
- 생성자(Generator): 무작위 노이즈 벡터를 입력으로 받아 가짜 데이터를 생성합니다.
- 판별자(Discriminator): 입력으로 받은 데이터가 진짜인지 가짜인지 판별합니다.
1.2 GAN의 수학적 정의
GAN의 목표는 Minimax 게임으로 표현할 수 있습니다. 생성자는 다음과 같은 목표를 가지고 있습니다:
G^{*} = arg \min_{G} \max_{D} V(D, G) = E_{x \sim pdata(x)}[\log D(x)] + E_{z \sim pz(z)}[\log(1 - D(G(z)))]
여기서 G는 생성자, D는 판별자를 의미하며, pdata(x)는 실제 데이터의 분포, pz(z)는 생성자가 사용하는 노이즈 분포입니다.
2. 파이토치(PyTorch) 환경 설정
파이토치는 텐서 연산, 자동 미분, 그리고 딥러닝 모델을 손쉽게 구축할 수 있는 여러 도구를 제공하는 오픈소스 머신러닝 라이브러리입니다. GAN 구현을 위해 파이토치를 설치하고 필요한 라이브러리를 구성하는 방법은 다음과 같습니다.
2.1 파이토치 설치하기
파이토치는 CUDA를 지원하여 NVIDIA GPU에서 효율적으로 동작할 수 있습니다. 아래의 명령어를 통해 설치할 수 있습니다.
pip install torch torchvision torchaudio
만약 CUDA를 사용하고 있다면, 파이토치 공식 홈페이지에서 환경에 맞는 설치 명령어를 확인하여 설치하기 바랍니다.
2.2 추가 라이브러리 설치하기
이미지 처리를 위해 필요한 추가 라이브러리도 설치해야 합니다. 아래 명령어를 통해 설치합니다:
pip install matplotlib numpy
2.3 기본적인 디렉토리 구조 설정하기
프로젝트 디렉토리를 다음과 같은 구조로 만들어 줍니다:
gan_project/
├── dataset/
├── models/
├── results/
└── train.py
각각의 디렉토리는 데이터셋, 모델, 결과물을 저장하는 역할을 합니다. train.py
파일은 GAN을 학습시키고 평가하는 스크립트를 담고 있습니다.
3. GAN 구현에 필요한 코드 예제
이제 GAN을 구현하기 위한 기본적인 코드를 작성해보겠습니다. 이 코드는 생성자와 판별자를 정의하고, 이를 이용하여 GAN을 훈련하는 과정을 포함하고 있습니다.
3.1 모델 정의하기
먼저 생성자와 판별자 네트워크를 정의합니다. 아래 코드는 간단한 CNN(컨볼루션 신경망)을 사용하여 생성자와 판별자를 구축하는 예제입니다:
import torch
import torch.nn as nn
# 생성자 모델
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28), # MNIST 이미지 크기
nn.Tanh() # [-1, 1]로 정규화
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
# 판별자 모델
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # [0, 1]로 정규화
)
def forward(self, img):
return self.model(img)
3.2 데이터셋 준비하기
MNIST 데이터셋을 가져오고 전처리합니다. torchvision 라이브러리를 이용하면 쉽게 데이터셋을 로드할 수 있습니다.
from torchvision import datasets, transforms
# 데이터 전처리
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# MNIST 데이터셋 로드
dataloader = torch.utils.data.DataLoader(
datasets.MNIST('dataset/', download=True, transform=transform),
batch_size=64,
shuffle=True
)
3.3 GAN 훈련 코드
이제 생성자와 판별자를 훈련할 수 있는 루프를 만들어보겠습니다.
import torch.optim as optim
# 모델 초기화
generator = Generator()
discriminator = Discriminator()
# 손실 함수 및 옵티마이저
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
num_epochs = 50
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
batch_size = imgs.size(0)
imgs = imgs.view(batch_size, -1)
# 진짜와 가짜 레이블 생성
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 판별자 훈련
optimizer_d.zero_grad()
outputs = discriminator(imgs)
d_loss_real = criterion(outputs, real_labels)
d_loss_real.backward()
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
d_loss_fake.backward()
optimizer_d.step()
# 생성자 훈련
optimizer_g.zero_grad()
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_g.step()
if i % 100 == 0:
print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {d_loss_real.item() + d_loss_fake.item()}] "
f"[G loss: {g_loss.item()}]")
4. 결과 시각화
훈련이 완료된 후에 생성된 이미지를 시각화하는 것도 중요합니다. matplotlib을 사용하여 생성된 이미지를 출력할 수 있습니다.
import matplotlib.pyplot as plt
def generate_and_plot_images(generator, n_samples=25):
z = torch.randn(n_samples, 100)
generated_images = generator(z).detach().numpy()
plt.figure(figsize=(5, 5))
for i in range(n_samples):
plt.subplot(5, 5, i + 1)
plt.imshow(generated_images[i][0], cmap='gray')
plt.axis('off')
plt.show()
generate_and_plot_images(generator)
5. 마치며
이번 포스팅에서는 GAN의 원리와 기본 구조를 설명하고, 파이토치를 활용하여 GAN을 구현하기 위한 환경 설정과 코드 예제를 제공했습니다. GAN은 매우 강력한 생성 모델이며, 다양한 응용 분야를 갖고 있습니다. 앞으로 GAN을 활용한 다양한 프로젝트를 시도해보길 권장합니다.