파이토치를 활용한 GAN 딥러닝, 고약한 범법자를 위한 문학 클럽

Generative Adversarial Networks (GANs)는 딥러닝의 가장 혁신적인 발전 중 하나로 간주됩니다. GAN은 두 개의 신경망, 즉 생성자(Generator)와 판별자(Discriminator)로 구성됩니다. 생성자는 데이터를 생성하고, 판별자는 해당 데이터가 진짜인지 가짜인지 판단합니다. 이러한 경쟁 구조는 서로의 성능을 향상시키는 역할을 합니다. 본 강좌에서는 파이토치를 사용하여 GAN을 구축하고, 이를 통해 흥미로운 방식으로 ‘고약한 범법자를 위한 문학 클럽’을 주제로 한 데이터 생성을 수행할 것입니다.

1. GAN의 기본 구조와 원리

GAN의 작동 방식은 다음과 같습니다:

  • 생성자(Generator): 무작위 노이즈(z)를 입력으로 받아 현실적인 데이터를 생성합니다.
  • 판별자(Discriminator): 입력된 데이터가 실제 데이터인지 생성자에 의해 만들어진 데이터인지 판단합니다.
  • 생성자는 판별자를 속이려 하며, 판별자는 생성자를 구별하려 합니다. 이러한 경쟁이 지속되면서 두 네트워크는 점점 더 발전하게 됩니다.

2. 필요한 라이브러리와 데이터셋 준비

파이토치와 기타 필요한 라이브러리를 설치합니다. 그리고 데이터를 준비하는 과정에서 사용할 데이터셋을 선택해야 합니다. 예제에서는 MNIST 데이터셋을 사용하여 숫자 이미지를 생성해보겠습니다. MNIST 데이터셋은 손으로 쓴 숫자 이미지로 구성되어 있습니다.

2.1 환경 설정

pip install torch torchvision

2.2 데이터셋 로드

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# MNIST 데이터셋 로드
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset=mnist_dataset, batch_size=64, shuffle=True)

3. GAN 모델 구축

Generative Adversarial Network를 구현하기 위해 생성자와 판별자 모델을 정의합니다.

3.1 생성자 모델

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),  # MNIST 이미지 크기
            nn.Tanh()  # 생성된 이미지의 픽셀 값 범위를 -1 ~ 1로 조정
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)  # 이미지 형태로 변환
        return img

3.2 판별자 모델

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 실수가 아닌 0과 1 사이의 값으로 출력
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)  # 이미지를 평탄화
        validity = self.model(img_flat)
        return validity

4. GAN의 학습 과정

GAN의 학습 과정은 다음과 같이 이루어집니다.

  • 진짜 이미지와 생성된 이미지를 판별자에 제공하여 판별자의 손실(loss)을 계산합니다.
  • 생성자를 업데이트하여 생성된 이미지가 실제와 가까워지도록 합니다.
  • 이 과정을 반복하여 각 네트워크가 서로 발전하게 합니다.

4.1 손실 함수 및 최적화기 정의

import torch.optim as optim

# 생성자 및 판별자 인스턴스 생성
generator = Generator()
discriminator = Discriminator()

# 손실 함수 및 최적화기 설정
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

4.2 학습 루프

num_epochs = 200
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(data_loader):
        # 실제 데이터의 레이블: 1
        real_imgs = imgs
        valid = torch.ones(imgs.size(0), 1)  # 진짜 이미지에 대한 정답
        fake = torch.zeros(imgs.size(0), 1)  # 가짜 이미지에 대한 정답

        # 판별자 학습
        optimizer_D.zero_grad()
        z = torch.randn(imgs.size(0), 100)  # 무작위 노이즈 샘플링
        generated_imgs = generator(z)  # 생성된 이미지
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # 생성자 학습
        optimizer_G.zero_grad()
        g_loss = adversarial_loss(discriminator(generated_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch {epoch}/{num_epochs} | D Loss: {d_loss.item()} | G Loss: {g_loss.item()}')

5. 결과 시각화

학습이 완료된 후에 생성된 이미지를 시각화하여 결과를 확인합니다.

import matplotlib.pyplot as plt

# 생성된 이미지 시각화
def show_generated_images(generator, num_images=16):
    z = torch.randn(num_images, 100)  # 무작위 노이즈 샘플링
    generated_images = generator(z)
    generated_images = generated_images.detach().numpy()

    fig, axs = plt.subplots(4, 4, figsize=(10, 10))
    for i in range(4):
        for j in range(4):
            axs[i, j].imshow(generated_images[i * 4 + j, 0], cmap='gray')
            axs[i, j].axis('off')
    plt.show()

show_generated_images(generator)

이와 같은 방법으로 GAN을 구축하고 학습시켜 생성된 이미지를 확인할 수 있습니다. GAN의 활용 가능성은 매우 넓으며, 이를 통해 창의적인 작업을 수행할 수 있습니다. 이제 여러분도 GAN의 세계에 한 걸음 더 가까워질 수 있습니다!

6. 결론

Generative Adversarial Networks는 매우 흥미로운 딥러닝 분야로, 많은 연구와 개발에서 활발히 사용되고 있습니다. 이번 강좌에서는 파이토치를 활용하여 GAN의 기본 원리와 구조를 살펴보았고, 실제로 딥러닝 모델을 구축하고 학습하는 과정을 다루었습니다. 여러분이 이 강좌를 통해 GAN에 대한 깊은 이해와 흥미를 느끼길 바라며, 앞으로의 딥러닝 여정에 큰 도움이 되길 바랍니다.