딥러닝 파이토치 강좌, CycleGAN

딥러닝의 발전은 다양한 분야에서의 이미지 변환 및 생성 모델의 가능성을 열었습니다. 생성적 적대 신경망(Generative Adversarial Networks, GANs)은 이러한 발전의 중핵을 이루며, 그 중에서도 CycleGAN은 특히 스타일 변환에 유용한 모델로 각광받고 있습니다.
이 글에서는 CycleGAN의 원리, 활용법 및 파이썬의 PyTorch 라이브러리를 사용한 구현 과정을 자세히 설명하겠습니다.

1. CycleGAN의 개요

CycleGAN은 두 개의 이미지 도메인 간의 이미지 변환을 학습하는 데 사용되는 모델입니다. 이 모델은 각 도메인에서 이미지를 서로 변환하는 두 개의 생성기와 해당 생성기가 생성한 이미지를 원래 도메인으로 변환하는 두 개의 판별기로 이루어져 있습니다.
CycleGAN은 특히 두 도메인 간의 직접적인 대응이 필요 없는 경우에 유리합니다. 예를 들어, 사진을 그림으로 변환하거나, 여름 사진을 겨울 사진으로 변환하는 등의 작업을 수행할 수 있습니다.

2. CycleGAN의 구조

CycleGAN의 기본 구조는 다음과 같은 네 가지 주요 구성 요소로 이루어져 있습니다.

  • Generator G: 도메인 X의 이미지를 도메인 Y의 이미지로 변환합니다.
  • Generator F: 도메인 Y의 이미지를 도메인 X의 이미지로 변환합니다.
  • Discriminator D_X: 도메인 X의 실제 이미지와 G가 생성한 변환 이미지를 구분합니다.
  • Discriminator D_Y: 도메인 Y의 실제 이미지와 F가 생성한 변환 이미지를 구분합니다.

2.1. Loss Function

CycleGAN은 여러 가지 손실 함수를 사용하여 훈련됩니다. 주요 손실 함수는 다음과 같습니다.

  • Adversarial Loss: 판별기가 생성한 이미지와 실제 이미지를 구별하는 능력을 기반으로 생성기의 성능을 평가합니다.
  • Cycle Consistency Loss: X에서 Y로 변환한 뒤 다시 X로 변환하는 과정을 거치면서 원래 이미지를 재구성할 수 있어야 한다는 원칙을 적용합니다. 즉, F(G(X)) ≈ X 여야 합니다.

3. CycleGAN 구현하기

이제 CycleGAN을 PyTorch를 사용하여 구현해보겠습니다. 이 과정은 데이터 준비, 모델 정의, 손실 함수 및 최적화 설정, 학습 루프, 그리고 결과 시각화 등을 포함합니다.

3.1. 데이터 준비

CycleGAN을 훈련시키기 위해서는 두 개의 이미지 도메인이 필요합니다. 우리는 예시로 ‘여름’과 ‘겨울’ 이미지 데이터를 사용할 것입니다. 해당 데이터셋은 Apple2Orange, Horse2Zebra와 같은 유명한 공개 데이터셋을 사용할 수 있습니다. 아래 코드는 데이터셋을 로드하는 방법을 보여줍니다.


import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# 데이터 변환 정의
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

# 데이터 로드
summer_dataset = ImageFolder(root='data/summer', transform=transform)
winter_dataset = ImageFolder(root='data/winter', transform=transform)

summer_loader = DataLoader(summer_dataset, batch_size=1, shuffle=True)
winter_loader = DataLoader(winter_dataset, batch_size=1, shuffle=True)
    

3.2. 모델 정의

CycleGAN에서는 고차원 기능을 학습할 수 있도록 U-Net과 같은 구조를 따르는 생성기를 정의합니다. 다음 코드에서는 단순한 생성기 모델을 정의합니다.


import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.hidden_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # 중간 레이어 추가
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # 디코더
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
        )

    def forward(self, x):
        return self.hidden_layers(x)
    

3.3. 손실 함수 및 최적화 설정

이제 손실 함수와 최적화 알고리즘을 설정합니다. 우리는 진짜-가짜 판별을 위한 이진 교차 엔트로피 손실 함수와 Cycle Consistency Loss를 사용할 것입니다.


criterion_gan = nn.BCELoss()
criterion_cycle = nn.L1Loss()

# Adam 옵티마이저
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    

3.4. 학습 루프

학습 루프에서는 모델을 훈련시키고 손실 값을 기록합니다. 기본적인 학습 루프는 다음과 같은 구조로 작성할 수 있습니다.


num_epochs = 200
for epoch in range(num_epochs):
    for (summer_images, winter_images) in zip(summer_loader, winter_loader):
        real_A = summer_images[0].to(device)
        real_B = winter_images[0].to(device)

        # 생성적 손실 계산
        fake_B = generator_G(real_A)
        cycled_A = generator_F(fake_B)

        loss_cycle = criterion_cycle(cycled_A, real_A) 

        # Adversarial Loss 계산
        loss_G = criterion_gan(discriminator_D_Y(fake_B), real_labels) + loss_cycle

        # 역전파 및 최적화
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # 결과 기록
        print(f'Epoch [{epoch}/{num_epochs}], Loss: {loss_G.item()}')
    

3.5. 결과 시각화

학습이 완료된 후, CycleGAN의 결과를 시각화하기 위해 몇 가지 이미지를 생성하고 이를 사용자에게 보여줍니다. 다음 코드는 결과 이미지를 저장하고 시각화하는 방법을 보여줍니다.


import matplotlib.pyplot as plt

# 이미지를 생성하고 저장하는 함수
def save_image(tensor, filename):
    image = tensor.detach().cpu().numpy()
    image = image.transpose((1, 2, 0))
    plt.imsave(filename, (image * 255).astype('uint8'))

# 훈련된 생성기를 사용하여 이미지 생성
with torch.no_grad():
    for i, summer_images in enumerate(summer_loader):
        fake_images = generator_G(summer_images[0].to(device))
        save_image(fake_images, f'output/image_{i}.png')
    break
    

4. CycleGAN의 활용

CycleGAN은 이미지 변환 및 스타일 변환 외에도 다양한 분야에서 활용될 수 있습니다. 예를 들어, 의료 이미징, 비디오 변환, 그리고 패션 디자인 등에서 사용할 수 있습니다.

4.1. 의료 이미지 처리

CycleGAN은 의학적 이미지에서 병리적 변화를 식별하는 데 큰 도움이 됩니다. 환자의 CT 스캔을 MRI 이미지로 변환함으로써 의사들이 비교하고 분석하기 쉽게 할 수 있습니다.

4.2. 비디오 변환

CycleGAN을 사용하여 비디오의 한 스타일을 다른 스타일로 변환할 수 있습니다. 예를 들어, 실시간 비디오 스트림에서 여름의 풍경을 겨울로 변환하는데 활용될 수 있습니다.

4.3. 패션 디자인

CycleGAN은 패션 디자인 분야에서도 혁신을 가져올 수 있습니다. 디자이너가 다양한 스타일의 의상을 시뮬레이션하고 디자인하는 데 도움을 줄 수 있습니다.

5. 결론

CycleGAN은 이미지 변환 분야에서 매우 유용한 도구입니다. 이 모델은 비디오, 패션 등 다양한 응용 분야에 적합하며, 비전 분야에서의 한계를 극복하는 데 중요한 역할을 합니다.
이 글에서는 CycleGAN의 기본 원리부터 구현, 결과 시각화까지의 과정을 자세히 살펴보았습니다. 앞으로의 연구와 발전이 기대되며, CycleGAN에 대한 이해가 향후 개발에 큰 도움이 되기를 바랍니다.