Generative Adversarial Networks (GANs)는 딥러닝의 가장 혁신적인 발전 중 하나로 간주됩니다. GAN은 두 개의 신경망, 즉 생성자(Generator)와 판별자(Discriminator)로 구성됩니다. 생성자는 데이터를 생성하고, 판별자는 해당 데이터가 진짜인지 가짜인지 판단합니다. 이러한 경쟁 구조는 서로의 성능을 향상시키는 역할을 합니다. 본 강좌에서는 파이토치를 사용하여 GAN을 구축하고, 이를 통해 흥미로운 방식으로 ‘고약한 범법자를 위한 문학 클럽’을 주제로 한 데이터 생성을 수행할 것입니다.
1. GAN의 기본 구조와 원리
GAN의 작동 방식은 다음과 같습니다:
- 생성자(Generator): 무작위 노이즈(z)를 입력으로 받아 현실적인 데이터를 생성합니다.
- 판별자(Discriminator): 입력된 데이터가 실제 데이터인지 생성자에 의해 만들어진 데이터인지 판단합니다.
- 생성자는 판별자를 속이려 하며, 판별자는 생성자를 구별하려 합니다. 이러한 경쟁이 지속되면서 두 네트워크는 점점 더 발전하게 됩니다.
2. 필요한 라이브러리와 데이터셋 준비
파이토치와 기타 필요한 라이브러리를 설치합니다. 그리고 데이터를 준비하는 과정에서 사용할 데이터셋을 선택해야 합니다. 예제에서는 MNIST 데이터셋을 사용하여 숫자 이미지를 생성해보겠습니다. MNIST 데이터셋은 손으로 쓴 숫자 이미지로 구성되어 있습니다.
2.1 환경 설정
pip install torch torchvision
2.2 데이터셋 로드
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# MNIST 데이터셋 로드
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset=mnist_dataset, batch_size=64, shuffle=True)
3. GAN 모델 구축
Generative Adversarial Network를 구현하기 위해 생성자와 판별자 모델을 정의합니다.
3.1 생성자 모델
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28), # MNIST 이미지 크기
nn.Tanh() # 생성된 이미지의 픽셀 값 범위를 -1 ~ 1로 조정
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28) # 이미지 형태로 변환
return img
3.2 판별자 모델
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # 실수가 아닌 0과 1 사이의 값으로 출력
)
def forward(self, img):
img_flat = img.view(img.size(0), -1) # 이미지를 평탄화
validity = self.model(img_flat)
return validity
4. GAN의 학습 과정
GAN의 학습 과정은 다음과 같이 이루어집니다.
- 진짜 이미지와 생성된 이미지를 판별자에 제공하여 판별자의 손실(loss)을 계산합니다.
- 생성자를 업데이트하여 생성된 이미지가 실제와 가까워지도록 합니다.
- 이 과정을 반복하여 각 네트워크가 서로 발전하게 합니다.
4.1 손실 함수 및 최적화기 정의
import torch.optim as optim
# 생성자 및 판별자 인스턴스 생성
generator = Generator()
discriminator = Discriminator()
# 손실 함수 및 최적화기 설정
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
4.2 학습 루프
num_epochs = 200
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(data_loader):
# 실제 데이터의 레이블: 1
real_imgs = imgs
valid = torch.ones(imgs.size(0), 1) # 진짜 이미지에 대한 정답
fake = torch.zeros(imgs.size(0), 1) # 가짜 이미지에 대한 정답
# 판별자 학습
optimizer_D.zero_grad()
z = torch.randn(imgs.size(0), 100) # 무작위 노이즈 샘플링
generated_imgs = generator(z) # 생성된 이미지
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# 생성자 학습
optimizer_G.zero_grad()
g_loss = adversarial_loss(discriminator(generated_imgs), valid)
g_loss.backward()
optimizer_G.step()
print(f'Epoch {epoch}/{num_epochs} | D Loss: {d_loss.item()} | G Loss: {g_loss.item()}')
5. 결과 시각화
학습이 완료된 후에 생성된 이미지를 시각화하여 결과를 확인합니다.
import matplotlib.pyplot as plt
# 생성된 이미지 시각화
def show_generated_images(generator, num_images=16):
z = torch.randn(num_images, 100) # 무작위 노이즈 샘플링
generated_images = generator(z)
generated_images = generated_images.detach().numpy()
fig, axs = plt.subplots(4, 4, figsize=(10, 10))
for i in range(4):
for j in range(4):
axs[i, j].imshow(generated_images[i * 4 + j, 0], cmap='gray')
axs[i, j].axis('off')
plt.show()
show_generated_images(generator)
이와 같은 방법으로 GAN을 구축하고 학습시켜 생성된 이미지를 확인할 수 있습니다. GAN의 활용 가능성은 매우 넓으며, 이를 통해 창의적인 작업을 수행할 수 있습니다. 이제 여러분도 GAN의 세계에 한 걸음 더 가까워질 수 있습니다!
6. 결론
Generative Adversarial Networks는 매우 흥미로운 딥러닝 분야로, 많은 연구와 개발에서 활발히 사용되고 있습니다. 이번 강좌에서는 파이토치를 활용하여 GAN의 기본 원리와 구조를 살펴보았고, 실제로 딥러닝 모델을 구축하고 학습하는 과정을 다루었습니다. 여러분이 이 강좌를 통해 GAN에 대한 깊은 이해와 흥미를 느끼길 바라며, 앞으로의 딥러닝 여정에 큰 도움이 되길 바랍니다.