파이토치를 활용한 GAN 딥러닝, 강화학습

1. 소개

Generative Adversarial Networks (GANs)는 2014년 Ian Goodfellow에 의해 제안된 모델로, 두 신경망 간의 경쟁을 통해 데이터를 생성하는 모델입니다. GAN은 특히 이미지 생성, 스타일 변환, 데이터 증강 등에 널리 사용되고 있습니다. 이번 포스팅에서는 GAN의 기본 구조, 파이토치를 이용한 구현 방법, 강화학습의 기본 개념 및 여러 응용 사례를 소개하겠습니다.

2. GAN의 기본 구조

GAN은 두 개의 신경망으로 구성됩니다: 생성자(Generator)와 판별자(Discriminator)입니다. 생성자는 무작위 노이즈를 입력으로 받아 새로운 데이터를 생성하고, 판별자는 입력 데이터가 진짜 데이터인지 생성된 데이터인지를 구별합니다. 이 두 네트워크는 서로 경쟁하면서 학습합니다.

2.1 생성자 (Generator)

생성자는 노이즈 벡터를 받아 진짜처럼 보이는 데이터를 생성합니다. 목표는 판별자를 속이는 것입니다.

2.2 판별자 (Discriminator)

판별자는 입력 데이터의 진위 여부를 판단합니다. 진짜 데이터일 경우 1, 생성된 데이터일 경우 0을 출력합니다.

2.3 GAN의 손실 함수

GAN의 손실 함수는 다음과 같이 설정됩니다:

min_G max_D V(D, G) = E[log(D(x))] + E[log(1 - D(G(z)))]

여기서 E는 기대값을 나타내며, x는 진짜 데이터, G(z)는 생성자가 생성한 데이터입니다. 생성자는 손실을 최소화하려고 하고, 판별자는 손실을 최대화하려 하고 있습니다.

3. 파이토치를 활용한 GAN 구현

이제 GAN을 파이토치로 구현해 보겠습니다. 데이터셋으로는 MNIST 손글씨 숫자 데이터를 사용할 것입니다.

3.1 데이터셋 준비

import torch
import torchvision
from torchvision import datasets, transforms

# 데이터 변환 및 다운로드
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# MNIST 데이터셋
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

3.2 생성자 (Generator) 모델 정의

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(True)
        )
        self.layer4 = nn.Sequential(
            nn.Linear(1024, 28*28),
            nn.Tanh()  # 픽셀 값은 -1과 1 사이
        )
    
    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out.view(-1, 1, 28, 28)  # 이미지 형태로 변환

3.3 판별자 (Discriminator) 모델 정의

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer4 = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()  # 출력값을 0과 1 사이로
        )
    
    def forward(self, x):
        out = self.layer1(x.view(-1, 28*28))  # 평탄화
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out

3.4 모델 훈련

import torch.optim as optim

# 모델 초기화
generator = Generator()
discriminator = Discriminator()

# 손실 함수 및 최적화기 설정
criterion = nn.BCELoss()  # Binary Cross Entropy Loss
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# 훈련
num_epochs = 200
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # 진짜 데이터 레이블
        real_labels = torch.ones(images.size(0), 1)
        fake_labels = torch.zeros(images.size(0), 1)

        # 판별자 훈련
        optimizer_d.zero_grad()
        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()
        
        z = torch.randn(images.size(0), 100)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        
        optimizer_d.step()
        
        # 생성자 훈련
        optimizer_g.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss_real.item() + d_loss_fake.item():.4f}, g_loss: {g_loss.item():.4f}')

3.5 결과 시각화

import matplotlib.pyplot as plt

# 생성된 이미지 시각화 함수
def plot_generated_images(generator, n=10):
    z = torch.randn(n, 100)
    with torch.no_grad():
        generated_images = generator(z).cpu()
    generated_images = generated_images.view(-1, 28, 28)
    
    plt.figure(figsize=(10, 1))
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.imshow(generated_images[i], cmap='gray')
        plt.axis('off')
    plt.show()

# 이미지 생성
plot_generated_images(generator)

4. 강화학습의 기본 개념

강화학습(Reinforcement Learning, RL)은 에이전트가 환경과 상호작용하며 최적의 행동을 학습하는 기계 학습의 한 분야입니다. 에이전트는 상태를 관찰하고, 행동을 선택하고, 보상을 받으며, 이를 통해 최적의 정책을 학습합니다.

4.1 강화학습의 구성 요소

  • 상태 (State): 에이전트가 현재의 환경을 나타내는 정보입니다.
  • 행동 (Action): 에이전트가 현재 상태에서 수행할 수 있는 작업입니다.
  • 보상 (Reward): 에이전트가 행동을 수행한 후에 환경으로부터 받는 피드백입니다.
  • 정책 (Policy): 에이전트가 각 상태에서 취할 행동의 확률 분포를 나타냅니다.

4.2 강화학습 알고리즘

  • Q-Learning: 가치 기반 방법으로, Q 값을 학습하여 최적의 정책을 유도합니다.
  • 정책 경사 방법 (Policy Gradient): 정책을 직접 학습하는 방법입니다.
  • Actor-Critic: 가치 함수와 정책을 동시에 학습하는 방법입니다.

4.3 파이토치를 활용한 강화학습 구현

간단한 강화학습 구현을 위해 OpenAI의 Gym 라이브러리를 사용할 것입니다. 여기서는 CartPole 환경을 다루겠습니다.

4.3.1 Gym 환경 설정

import gym

# Gym 환경 생성
env = gym.make('CartPole-v1')  # CartPole 환경

4.3.2 DQN 모델 정의

class DQN(nn.Module):
    def __init__(self, input_size, num_actions):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, num_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

4.3.3 모델 훈련

def train_dqn(env, num_episodes):
    model = DQN(input_size=env.observation_space.shape[0], num_actions=env.action_space.n)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.MSELoss()

    for episode in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state)
        done = False
        total_reward = 0

        while not done:
            q_values = model(state)
            action = torch.argmax(q_values).item()  # 또는 epsilon-greedy 정책 사용

            next_state, reward, done, _ = env.step(action)
            next_state = torch.FloatTensor(next_state)

            total_reward += reward

            # DQN 업데이트 로직 추가 필요

            state = next_state

        print(f'Episode {episode+1}, Total Reward: {total_reward}')  

    return model

# DQN 훈련 시작
train_dqn(env, num_episodes=1000)

5. 결론

이번 포스팅에서는 GAN과 강화학습의 기본 개념 및 파이토치를 활용한 구현 방법에 대해 알아보았습니다. GAN은 데이터 생성에 매우 유용한 모델이고, 강화학습은 에이전트가 최적의 정책을 학습하도록 돕는 기법입니다. 이러한 기술들은 다양한 분야에서 응용될 수 있으며, 앞으로의 연구와 발전이 기대됩니다.

6. 참고 자료