In the field of deep learning, Residual Network, abbreviated as ResNet, has become a very important architecture. ResNet was proposed by Kaiming He in 2015 and provides a way to effectively increase the depth of deep learning models. In various modern computer vision problems, ResNet is considered one of the main reasons for performance improvement.
1. Overview of ResNet
ResNet is a neural network based on the “Residual Learning” framework. Traditionally, deep neural networks (DNNs) tend to suffer from performance degradation as they become deeper. This is primarily due to the vanishing gradient problem, where the gradients diminish during the backpropagation process as the depth of the neural network increases.
To address this issue, ResNet introduced residual connections. Residual connections directly pass information from previous layers to a layer by adding the network’s input to the output. This approach allows for the effective training of deeper networks.
2. Structure of ResNet
ResNet can be composed of models with various depths, typically denoted as “ResNet50”, “ResNet101”, “ResNet152”, and so on. These numbers indicate the total number of layers in the network.
2.1 Basic Block Composition
The basic components of ResNet are composed of the following blocks:
- Convolution Layer
- Batch Normalization
- ReLU Activation Function
- Residual Connection
The structure of a typical ResNet block is as follows:
def resnet_block(input_tensor, filters, kernel_size=3, stride=1):
x = Conv2D(filters, kernel_size=kernel_size, strides=stride, padding='same')(input_tensor)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2D(filters, kernel_size=kernel_size, strides=stride, padding='same')(x)
x = BatchNormalization()(x)
shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same')(input_tensor)
x = Add()([x, shortcut])
x = ReLU()(x)
return x
3. Implementing ResNet with PyTorch
Now, let’s implement ResNet using PyTorch. First, we need to install the required libraries:
pip install torch torchvision
Next, we will implement the basic ResNet model:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet18(num_classes=1000):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
3.1 Preparing to Train the Model
To train the ResNet model, we need to prepare the dataset and set up the optimizer and loss function.
# Preparing the dataset
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Initializing the model
model = resnet18(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
3.2 Training Phase
Now we are ready to train the model:
for epoch in range(10): # Setting epochs
model.train() # Switching the model to training mode
for images, labels in train_loader:
optimizer.zero_grad() # Resetting gradients
outputs = model(images) # Model prediction
loss = criterion(outputs, labels) # Calculating loss
loss.backward() # Backpropagation
optimizer.step() # Updating parameters
print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')
4. Applications of ResNet
ResNet can be used for various computer vision tasks. For example, it is widely applied in image classification, object detection, segmentation, and more complex vision problems. Several image and video tasks used by companies like Google and Facebook incorporate ResNet architecture.
5. Conclusion
In this tutorial, we learned the basic concepts and architecture of ResNet and how to implement a basic ResNet model using PyTorch. ResNet offers flexible ways to build deeper deep learning models and the opportunity to achieve better performance by leveraging residual learning, inspiring many researchers and developers.
Now, you can study more advanced ResNet structures and various parameter tuning techniques, as well as data augmentation methods to improve the model.