파이토치를 활용한 GAN 딥러닝, 첫 번째 심층 신경망

1. Gan이란 무엇인가?

Generative Adversarial Networks (GAN)은 Ian Goodfellow가 2014년 제안한 기본적인 딥러닝 모델 중 하나입니다. GAN은 두 개의 신경망으로 구성됩니다:
생성자(Generator)구별자(Discriminator). 생성자는 가짜 데이터를 생성하려 하고, 구별자는 이 데이터가 진짜인지 가짜인지 판단하려 합니다.
이 두 네트워크는 서로 경쟁하며 훈련되기 때문에 “Adversarial”이라는 용어가 붙습니다. 이 과정은 비지도 학습 방식으로, 생성자는 점점 더 진짜 데이터와 비슷한 데이터를 생성하게 됩니다.

2. GAN의 구조

GAN의 구조는 다음과 같이 작동합니다:

  • 생성자 네트워크: 랜덤 노이즈 벡터를 입력으로 받아 가짜 이미지를 생성합니다.
  • 구별자 네트워크: 실제 이미지와 가짜 이미지를 구별하는 역할을 합니다.
  • 훈련 과정에서는 생성자가 생성한 이미지를 구별자가 잘 구별하지 못하도록 학습합니다. 이로 인해 두 네트워크는 서로를 개선하며 경쟁하게 됩니다.

3. GAN의 동작 원리

GAN의 훈련 과정은 다음과 같은 반복적인 단계로 이루어져 있습니다:

  1. 구별자의 훈련: 실제 이미지와 생성자가 만든 가짜 이미지를 입력으로 받아, 구별자는 이러한 이미지를 올바르게 분류하기 위해 파라미터를 업데이트합니다.
  2. 생성자의 훈련: 훈련된 구별자를 통해 생성자가 만든 이미지의 품질을 평가하며, 구별자가 자신이 만든 이미지를 가짜로 인식하지 않도록 파라미터를 업데이트합니다.
  3. 이 과정을 반복하여 두 네트워크가 서로 점점 더 강해지도록 합니다.

4. GAN을 파이토치로 구현하기

이제 GAN을 파이토치로 구현해보겠습니다. 우리의 목표는 MNIST 데이터셋을 사용하여 숫자 이미지를 생성하는 것입니다.

4.1 필요한 라이브러리 설치

        pip install torch torchvision matplotlib
    

4.2 데이터셋 준비하기

MNIST 데이터셋은 파이토치의 torchvision 라이브러리를 통해 쉽게 가져올 수 있습니다. 아래 코드는 데이터를 로드하고 전처리하는 과정입니다.

        
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 데이터 변환 설정
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# MNIST 데이터셋 로드
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        
        

4.3 생성자와 구별자 네트워크 정의

이제 GAN의 두 네트워크를 정의합니다. 간단한 신경망 구조를 사용하여 생성자와 구별자를 만듭니다.

        
import torch.nn as nn

# 생성자 네트워크 정의
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# 구별자 네트워크 정의
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)
        
    

4.4 손실 함수와 최적화 알고리즘 설정

손실 함수는 생성자와 구별자 모두에 대해 이진 교차 엔트로피 손실을 사용합니다. 최적화 알고리즘으로는 Adam을 채택합니다.

        
import torch.optim as optim

# 손실 함수 및 최적화 알고리즘 정의
criterion = nn.BCELoss()
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
    

4.5 훈련 과정 구현

훈련은 다음과 같은 과정으로 진행됩니다:

        
import numpy as np

num_epochs = 50
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.size(0)
        
        # 현실과 비현실 태그 설정
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # 구별자 훈련
        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)

        noise = torch.randn(batch_size, 100)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 생성자 훈련
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # 훈련 상태 출력
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
        
    

4.6 생성된 이미지 시각화하기

훈련이 완료된 후 생성된 이미지를 시각화합니다.

        
import matplotlib.pyplot as plt

def visualize_generator(num_images):
    noise = torch.randn(num_images, 100)
    with torch.no_grad():
        generated_images = generator(noise)

    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i][0].cpu().numpy(), cmap='gray')
        plt.axis('off')
    plt.show()

visualize_generator(25)
        
    

5. GAN의 활용

GAN은 이미지 생성 외에도 다양한 분야에 활용될 수 있습니다. 예를 들어, 스타일 전이, 이미지 복원, 비디오 생성 등에서 사용되고 있으며, 인공지능 분야에서 큰 주목을 받고 있습니다.
GAN의 발전은 생성적 모델을 통해 무엇이 가능한지를 새롭게 보여주고 있습니다.

6. 결론

이번 강좌에서는 파이토치를 사용하여 GAN을 구현하는 기본적인 방법에 대해 배우고 실제 코드를 통해 GAN의 작동 원리를 이해했습니다. GAN은 앞으로도 더욱 발전할 기술이며, 다양한 응용 프로그램에서 큰 가능성을 지니고 있습니다.
앞으로 더 나아가 GAN을 기반으로 한 다양한 변형 모델들을 탐구해보길 추천합니다.