파이토치를 활용한 GAN 딥러닝, CycleGAN으로 모네 그림 그리기

딥러닝 분야는 데이터와 연산 능력의 발전에 힘입어 실질적인 성과를 많이 낸 분야입니다. 그 중에서도 GAN(Generative Adversarial Network)은 가장 혁신적인 결과를 보여준 모델 중 하나입니다. 본 글에서는 딥러닝 프레임워크 중 하나인 파이토치(PyTorch)를 활용하여 CycleGAN 모델을 학습시켜 모네(Monet) 스타일의 그림을 생성하는 방법을 소개할 것입니다.

1. CycleGAN 개요

CycleGAN은 두 개의 도메인 간 변환을 위한 GAN의 일종입니다. 예를 들어, 현실 사진을 화풍으로 변환하거나 낮의 풍경을 밤의 풍경으로 변환하는 일에 사용될 수 있습니다. CycleGAN의 주요 특징은 주어진 두 개의 도메인 간의 ‘순환 학습(cycle consistency)’을 통해 각각의 도메인 사이에서 변환의 일관성을 유지하는 것입니다.

1.1 CycleGAN 구조

CycleGAN은 두 개의 생성기(Generator)와 두 개의 판별기(Discriminator)로 구성됩니다. 각각의 생성기는 한 도메인의 이미지를 다른 도메인으로 변환하며, 판별기는 생성된 이미지가 진짜 이미지인지 구분하는 역할을 합니다.

  • Generator G: 도메인 X(예: 사진)에서 도메인 Y(예: 모네 스타일의 그림)으로 변환
  • Generator F: 도메인 Y에서 도메인 X로 변환
  • Discriminator D_X: 도메인 X의 진짜와 생성된 이미지를 구분
  • Discriminator D_Y: 도메인 Y의 진짜와 생성된 이미지를 구분

1.2 손실 함수

CycleGAN의 학습 과정은 다음과 같은 손실 함수 구성으로 이루어집니다.

  • Adversarial Loss: 생성된 이미지가 얼마나 진짜 같은지를 판별기에게 평가받는 손실
  • Cycle Consistency Loss: 이미지 변환 후 원래 이미지로 다시 변환했을 때의 손실

전체 손실은 다음과 같이 정의됩니다:

L = LGAN(G, DY, X, Y) + LGAN(F, DX, Y, X) + λ(CycleLoss(G, F) + CycleLoss(F, G))

2. 환경 설정

이번 프로젝트를 위해서는 Python, PyTorch 및 필요한 라이브러리들(예: NumPy, Matplotlib)이 설치되어 있어야 합니다. 필요한 라이브러리를 설치하기 위한 명령어는 다음과 같습니다:

pip install torch torchvision numpy matplotlib

3. 데이터셋 준비

모네 스타일의 그림과 사진 데이터셋이 필요합니다. 예를 들어, Monet Style의 그림은 Kaggle Monet Style Dataset에서 다운로드 받을 수 있습니다. 또한, 일반적인 사진 이미지는 다양한 공개 이미지 데이터베이스에서 구할 수 있습니다.

이미지 데이터셋이 준비되었으면, 이를 적절한 형식으로 로드하고 전처리 해줘야 합니다.

3.1 데이터 로드 및 전처리

import os
import glob
import random
from PIL import Image
import torchvision.transforms as transforms

def load_data(image_path, image_size=(256, 256)):
    images = glob.glob(os.path.join(image_path, '*.jpg'))
    dataset = []
    for img in images:
        image = Image.open(img).convert('RGB')
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
        ])
        image = transform(image)
        dataset.append(image)
    return dataset

# 이미지 경로 설정
monet_path = './data/monet/'
photo_path = './data/photos/'

monet_images = load_data(monet_path)
photo_images = load_data(photo_path)

4. CycleGAN 모델 구축

CycleGAN 모델을 구축하기 위해 기본적인 생성기와 판별기를 정의하겠습니다.

4.1 생성기 정의

여기서는 U-Net 구조를 기반으로 한 생성기를 정의합니다.

import torch
import torch.nn as nn

class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        self.encoder1 = self.contracting_block(3, 64)
        self.encoder2 = self.contracting_block(64, 128)
        self.encoder3 = self.contracting_block(128, 256)
        self.encoder4 = self.contracting_block(256, 512)
        self.decoder1 = self.expansive_block(512, 256)
        self.decoder2 = self.expansive_block(256, 128)
        self.decoder3 = self.expansive_block(128, 64)
        self.decoder4 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1)

    def contracting_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def expansive_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        d1 = self.decoder1(e4)
        d2 = self.decoder2(d1 + e3)  # Skip connection
        d3 = self.decoder3(d2 + e2)  # Skip connection
        output = self.decoder4(d3 + e1)  # Skip connection
        return output

4.2 판별기 정의

패치 기반 구조를 사용하여 판별기를 정의합니다.

class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

5. 손실 함수 구현

CycleGAN의 손실 함수를 구현합니다. 생성기의 손실과 판별기의 손실을 모두 고려합니다.

def compute_gan_loss(predictions, targets):
    return nn.BCEWithLogitsLoss()(predictions, targets)

def compute_cycle_loss(real_image, cycled_image, lambda_cycle):
    return lambda_cycle * nn.L1Loss()(real_image, cycled_image)

def compute_total_loss(real_images_X, real_images_Y, 
                       fake_images_Y, fake_images_X, 
                       cycled_images_X, cycled_images_Y, 
                       D_X, D_Y, lambda_cycle):
    loss_GAN_X = compute_gan_loss(D_Y(fake_images_Y), torch.ones_like(fake_images_Y))
    loss_GAN_Y = compute_gan_loss(D_X(fake_images_X), torch.ones_like(fake_images_X))
    loss_cycle = compute_cycle_loss(real_images_X, cycled_images_X, lambda_cycle) + \
                compute_cycle_loss(real_images_Y, cycled_images_Y, lambda_cycle)
    return loss_GAN_X + loss_GAN_Y + loss_cycle

6. 학습 과정

이제 모델을 학습할 차례입니다. 데이터 로더를 설정하고, 모델을 초기화한 후, 손실을 저장하고 업데이트를 수행합니다.

from torch.utils.data import DataLoader

def train_cyclegan(monet_loader, photo_loader, epochs=200, lambda_cycle=10):
    G = UNetGenerator()
    F = UNetGenerator()
    D_X = PatchDiscriminator()
    D_Y = PatchDiscriminator()

    # Optimizers 설정
    optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_F = torch.optim.Adam(F.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_X = torch.optim.Adam(D_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_Y = torch.optim.Adam(D_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for real_images_X, real_images_Y in zip(monet_loader, photo_loader):
            # 생성기 학습
            fake_images_Y = G(real_images_X)
            cycled_images_X = F(fake_images_Y)

            optimizer_G.zero_grad()
            optimizer_F.zero_grad()
            total_loss = compute_total_loss(real_images_X, real_images_Y, 
                                             fake_images_Y, fake_images_X, 
                                             cycled_images_X, cycled_images_Y, 
                                             D_X, D_Y, lambda_cycle)
            total_loss.backward()
            optimizer_G.step()
            optimizer_F.step()

            # 판별기 학습
            optimizer_D_X.zero_grad()
            optimizer_D_Y.zero_grad()
            loss_D_X = compute_gan_loss(D_X(real_images_X), torch.ones_like(real_images_X)) + \
                        compute_gan_loss(D_X(fake_images_X.detach()), torch.zeros_like(fake_images_X))
            loss_D_Y = compute_gan_loss(D_Y(real_images_Y), torch.ones_like(real_images_Y)) + \
                        compute_gan_loss(D_Y(fake_images_Y.detach()), torch.zeros_like(fake_images_Y))
            loss_D_X.backward()
            loss_D_Y.backward()
            optimizer_D_X.step()
            optimizer_D_Y.step()

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item()}')

7. 결과 생성

모델이 학습을 마치면, 새로운 이미지를 생성하는 과정을 진행할 수 있습니다. 테스트 이미지를 사용하여 생성된 모네 스타일의 그림을 확인해봅시다.

def generate_images(test_loader, model_G):
    model_G.eval()
    for real_images in test_loader:
        with torch.no_grad():
            fake_images = model_G(real_images)
            # 이미지를 저장하거나 시각화하는 코드 추가

이미지를 시각화하기 위한 내장 함수를 추가합니다:

import matplotlib.pyplot as plt

def visualize_results(real_images, fake_images):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title('Real Images')
    plt.imshow(real_images.permute(1, 2, 0).numpy())
    
    plt.subplot(1, 2, 2)
    plt.title('Fake Images (Monet Style)')
    plt.imshow(fake_images.permute(1, 2, 0).numpy())
    plt.show()

8. 결론

이 글에서는 CycleGAN을 활용하여 모네 스타일의 그림을 생성하는 과정을 살펴보았습니다. 이 방법론은 많은 응용이 가능하며, 향후 더 많은 도메인 간의 변환 문제를 해결하는 데 사용될 수 있습니다. CycleGAN의 특징인 순환 일관성 또한 다양한 GAN 변형에 적용될 수 있어 앞으로의 연구 방향이 기대됩니다.

이 예제를 통해 파이토치에서 CycleGAN을 구현하는 기초를 습득하셨길 바랍니다. GAN은 높은 퀄리티의 이미지를 생성하는 데 있어 많은 가능성을 지니고 있으며, 이 기술의 발전이 더 많은 분야에 응용될 수 있을 것입니다.