1. GAN의 개요
GAN(Generative Adversarial Networks)은 2014년 Ian Goodfellow에 의해 제안된 혁신적인 딥러닝 모델로, 생성 모델과 판별 모델이 서로 대립하여 학습하는 방식으로 작동합니다. GAN의 기본 구성 요소는 생성기(Generator)와 판별기(Discriminator)입니다. 생성기는 실제와 유사한 데이터를 생성하려고 하며, 판별기는 주어진 데이터가 실제 데이터인지 생성된 데이터인지를 판단합니다. 이러한 두 모델 간의 경쟁을 통해 생성기는 점점 더 실제와 유사한 데이터를 만들어 내게 됩니다.
2. GAN의 주요 구성 요소
2.1 생성기(Generator)
생성기는 랜덤 노이즈 벡터를 입력받아 실제 데이터와 유사한 데이터를 생성하는 신경망입니다. 이 네트워크는 일반적으로 다층 퍼셉트론 또는 컨볼루션 신경망을 사용할 수 있습니다.
2.2 판별기(Discriminator)
판별기는 입력받은 데이터가 실제 데이터인지 생성된 데이터인지를 판단하는 신경망입니다. 생성기가 만든 가짜 데이터와 실제 데이터를 잘 구분할 수 있는 능력이 중요합니다.
2.3 손실 함수
GAN의 손실 함수는 생성기의 손실과 판별기의 손실로 나뉩니다. 생성기의 목표는 판별기를 속이는 것이고, 판별기의 목표는 생성된 데이터를 잘 구별하는 것입니다.
3. GAN의 훈련 과정
GAN의 훈련 과정은 반복적이며, 다음과 같은 단계로 구성됩니다:
- 실제 데이터와 랜덤 노이즈를 사용하여 생성기를 통해 가짜 데이터를 생성합니다.
- 가짜 데이터와 실제 데이터를 판별기에 입력합니다.
- 판별기가 가짜와 실제를 얼마나 잘 구별했는지 계산하여 손실을 업데이트합니다.
- 생성기를 업데이트하여 판별기를 속일 수 있도록 합니다.
4. GAN의 활용 분야
GAN은 다음과 같은 다양한 분야에서 활용됩니다:
- 이미지 생성
- 비디오 생성
- 음성 합성
- 텍스트 생성
- 데이터 증강
5. 정형 데이터와 비정형 데이터의 차이
정형 데이터는 구조화된 데이터로, 관계형 데이터베이스에 쉽게 표현할 수 있는 데이터입니다. 반면, 비정형 데이터는 텍스트, 이미지, 비디오 등과 같이 고유한 형식과 구조가 없는 데이터입니다. GAN은 비정형 데이터에 주로 사용되지만 정형 데이터에도 활용될 수 있습니다.
6. 파이토치를 사용한 GAN 구현 예제
다음은 파이토치를 사용한 GAN의 간단한 구현 예제입니다. 이 예제에서는 MNIST 데이터셋을 사용하여 손글씨 숫자를 생성합니다.
6.1 환경 설정
먼저 필요한 라이브러리를 설치하고 불러옵니다.
!pip install torch torchvision matplotlib
6.2 데이터셋 준비
MNIST 데이터셋을 로드하고, 이를 텐서로 변환합니다.
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
6.3 모델 구성
생성기와 판별기 모델을 정의합니다.
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784), # MNIST와 같은 28x28 이미지
nn.Tanh(),
)
def forward(self, x):
return self.model(x)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
6.4 손실 함수와 옵티마이저 설정
손실 함수로는 BCELoss를 사용하고, Adam 옵티마이저로 학습을 진행합니다.
criterion = nn.BCELoss()
generator = Generator()
discriminator = Discriminator()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
6.5 모델 훈련
모델 훈련을 수행합니다. 매 에포크마다 생성된 샘플을 확인할 수 있습니다.
import matplotlib.pyplot as plt
num_epochs = 100
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
# 라벨 생성
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 판별기 훈련
optimizer_D.zero_grad()
outputs = discriminator(real_images.view(batch_size, -1))
d_loss_real = criterion(outputs, real_labels)
d_loss_real.backward()
noise = torch.randn(batch_size, 100)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
d_loss_fake.backward()
optimizer_D.step()
# 생성기 훈련
optimizer_G.zero_grad()
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_G.step()
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item():.4f}, g_loss: {g_loss.item():.4f}')
# 생성된 이미지 확인
with torch.no_grad():
fake_images = generator(noise).view(-1, 1, 28, 28)
grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True)
plt.imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
plt.show()
6.6 결과 확인
훈련이 진행됨에 따라 생성기는 점점 더 현실적인 손글씨 숫자를 생성합니다. 이를 통해 GAN의 성능을 확인할 수 있습니다.
7. 종료 및 결론
이번 글에서는 GAN의 개요, 구성 요소, 훈련 과정, 파이토치를 이용한 구현 방법을 다루었습니다. GAN은 비정형 데이터 생성에서 뛰어난 성능을 보여주며, 다양한 응용 분야에서 활용될 수 있습니다. 이러한 기술이 발전함에 따라 앞으로의 가능성은 무궁무진합니다.
8. 참고 문헌
1. Ian Goodfellow et al., “Generative Adversarial Networks”, NeurIPS 2014.
2. PyTorch Documentation – https://pytorch.org/docs/stable/index.html
3. torchvision Documentation – https://pytorch.org/vision/stable/index.html