Deep Learning with GANs Using PyTorch, First LSTM Network

Deep learning is one of the most prominent technologies in the field of artificial intelligence today. It is used in various application areas, and particularly, GAN (Generative Adversarial Network) and LSTM (Long Short-Term Memory) demonstrate remarkable performance in data generation and time series data processing, respectively. In this article, we will explore GAN and LSTM in detail using the PyTorch framework.

1. Overview of GAN (Generative Adversarial Network)

GAN is a generative model proposed by Ian Goodfellow and his colleagues in 2014. GAN consists of two neural networks (Generator and Discriminator). The Generator generates fake data from random noise, and the Discriminator’s role is to distinguish between real and fake data. These two networks compete and learn from each other.

The process is as follows:

  • The Generator takes random noise as input and generates fake data.
  • The Discriminator receives the generated data and real data and classifies them as real or fake.
  • The Discriminator learns not to misclassify fake data as real, while the Generator learns to produce more realistic data.

2. Overview of LSTM (Long Short-Term Memory) Network

LSTM is a type of RNN (Recurrent Neural Network) that excels in handling time series data or sequential data. LSTM cells have memory cells that can efficiently remember past information and control the forgetting process. This is particularly useful when dealing with long sequence data.

The basic components of LSTM are as follows:

  • Input Gate: Determines how much new information to remember.
  • Forget Gate: Determines how much existing information to forget.
  • Output Gate: Determines how much information to output from the current memory cell.

3. Introduction to PyTorch

PyTorch is an open-source machine learning framework developed by Facebook that supports dynamic computation graphs, making it easy to construct and train neural networks. It is also widely used in various fields such as computer vision and natural language processing.

4. Implementing GAN with PyTorch

4.1 Environment Setup

Install PyTorch and the necessary packages. You can install them using pip as follows.

pip install torch torchvision

4.2 Preparing the Dataset

Let’s implement a GAN to generate handwritten digits using the MNIST dataset as an example.


import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)
    

4.3 Defining Generator and Discriminator

The Generator and Discriminator are implemented as neural networks. Each model can be defined as follows.


import torch.nn as nn

# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)  # Reshape to image format

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)
    

4.4 Setting Loss Function and Optimizer

The loss function used is Binary Cross Entropy, and the optimizer is Adam.


import torch.optim as optim

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Set loss function and optimizer
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    

4.5 GAN Training Loop

Now, we can perform the training of the GAN. The Generator generates fake data, and the Discriminator judges it.


num_epochs = 50
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Create labels for real and fake data
        real_labels = torch.ones(imgs.size(0), 1)
        fake_labels = torch.zeros(imgs.size(0), 1)

        # Train Discriminator
        optimizer_D.zero_grad()
        outputs = discriminator(imgs)
        d_loss_real = criterion(outputs, real_labels)

        z = torch.randn(imgs.size(0), 100)
        fake_imgs = generator(z)
        outputs = discriminator(fake_imgs.detach())
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        outputs = discriminator(fake_imgs)
        g_loss = criterion(outputs, real_labels)

        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
    

4.6 Visualizing Generated Images

After training, we visualize the images generated by the Generator.


import matplotlib.pyplot as plt

# Change Generator model to evaluation mode
generator.eval()
z = torch.randn(64, 100)
fake_imgs = generator(z).detach().numpy()

# Output images
plt.figure(figsize=(8, 8))
for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow(fake_imgs[i][0], cmap='gray')
    plt.axis('off')
plt.show()
    

5. Implementing LSTM Network

5.1 Time Series Data Prediction Using LSTM

LSTM also shows excellent performance in predicting time series data. We will look at an example where we implement a simple LSTM model to predict the values of the sine function.

5.2 Preparing the Data

We generate sine function data and prepare it for the LSTM model.


import numpy as np

# Generate data
time = np.arange(0, 100, 0.1)
data = np.sin(time)

# Preprocess data for LSTM input
def create_sequences(data, seq_length):
    sequences = []
    labels = []
    for i in range(len(data) - seq_length):
        sequences.append(data[i:i+seq_length])
        labels.append(data[i+seq_length])
    return np.array(sequences), np.array(labels)

seq_length = 10
X, y = create_sequences(data, seq_length)
X = X.reshape((X.shape[0], X.shape[1], 1))
    

5.3 Defining the LSTM Model

Now, we define the LSTM model.


class LSTMModel(nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
        self.fc = nn.Linear(50, 1)
        
    def forward(self, x):
        out, (hn, cn) = self.lstm(x)
        out = self.fc(hn[-1])
        return out
    

5.4 Setting Loss Function and Optimizer


model = LSTMModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
    

5.5 LSTM Training Loop

We set up the training loop to train the model.


num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    output = model(torch.FloatTensor(X))
    loss = criterion(output, torch.FloatTensor(y).unsqueeze(1))
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    

5.6 Visualizing the Prediction Results

After training is complete, we visualize the prediction results.


import matplotlib.pyplot as plt

# Prediction
model.eval()
predictions = model(torch.FloatTensor(X)).detach().numpy()

# Visualize prediction results
plt.figure(figsize=(12, 6))
plt.plot(data, label='Real Data')
plt.plot(np.arange(seq_length, seq_length + len(predictions)), predictions, label='Predicted Data', color='red')
plt.legend()
plt.show()
    

6. Conclusion

In this post, we explored GAN and LSTM. GAN is used as a generative model for generating data such as images, while LSTM is used as a prediction model for time series data. Both technologies are very important in their respective fields and can be easily implemented through PyTorch. Furthermore, we encourage you to explore various application methods and apply them to your own projects.

7. References

Please refer to the materials below for a deeper understanding of the topics covered in this post.