Deep Learning PyTorch Course, Fine-tuning Techniques

Fine-tuning, one of the subfields of deep learning, is a technique that adjusts a pre-trained model to enhance performance for a specific task. Generally, this technique is an efficient way to save time and resources spent on data collection and training, and it is utilized in various fields such as image recognition and natural language processing.

1. Overview of Fine-tuning Techniques

Fine-tuning techniques are used to improve the predictive performance of a new dataset by transplanting the weights of a pre-trained model. This method proceeds through the following steps:

  • Select a pre-trained model
  • Fine-tune the model on other benchmark tasks
  • Retrain the model on the new dataset
  • Evaluate and optimize the model

2. Fine-tuning in PyTorch

PyTorch provides various tools and libraries that make it easy to implement fine-tuning functionality. The main steps are as follows:

  • Load a pre-trained model
  • Freeze or modify some layers of the model
  • Train the model using a new dataset
  • Save and evaluate the model

2.1 Loading a Pre-trained Model

In PyTorch, you can easily load a pre-trained model using the torchvision library. Here, we will explain using the ResNet18 model as an example.

import torch
import torch.nn as nn
import torchvision.models as models

# Load ResNet18 model
model = models.resnet18(pretrained=True)

2.2 Freezing or Modifying Some Layers of the Model

In general, during fine-tuning, the last layer of the model is modified to fit the new number of classes. For instance, if the number of classes changes from 1000 to 10 in image classification, the last layer needs to be replaced.

# Replace the existing last layer with a new layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # Set to 10

2.3 Training the Model Using a New Dataset

A data loader is set up to train the model on the new dataset.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Set up data transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

# Load the dataset
train_dataset = datasets.FakeData(transform=transform)  # Using sample data
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

2.4 Writing the Training Loop

Write a training loop that defines the learning process.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training loop
for epoch in range(10):  # Change the number of epochs if needed
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

3. Evaluating the Fine-tuning Results

Once training is complete, the model can be evaluated. Typically, a validation dataset is used to assess the model’s performance.

# Load and evaluate the validation dataset
val_dataset = datasets.FakeData(transform=transform)  # Using sample data
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model: {100 * correct / total}%')

4. Conclusion

In deep learning, fine-tuning is a crucial technique that allows for efficient data usage and maximization of performance. PyTorch offers various tools and libraries that make it easy to perform fine-tuning tasks using pre-trained models. Understanding and applying this process is an important step for practically using deep learning technologies.

I hope this course has been helpful. Thank you!