딥러닝 파이토치 강좌, LeNet-5

딥러닝은 최근 몇 년간 데이터 과학의 여러 분야에서 엄청난 인기를 끌고 있습니다. 다양한 분야의 문제를 해결하는 데 있어 매우 유용한 도구로 자리잡게 되었습니다. 이 강좌에서는 유명한 딥러닝 아키텍처 중 하나인 LeNet-5에 대해 자세히 살펴보겠습니다.

LeNet-5란?

LeNet-5는 Yann LeCun을 포함한 연구자들이 1998년에 개발한 합성곱 신경망(CNN) 아키텍처입니다. 이미지를 인식하는 데 유용한 모델로, 주로 손글씨 숫자 인식에 사용되었습니다. 이 모델은 CNN의 기본 구조를 따르며, 여러 계층으로 구성되어 있습니다. LeNet-5는 다음과 같은 계층으로 이루어져 있습니다:

  • 입력층: 32×32 픽셀의 그레이스케일 이미지.
  • 합성곱층 (C1): 6개의 필터(5×5)를 사용하여 특성 맵을 생성 (28×28 크기).
  • 풀링층 (S2): 평균 풀링을 통해 6개의 14×14 특성 맵 생성.
  • 합성곱층 (C3): 16개의 필터를 사용, 10×10 특성 맵 생성.
  • 풀링층 (S4): 평균 풀링을 통해 16개의 5×5 특성 맵 생성.
  • 합성곱층 (C5): 120개의 필터(5×5)를 사용하여 마지막 특성 맵 생성.
  • 완전 연결층 (F6): 84개의 뉴런으로 최종 출력.
  • 출력층: 10개의 클래스 (0-9)로 분류.

LeNet-5의 중요성

LeNet-5는 CNN의 기본 아키텍처 중 하나로, 많은 깊은 네트워크의 기반이 되었습니다. 이 모델은 이미지 인식 분야에서 많은 혁신을 가져왔고, 현재에도 다양한 변형 모델이 존재합니다. LeNet-5의 단순성과 효율성 덕분에, 많은 데이터셋에서 좋은 성능을 발휘합니다.

LeNet-5 구현하기

이제 PyTorch를 사용하여 LeNet-5를 구현해보겠습니다. PyTorch는 사용자 친화적인 딥러닝 프레임워크로, 다양한 연구와 산업에서 널리 사용됩니다. 또한, PyTorch는 동적 계산 그래프를 사용하는 장점이 있습니다.

환경 설정

먼저, 필요한 라이브러리를 설치하고 환경을 설정해야 합니다. 다음 코드를 사용하여 PyTorch와 torchvision을 설치하세요:

pip install torch torchvision

LeNet-5 모델 구현

이제 LeNet-5의 구조를 구현해보겠습니다:

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.avg_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = F.avg_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

모델 훈련 위한 데이터셋 준비

LeNet-5는 MNIST 데이터셋을 사용하여 훈련할 것입니다. torchvision을 사용하여 데이터를 쉽게 다운로드하고 로드할 수 있습니다. 다음 코드를 사용하여 MNIST 데이터셋을 준비하세요:

from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

모델 훈련

모델을 훈련시키기 위해서는 손실 함수와 최적화 알고리즘을 설정해야 합니다. 여기서는 Cross Entropy Loss와 Adam 옵티마이저를 사용하겠습니다:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

모델 평가

훈련이 완료된 후, 모델의 성능을 평가할 수 있습니다. 테스트 데이터셋을 사용하여 정확도를 확인하겠습니다:

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

결론

이번 강좌에서는 LeNet-5 아키텍처를 PyTorch를 사용하여 구현하고 훈련하는 과정을 살펴보았습니다. LeNet-5는 CNN의 기초를 이해하고 실습할 수 있는 좋은 예제입니다. 이 모델을 기반으로 더 복잡한 네트워크 아키텍처나 다양한 응용으로 발전시킬 수 있습니다. 다음 단계로는 더 깊은 네트워크 구조나 데이터셋을 활용해보는 것을 추천드립니다.

참고 자료