파이토치를 활용한 GAN 딥러닝, 심층 신경망

1. GAN의 개요

GAN(Generative Adversarial Networks)은 2014년 Ian Goodfellow에 의해 제안된 딥러닝 모델입니다. GAN은 주어진 데이터셋의 분포를 학습하여 새로운 데이터를 생성할 수 있는 능력을 가지고 있습니다.
GAN의 주요 구성 요소는 두 개의 신경망인 생성자(Generator)와 판별자(Discriminator)입니다. 생성자는 실제 데이터와 유사한 가짜 데이터를 생성하고, 판별자는 생성된 데이터가 진짜인지 가짜인지 판단합니다.

2. GAN의 구조

GAN은 다음과 같은 구조로 이루어져 있습니다:

  • Generator (G): 무작위 노이즈를 입력으로 받아들이고, 이를 통해 가짜 데이터를 생성합니다.
  • Discriminator (D): 실제 데이터와 생성된 가짜 데이터를 구별하는 역할을 합니다.

2.1. 손실 함수

GAN의 학습 과정에서 생성자와 판별자는 각각의 손실 함수를 최적화하여 경쟁적으로 학습합니다. 판별자의 목표는 실제 데이터를 가짜 데이터와 잘 구별하는 것이고, 생성자의 목표는 판별자를 속이는 것입니다. 이를 수식으로 표현하면 아래와 같습니다:


    min_G max_D V(D, G) = E[log(D(x))] + E[log(1 - D(G(z)))]
    

3. PyTorch를 활용한 GAN 구현

이번 섹션에서는 파이토치를 사용하여 간단한 GAN을 구현해보겠습니다. 간단한 예제로 MNIST 데이터셋을 사용하여 숫자 이미지를 생성하는 GAN을 만들어보겠습니다.

3.1. 라이브러리 임포트


    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    

3.2. 하이퍼파라미터 설정


    # 하이퍼파라미터 설정
    latent_size = 64
    batch_size = 128
    learning_rate = 0.0002
    num_epochs = 50
    

3.3. 데이터셋 불러오기


    # MNIST 데이터셋 불러오기
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
    

3.4. 생성자 및 판별자 정의


    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(latent_size, 128),
                nn.ReLU(),
                nn.Linear(128, 256),
                nn.ReLU(),
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 784),
                nn.Tanh()
            )

        def forward(self, z):
            return self.model(z).reshape(-1, 1, 28, 28)

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.model = nn.Sequential(
                nn.Flatten(),
                nn.Linear(784, 512),
                nn.LeakyReLU(0.2),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2),
                nn.Linear(256, 1),
                nn.Sigmoid()
            )

        def forward(self, img):
            return self.model(img)
    

3.5. 모델, 손실 함수 및 최적화 기법 설정


    generator = Generator()
    discriminator = Discriminator()

    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
    

3.6. GAN 학습 루프


    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(dataloader):
            # 실제 이미지와 레이블을 정의합니다.
            real_imgs = imgs
            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)

            # 판별자 학습
            optimizer_D.zero_grad()
            outputs = discriminator(real_imgs)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, latent_size)
            fake_imgs = generator(z)
            outputs = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_D.step()

            # 생성자 학습
            optimizer_G.zero_grad()
            outputs = discriminator(fake_imgs)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}')
    

3.7. 결과 시각화

학습이 완료된 후 생성된 이미지를 시각화하여 GAN의 성능을 평가해보겠습니다.


    z = torch.randn(64, latent_size)
    generated_images = generator(z).detach().numpy()
    generated_images = (generated_images + 1) / 2  # 0-1로 정규화

    fig, axs = plt.subplots(8, 8, figsize=(10,10))
    for i in range(8):
        for j in range(8):
            axs[i,j].imshow(generated_images[i*8 + j][0], cmap='gray')
            axs[i,j].axis('off')
    plt.show()
    

4. 결론

본 글에서는 GAN의 기본 개념과 PyTorch를 사용한 간단한 GAN 구현 방법에 대해 살펴보았습니다. GAN은 데이터 생성 분야에서 뛰어난 성능을 보여주며, 여러 응용 분야에서 활용되고 있습니다.