Deep Learning PyTorch Course, PSPNet

In this course, we will explore one of the latest techniques in image segmentation using deep learning, known as PSPNet (Pyramid Scene Parsing Network). PSPNet demonstrates particularly excellent performance in semantic segmentation of images and can be applied to various image recognition problems.

1. Overview of PSPNet

PSPNet is a network proposed by Zhang et al. in 2017 that captures the global context of an image to predict class probabilities for each pixel. This model operates by integrating information at different scales through the Pyramid Pooling Module (PPM). This structure is advantageous for object recognition, enabling the identification of objects of various sizes.

1.1. Key Features

  • Pyramid Pooling Module: Extracts and integrates features from multiple sizes of the image, providing more comprehensive contextual information.
  • Global Information Integration: Integrates global information in the final stage of the network to enhance the final prediction.
  • Excellent Performance: Shows outstanding performance across several benchmark datasets and can be utilized in various application fields.

2. Structure of PSPNet

The basic structure of PSPNet can be divided as follows:

  1. Backbone Network: Based on CNN models such as ResNet.
  2. Pyramid Pooling Module: Integrates multi-scale feature maps to capture overall context.
  3. Upsampling: Adjusts to an appropriate resolution for final predictions.

2.1. Pyramid Pooling Module

The PPM generates feature maps at multiple resolutions for the input image. This module conducts pooling operations of different sizes to collect spatial information and integrates it back to the original resolution. The PPM consists of the following steps:

  • Performs pooling operations of various sizes on the input feature map (e.g., 1×1, 2×2, 3×3, 6×6).
  • Upsamples the feature maps outputted from each pooling stage back to the original resolution.
  • Finally, concatenates all upsampled feature maps to create a new feature map.

3. Implementing PSPNet with PyTorch

Now, let’s implement PSPNet using PyTorch. The code below defines the structure of PSPNet.

3.1. Setting Up the Environment

import torch
import torch.nn as nn
import torchvision.models as models
    

3.2. Defining the PSPNet Class

The PSPNet class integrates the backbone network and the pyramid pooling module. It can be defined as follows:

class PSPModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PSPModule, self).__init__()
        self.pool1 = nn.AvgPool2d(1, stride=1)
        self.pool2 = nn.AvgPool2d(2, stride=2)
        self.pool3 = nn.AvgPool2d(3, stride=3)
        self.pool4 = nn.AvgPool2d(6, stride=6)
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        size = x.size()[2:]
        p1 = self.conv1x1(self.pool1(x))
        p2 = self.conv1x1(self.pool2(x))
        p3 = self.conv1x1(self.pool3(x))
        p4 = self.conv1x1(self.pool4(x))

        p1 = nn.functional.interpolate(p1, size, mode='bilinear', align_corners=True)
        p2 = nn.functional.interpolate(p2, size, mode='bilinear', align_corners=True)
        p3 = nn.functional.interpolate(p3, size, mode='bilinear', align_corners=True)
        p4 = nn.functional.interpolate(p4, size, mode='bilinear', align_corners=True)

        return torch.cat((x, p1, p2, p3, p4), dim=1)

class PSPNet(nn.Module):
    def __init__(self, num_classes):
        super(PSPNet, self).__init__()
        self.backbone = models.resnet101(pretrained=True)
        self.ppm = PSPModule(2048, 512)
        self.final_convolution = nn.Conv2d(2048 + 512 * 4, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.backbone(x)
        x = self.ppm(x)
        x = self.final_convolution(x)
        return x

3.3. Training the Model

To train the model, you need to prepare the dataset, set up the optimizer, and write the training loop. Let’s take the Cityscapes dataset from torchvision as an example.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Preparing the dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_dataset = datasets.Cityscapes(root='path/to/cityscapes/', split='train', mode='fine', target_type='semantic', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Setting up the model and optimizer
model = PSPNet(num_classes=19).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(num_epochs):
    model.train()
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

4. Experiments and Evaluation

After training is complete, performance should be measured on the validation dataset to evaluate the model. Metrics commonly used for evaluation include IoU (Intersection over Union) and Pixel Accuracy. The following code illustrates how to assess the model’s performance.

def evaluate(model, val_loader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_pixels = 0

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            total_correct += (preds == masks).sum().item()
            total_pixels += masks.numel()

    print(f'Validation Loss: {total_loss / len(val_loader):.4f}, Pixel Accuracy: {total_correct / total_pixels:.4f}')

# Performing evaluation
evaluate(model, val_loader)

5. Conclusion

In this lecture, we explored the structure and operational principles of PSPNet. I hope you have understood how to address semantic segmentation problems through the process of implementing and training the model with PyTorch. PSPNet is a network that demonstrates excellent performance and can be utilized in various real-world image processing problems and applications.

References:

  • Zhang, Y., et al. (2017). Pyramid Scene Parsing Network. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
  • PyTorch. (n.d.). PyTorch Documentation. Retrieved from https://pytorch.org/docs/stable/index.html