The field of deep learning has made significant achievements thanks to advancements in data and computational power. Among them, GAN (Generative Adversarial Network) is one of the most innovative models. In this article, we will introduce how to train the CycleGAN model using PyTorch, one of the deep learning frameworks, to generate paintings in the style of Monet.
1. Overview of CycleGAN
CycleGAN is a type of GAN used for transformation between two domains. For instance, it can be used to transform real photos into artistic styles or to convert daytime scenes into nighttime scenes. A key feature of CycleGAN is maintaining the consistency of transformations between the two given domains through ‘cycle consistency’ learning.
1.1 CycleGAN Structure
CycleGAN consists of two generators and two discriminators. Each generator transforms an image from one domain to another while the discriminator’s role is to distinguish whether the generated image is real or fake.
- Generator G: Transforms from domain X (e.g., photos) to domain Y (e.g., Monet-style paintings)
- Generator F: Transforms from domain Y to domain X
- Discriminator D_X: Distinguishes between real and generated images in domain X
- Discriminator D_Y: Distinguishes between real and generated images in domain Y
1.2 Loss Function
The training process of CycleGAN consists of the following loss function compositions.
- Adversarial Loss: The loss evaluated by the discriminator on how real the generated images are
- Cycle Consistency Loss: The loss when transforming an image back to the original after transformation
The total loss is defined as follows:
L = LGAN(G, DY, X, Y) + LGAN(F, DX, Y, X) + λ(CycleLoss(G, F) + CycleLoss(F, G))
2. Environment Setup
For this project, Python, PyTorch, and the necessary libraries (e.g., NumPy, Matplotlib) must be installed. The command to install the required libraries is as follows:
pip install torch torchvision numpy matplotlib
3. Dataset Preparation
You will need a dataset of Monet-style paintings and photographs. For instance, the Monet Style paintings can be downloaded from the Kaggle Monet Style Dataset. Additionally, general photograph images can be obtained from various public image databases.
Once the image datasets are prepared, they need to be loaded and preprocessed in the appropriate format.
3.1 Data Loading and Preprocessing
import os
import glob
import random
from PIL import Image
import torchvision.transforms as transforms
def load_data(image_path, image_size=(256, 256)):
images = glob.glob(os.path.join(image_path, '*.jpg'))
dataset = []
for img in images:
image = Image.open(img).convert('RGB')
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
])
image = transform(image)
dataset.append(image)
return dataset
# Set the image paths
monet_path = './data/monet/'
photo_path = './data/photos/'
monet_images = load_data(monet_path)
photo_images = load_data(photo_path)
4. Building the CycleGAN Model
To build the CycleGAN model, we will define basic generators and discriminators.
4.1 Generator Definition
Here, we define a generator based on the U-Net architecture.
import torch
import torch.nn as nn
class UNetGenerator(nn.Module):
def __init__(self):
super(UNetGenerator, self).__init__()
self.encoder1 = self.contracting_block(3, 64)
self.encoder2 = self.contracting_block(64, 128)
self.encoder3 = self.contracting_block(128, 256)
self.encoder4 = self.contracting_block(256, 512)
self.decoder1 = self.expansive_block(512, 256)
self.decoder2 = self.expansive_block(256, 128)
self.decoder3 = self.expansive_block(128, 64)
self.decoder4 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1)
def contracting_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def expansive_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
d1 = self.decoder1(e4)
d2 = self.decoder2(d1 + e3) # Skip connection
d3 = self.decoder3(d2 + e2) # Skip connection
output = self.decoder4(d3 + e1) # Skip connection
return output
4.2 Discriminator Definition
The discriminator is defined using a patch-based structure.
class PatchDiscriminator(nn.Module):
def __init__(self):
super(PatchDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
)
def forward(self, x):
return self.model(x)
5. Implementing the Loss Function
We will implement the loss functions for the CycleGAN, considering both the generator’s loss and the discriminator’s loss.
def compute_gan_loss(predictions, targets):
return nn.BCEWithLogitsLoss()(predictions, targets)
def compute_cycle_loss(real_image, cycled_image, lambda_cycle):
return lambda_cycle * nn.L1Loss()(real_image, cycled_image)
def compute_total_loss(real_images_X, real_images_Y,
fake_images_Y, fake_images_X,
cycled_images_X, cycled_images_Y,
D_X, D_Y, lambda_cycle):
loss_GAN_X = compute_gan_loss(D_Y(fake_images_Y), torch.ones_like(fake_images_Y))
loss_GAN_Y = compute_gan_loss(D_X(fake_images_X), torch.ones_like(fake_images_X))
loss_cycle = compute_cycle_loss(real_images_X, cycled_images_X, lambda_cycle) + \
compute_cycle_loss(real_images_Y, cycled_images_Y, lambda_cycle)
return loss_GAN_X + loss_GAN_Y + loss_cycle
6. Training Process
Now it’s time to train the model. Set up the data loader, initialize the model, and perform loss storage and updates.
from torch.utils.data import DataLoader
def train_cyclegan(monet_loader, photo_loader, epochs=200, lambda_cycle=10):
G = UNetGenerator()
F = UNetGenerator()
D_X = PatchDiscriminator()
D_Y = PatchDiscriminator()
# Set up optimizers
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_F = torch.optim.Adam(F.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_X = torch.optim.Adam(D_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_Y = torch.optim.Adam(D_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(epochs):
for real_images_X, real_images_Y in zip(monet_loader, photo_loader):
# Train generator
fake_images_Y = G(real_images_X)
cycled_images_X = F(fake_images_Y)
optimizer_G.zero_grad()
optimizer_F.zero_grad()
total_loss = compute_total_loss(real_images_X, real_images_Y,
fake_images_Y, fake_images_X,
cycled_images_X, cycled_images_Y,
D_X, D_Y, lambda_cycle)
total_loss.backward()
optimizer_G.step()
optimizer_F.step()
# Train discriminator
optimizer_D_X.zero_grad()
optimizer_D_Y.zero_grad()
loss_D_X = compute_gan_loss(D_X(real_images_X), torch.ones_like(real_images_X)) + \
compute_gan_loss(D_X(fake_images_X.detach()), torch.zeros_like(fake_images_X))
loss_D_Y = compute_gan_loss(D_Y(real_images_Y), torch.ones_like(real_images_Y)) + \
compute_gan_loss(D_Y(fake_images_Y.detach()), torch.zeros_like(fake_images_Y))
loss_D_X.backward()
loss_D_Y.backward()
optimizer_D_X.step()
optimizer_D_Y.step()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item()}')
7. Generating Results
Once the model has finished training, you can proceed to generate new images. Let’s check the generated Monet-style paintings using test images.
def generate_images(test_loader, model_G):
model_G.eval()
for real_images in test_loader:
with torch.no_grad():
fake_images = model_G(real_images)
# Add code to save or visualize the images
We will add built-in functions to visualize the images:
import matplotlib.pyplot as plt
def visualize_results(real_images, fake_images):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('Real Images')
plt.imshow(real_images.permute(1, 2, 0).numpy())
plt.subplot(1, 2, 2)
plt.title('Fake Images (Monet Style)')
plt.imshow(fake_images.permute(1, 2, 0).numpy())
plt.show()
8. Conclusion
In this article, we explored the process of generating Monet-style paintings using CycleGAN. This methodology has many applications and can be used to address more domain transformation problems in the future. The cycle consistency characteristic of CycleGAN can also be applied to various GAN variations, making the future research directions exciting.
We hope that this example has helped you grasp the basics of implementing CycleGAN in PyTorch. GANs hold a lot of potential for generating high-quality images, and the advancement of this technology is likely to find applications in many more fields.