Deep Learning PyTorch Course, ResNet

In the field of deep learning, Residual Network, abbreviated as ResNet, has become a very important architecture. ResNet was proposed by Kaiming He in 2015 and provides a way to effectively increase the depth of deep learning models. In various modern computer vision problems, ResNet is considered one of the main reasons for performance improvement.

1. Overview of ResNet

ResNet is a neural network based on the “Residual Learning” framework. Traditionally, deep neural networks (DNNs) tend to suffer from performance degradation as they become deeper. This is primarily due to the vanishing gradient problem, where the gradients diminish during the backpropagation process as the depth of the neural network increases.

To address this issue, ResNet introduced residual connections. Residual connections directly pass information from previous layers to a layer by adding the network’s input to the output. This approach allows for the effective training of deeper networks.

2. Structure of ResNet

ResNet can be composed of models with various depths, typically denoted as “ResNet50”, “ResNet101”, “ResNet152”, and so on. These numbers indicate the total number of layers in the network.

2.1 Basic Block Composition

The basic components of ResNet are composed of the following blocks:

  • Convolution Layer
  • Batch Normalization
  • ReLU Activation Function
  • Residual Connection

The structure of a typical ResNet block is as follows:


def resnet_block(input_tensor, filters, kernel_size=3, stride=1):
    x = Conv2D(filters, kernel_size=kernel_size, strides=stride, padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters, kernel_size=kernel_size, strides=stride, padding='same')(x)
    x = BatchNormalization()(x)
    
    shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same')(input_tensor)
    x = Add()([x, shortcut])
    x = ReLU()(x)
    
    return x

3. Implementing ResNet with PyTorch

Now, let’s implement ResNet using PyTorch. First, we need to install the required libraries:

pip install torch torchvision

Next, we will implement the basic ResNet model:


import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def resnet18(num_classes=1000):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

3.1 Preparing to Train the Model

To train the ResNet model, we need to prepare the dataset and set up the optimizer and loss function.


# Preparing the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initializing the model
model = resnet18(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

3.2 Training Phase

Now we are ready to train the model:


for epoch in range(10): # Setting epochs
    model.train()  # Switching the model to training mode
    for images, labels in train_loader:
        optimizer.zero_grad()  # Resetting gradients
        outputs = model(images)  # Model prediction
        loss = criterion(outputs, labels)  # Calculating loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Updating parameters

    print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

4. Applications of ResNet

ResNet can be used for various computer vision tasks. For example, it is widely applied in image classification, object detection, segmentation, and more complex vision problems. Several image and video tasks used by companies like Google and Facebook incorporate ResNet architecture.

5. Conclusion

In this tutorial, we learned the basic concepts and architecture of ResNet and how to implement a basic ResNet model using PyTorch. ResNet offers flexible ways to build deeper deep learning models and the opportunity to achieve better performance by leveraging residual learning, inspiring many researchers and developers.

Now, you can study more advanced ResNet structures and various parameter tuning techniques, as well as data augmentation methods to improve the model.

6. References