1. Introduction
Transfer Learning is a very important technology in the fields of machine learning and deep learning. This technology refers to the process of reusing the weights or parameters learned for one task on another similar task. Transfer learning can save a lot of time and resources when the number of samples is small or when using a new dataset.
2. The Necessity of Transfer Learning
Collecting data and training models require a lot of time and cost. Therefore, by utilizing the knowledge learned from existing models for new tasks, efficiency can be increased. For example, if a model for image classification has already been trained, such a model can be utilized for similar tasks like plant classification.
3. The Concept of Transfer Learning
In general, transfer learning includes the following steps:
- Select a pre-trained model
- Load some or all weights from the existing model
- Retrain part of the model to fit new data (fine-tuning)
4. Transfer Learning in PyTorch
PyTorch provides various features that support transfer learning. This makes it easy to use complex models. The following example explains the process of performing image classification using a pre-trained model with the torchvision library in PyTorch.
4.1 Preparing the Dataset
This section explains how to load and preprocess image datasets. We will use the CIFAR-10 dataset here.
import torch
import torchvision
import torchvision.transforms as transforms
# Data preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False, num_workers=2)
4.2 Loading the Pre-trained Model
This section describes how to load the pre-trained ResNet18 model from PyTorch’s torchvision.
import torchvision.models as models
# Load pre-trained model
model = models.resnet18(pretrained=True)
# Modify the last layer
num_classes = 10 # Number of classes in CIFAR-10
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
4.3 Defining the Loss Function and Optimizer
This section defines the loss function and optimization algorithm for the multi-class classification problem.
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss() # Loss function
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Optimization algorithm
4.4 Training the Model
This section explains the overall code and method for training the model.
# Model training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(10): # Number of epochs adjustable
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# Zero the gradients
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99: # Print every 100 mini-batches
print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
print('Finished Training')
4.5 Evaluating the Model
This section describes how to evaluate the trained model. The accuracy of the model is measured using the test dataset.
# Model evaluation
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')
5. Conclusion
In this course, we explored the concept of transfer learning in deep learning and how to implement it using PyTorch. Transfer learning is an important technology that helps achieve strong performance even in situations where data is scarce. By utilizing various pre-trained models, we can more easily develop high-performance models. We hope that more deep learning applications will be developed through transfer learning in the future.