딥러닝 분야는 데이터와 연산 능력의 발전에 힘입어 실질적인 성과를 많이 낸 분야입니다. 그 중에서도 GAN(Generative Adversarial Network)은 가장 혁신적인 결과를 보여준 모델 중 하나입니다. 본 글에서는 딥러닝 프레임워크 중 하나인 파이토치(PyTorch)를 활용하여 CycleGAN 모델을 학습시켜 모네(Monet) 스타일의 그림을 생성하는 방법을 소개할 것입니다.
1. CycleGAN 개요
CycleGAN은 두 개의 도메인 간 변환을 위한 GAN의 일종입니다. 예를 들어, 현실 사진을 화풍으로 변환하거나 낮의 풍경을 밤의 풍경으로 변환하는 일에 사용될 수 있습니다. CycleGAN의 주요 특징은 주어진 두 개의 도메인 간의 ‘순환 학습(cycle consistency)’을 통해 각각의 도메인 사이에서 변환의 일관성을 유지하는 것입니다.
1.1 CycleGAN 구조
CycleGAN은 두 개의 생성기(Generator)와 두 개의 판별기(Discriminator)로 구성됩니다. 각각의 생성기는 한 도메인의 이미지를 다른 도메인으로 변환하며, 판별기는 생성된 이미지가 진짜 이미지인지 구분하는 역할을 합니다.
- Generator G: 도메인 X(예: 사진)에서 도메인 Y(예: 모네 스타일의 그림)으로 변환
- Generator F: 도메인 Y에서 도메인 X로 변환
- Discriminator D_X: 도메인 X의 진짜와 생성된 이미지를 구분
- Discriminator D_Y: 도메인 Y의 진짜와 생성된 이미지를 구분
1.2 손실 함수
CycleGAN의 학습 과정은 다음과 같은 손실 함수 구성으로 이루어집니다.
- Adversarial Loss: 생성된 이미지가 얼마나 진짜 같은지를 판별기에게 평가받는 손실
- Cycle Consistency Loss: 이미지 변환 후 원래 이미지로 다시 변환했을 때의 손실
전체 손실은 다음과 같이 정의됩니다:
L = LGAN(G, DY, X, Y) + LGAN(F, DX, Y, X) + λ(CycleLoss(G, F) + CycleLoss(F, G))
2. 환경 설정
이번 프로젝트를 위해서는 Python, PyTorch 및 필요한 라이브러리들(예: NumPy, Matplotlib)이 설치되어 있어야 합니다. 필요한 라이브러리를 설치하기 위한 명령어는 다음과 같습니다:
pip install torch torchvision numpy matplotlib
3. 데이터셋 준비
모네 스타일의 그림과 사진 데이터셋이 필요합니다. 예를 들어, Monet Style의 그림은 Kaggle Monet Style Dataset에서 다운로드 받을 수 있습니다. 또한, 일반적인 사진 이미지는 다양한 공개 이미지 데이터베이스에서 구할 수 있습니다.
이미지 데이터셋이 준비되었으면, 이를 적절한 형식으로 로드하고 전처리 해줘야 합니다.
3.1 데이터 로드 및 전처리
import os
import glob
import random
from PIL import Image
import torchvision.transforms as transforms
def load_data(image_path, image_size=(256, 256)):
images = glob.glob(os.path.join(image_path, '*.jpg'))
dataset = []
for img in images:
image = Image.open(img).convert('RGB')
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
])
image = transform(image)
dataset.append(image)
return dataset
# 이미지 경로 설정
monet_path = './data/monet/'
photo_path = './data/photos/'
monet_images = load_data(monet_path)
photo_images = load_data(photo_path)
4. CycleGAN 모델 구축
CycleGAN 모델을 구축하기 위해 기본적인 생성기와 판별기를 정의하겠습니다.
4.1 생성기 정의
여기서는 U-Net 구조를 기반으로 한 생성기를 정의합니다.
import torch
import torch.nn as nn
class UNetGenerator(nn.Module):
def __init__(self):
super(UNetGenerator, self).__init__()
self.encoder1 = self.contracting_block(3, 64)
self.encoder2 = self.contracting_block(64, 128)
self.encoder3 = self.contracting_block(128, 256)
self.encoder4 = self.contracting_block(256, 512)
self.decoder1 = self.expansive_block(512, 256)
self.decoder2 = self.expansive_block(256, 128)
self.decoder3 = self.expansive_block(128, 64)
self.decoder4 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1)
def contracting_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def expansive_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
d1 = self.decoder1(e4)
d2 = self.decoder2(d1 + e3) # Skip connection
d3 = self.decoder3(d2 + e2) # Skip connection
output = self.decoder4(d3 + e1) # Skip connection
return output
4.2 판별기 정의
패치 기반 구조를 사용하여 판별기를 정의합니다.
class PatchDiscriminator(nn.Module):
def __init__(self):
super(PatchDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
)
def forward(self, x):
return self.model(x)
5. 손실 함수 구현
CycleGAN의 손실 함수를 구현합니다. 생성기의 손실과 판별기의 손실을 모두 고려합니다.
def compute_gan_loss(predictions, targets):
return nn.BCEWithLogitsLoss()(predictions, targets)
def compute_cycle_loss(real_image, cycled_image, lambda_cycle):
return lambda_cycle * nn.L1Loss()(real_image, cycled_image)
def compute_total_loss(real_images_X, real_images_Y,
fake_images_Y, fake_images_X,
cycled_images_X, cycled_images_Y,
D_X, D_Y, lambda_cycle):
loss_GAN_X = compute_gan_loss(D_Y(fake_images_Y), torch.ones_like(fake_images_Y))
loss_GAN_Y = compute_gan_loss(D_X(fake_images_X), torch.ones_like(fake_images_X))
loss_cycle = compute_cycle_loss(real_images_X, cycled_images_X, lambda_cycle) + \
compute_cycle_loss(real_images_Y, cycled_images_Y, lambda_cycle)
return loss_GAN_X + loss_GAN_Y + loss_cycle
6. 학습 과정
이제 모델을 학습할 차례입니다. 데이터 로더를 설정하고, 모델을 초기화한 후, 손실을 저장하고 업데이트를 수행합니다.
from torch.utils.data import DataLoader
def train_cyclegan(monet_loader, photo_loader, epochs=200, lambda_cycle=10):
G = UNetGenerator()
F = UNetGenerator()
D_X = PatchDiscriminator()
D_Y = PatchDiscriminator()
# Optimizers 설정
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_F = torch.optim.Adam(F.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_X = torch.optim.Adam(D_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_Y = torch.optim.Adam(D_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(epochs):
for real_images_X, real_images_Y in zip(monet_loader, photo_loader):
# 생성기 학습
fake_images_Y = G(real_images_X)
cycled_images_X = F(fake_images_Y)
optimizer_G.zero_grad()
optimizer_F.zero_grad()
total_loss = compute_total_loss(real_images_X, real_images_Y,
fake_images_Y, fake_images_X,
cycled_images_X, cycled_images_Y,
D_X, D_Y, lambda_cycle)
total_loss.backward()
optimizer_G.step()
optimizer_F.step()
# 판별기 학습
optimizer_D_X.zero_grad()
optimizer_D_Y.zero_grad()
loss_D_X = compute_gan_loss(D_X(real_images_X), torch.ones_like(real_images_X)) + \
compute_gan_loss(D_X(fake_images_X.detach()), torch.zeros_like(fake_images_X))
loss_D_Y = compute_gan_loss(D_Y(real_images_Y), torch.ones_like(real_images_Y)) + \
compute_gan_loss(D_Y(fake_images_Y.detach()), torch.zeros_like(fake_images_Y))
loss_D_X.backward()
loss_D_Y.backward()
optimizer_D_X.step()
optimizer_D_Y.step()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item()}')
7. 결과 생성
모델이 학습을 마치면, 새로운 이미지를 생성하는 과정을 진행할 수 있습니다. 테스트 이미지를 사용하여 생성된 모네 스타일의 그림을 확인해봅시다.
def generate_images(test_loader, model_G):
model_G.eval()
for real_images in test_loader:
with torch.no_grad():
fake_images = model_G(real_images)
# 이미지를 저장하거나 시각화하는 코드 추가
이미지를 시각화하기 위한 내장 함수를 추가합니다:
import matplotlib.pyplot as plt
def visualize_results(real_images, fake_images):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('Real Images')
plt.imshow(real_images.permute(1, 2, 0).numpy())
plt.subplot(1, 2, 2)
plt.title('Fake Images (Monet Style)')
plt.imshow(fake_images.permute(1, 2, 0).numpy())
plt.show()
8. 결론
이 글에서는 CycleGAN을 활용하여 모네 스타일의 그림을 생성하는 과정을 살펴보았습니다. 이 방법론은 많은 응용이 가능하며, 향후 더 많은 도메인 간의 변환 문제를 해결하는 데 사용될 수 있습니다. CycleGAN의 특징인 순환 일관성 또한 다양한 GAN 변형에 적용될 수 있어 앞으로의 연구 방향이 기대됩니다.
이 예제를 통해 파이토치에서 CycleGAN을 구현하는 기초를 습득하셨길 바랍니다. GAN은 높은 퀄리티의 이미지를 생성하는 데 있어 많은 가능성을 지니고 있으며, 이 기술의 발전이 더 많은 분야에 응용될 수 있을 것입니다.