Deep Learning PyTorch Course, Performance Optimization Using Early Stopping

Overfitting is one of the common problems that occur during the training of deep learning models. Overfitting refers to the phenomenon where a model is too closely fitted to the training data, leading to a decreased ability to generalize to new data. Therefore, many researchers and engineers strive to prevent overfitting through various methods. One of these methods is ‘Early Stopping.’

What is Early Stopping?

Early stopping is a technique that monitors the training process of a model and stops the training when the performance on validation data does not improve. This method prevents overfitting by stopping the training when the model performs poorly on validation data, even if it has learned successfully from the training data.

How Early Stopping Works

Early stopping fundamentally observes the validation loss or validation accuracy during model training and stops the training if there is no performance improvement for a certain number of epochs. At this point, the optimal model parameters are saved, allowing the use of this model after training is completed.

Implementing Early Stopping

Here, we will implement early stopping through a simple example of training an image classification model using PyTorch. In this example, we will use the MNIST dataset to train a model that recognizes handwritten digits.

Installing Required Libraries

pip install torch torchvision matplotlib numpy

Code Example

Below is a PyTorch code example with early stopping applied.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Hyperparameter settings
input_size = 28 * 28  # MNIST image size
num_classes = 10  # Number of classes to classify
num_epochs = 20  # Total number of training epochs
batch_size = 100  # Batch size
learning_rate = 0.001  # Learning rate

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Define a simple neural network model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = x.view(-1, input_size)  # Reshape image dimensions
        x = torch.relu(self.fc1(x))  # Activation function
        x = self.fc2(x)
        return x

# Initialize model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Initialize variables for early stopping
best_loss = float('inf')
patience, trials = 5, 0  # Stop training if no performance improvement for 5 trials
train_losses, val_losses = [], []

# Training loop
for epoch in range(num_epochs):
    model.train()  # Switch model to training mode
    running_loss = 0.0

    for images, labels in train_loader:
        optimizer.zero_grad()  # Reset gradients
        outputs = model(images)  # Model predictions
        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Compute gradients
        optimizer.step()  # Update weights

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation step
    model.eval()  # Switch model to evaluation mode
    val_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for images, labels in test_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(test_loader)
    val_losses.append(avg_val_loss)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_val_loss:.4f}')

    # Early stopping logic
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        trials = 0  # Reset performance improvement record
        torch.save(model.state_dict(), 'best_model.pth')  # Save best model
    else:
        trials += 1
        if trials >= patience:  # Stop training if no improvement for patience
            print("Early stopping...")
            break

# Evaluate performance on test data
model.load_state_dict(torch.load('best_model.pth'))  # Load best model
model.eval()  # Switch model to evaluation mode
correct, total = 0, 0

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # Select class with maximum probability
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Code Explanation

The above code represents the process of training a simple neural network model using the MNIST dataset. First, we import the necessary libraries and load the MNIST dataset. Then, we define a simple neural network composed of two fully connected layers.

After that, at each epoch, we calculate the training loss and validation loss and stop the training if there is no improvement in the validation loss through the early stopping logic. Finally, we evaluate the model’s performance by calculating the accuracy on the test data.

Conclusion

Early stopping is a useful technique for optimizing the performance of deep learning models. It helps prevent overfitting and leads to the generation of an optimal model. In this tutorial, we demonstrated how to implement early stopping using PyTorch to solve the MNIST classification problem. We encourage you to apply early stopping techniques to various deep learning problems based on this.

References