OpenCV 강좌, PyTorch 모델을 OpenCV로 로드하여 실행하기

저자: 조광형

작성일: 2024년 11월 26일

1. 서론

OpenCV(Open Source Computer Vision Library)는 컴퓨터 비전 및 머신러닝을 위한 라이브러리로,
다양한 이미지 및 비디오 처리 기술을 제공합니다. PyTorch는 머신러닝 라이브러리로, 특히 딥러닝 모델의
개발과 학습에 많이 사용됩니다. 본 강좌에서는 PyTorch로 학습한 모델을 OpenCV를 통해 로드하여
실시간으로 실행하는 방법에 대해 설명하겠습니다.

2. 요구 사항

이 강좌를 진행하기 위해 필요한 요구 사항은 다음과 같습니다:

  • Python 3.x 버전
  • OpenCV 라이브러리
  • PyTorch 라이브러리
  • NumPy

Python과 필요한 라이브러리를 설치하려면 다음과 같은 명령어를 사용할 수 있습니다:

pip install opencv-python torchvision torch numpy

3. PyTorch 모델 학습하기

우선 PyTorch를 사용하여 간단한 CNN 모델을 학습시킵니다. 여기서는 MNIST 데이터셋을 사용하여 숫자 이미지를 분류하는 모델을 만들어 보겠습니다.


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 데이터셋 로드 및 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 모델 정의
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 6 * 6)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# 모델 인스턴스 생성 및 손실 함수 및 최적화 알고리즘 정의
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 모델 학습
model.train()
for epoch in range(5):  # 5 에포크 동안 학습
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')

# 모델 저장
torch.save(model.state_dict(), 'mnist_cnn.pth')
            

위 코드는 MNIST 데이터셋을 로드하여 간단한 CNN 모델을 정의하고 학습시키는 코드입니다.
모델이 학습한 후, ‘mnist_cnn.pth’ 파일로 모델을 저장합니다.

4. OpenCV에서 PyTorch 모델 로드하기

저장한 PyTorch 모델을 OpenCV에서 접근하려면, 먼저 모델을 로드하고 OpenCV 형식으로 변환해야 합니다.
OpenCV에서 사용하려면 ONNX(Open Neural Network Exchange) 형식으로 변환해야 합니다. 아래 코드를
사용하여 모델을 ONNX 형식으로 변환합니다.


dummy_input = torch.randn(1, 1, 28, 28)  # MNIST의 이미지 크기
torch.onnx.export(model, dummy_input, 'mnist_cnn.onnx')
            

위 코드는 모델을 ONNX 형식으로 변환하는 과정입니다. 이제 OpenCV를 사용하여 이 ONNX 모델을 로드할 수 있습니다.

5. OpenCV를 사용하여 ONNX 모델 실행하기

OpenCV에서 ONNX 모델을 로드하고 실행하는 방법은 다음과 같습니다. OpenCV의 dnn 모듈을 사용하여
모델을 로드하고 이미지를 입력으로 사용하여 예측할 수 있습니다.


import cv2
import numpy as np

# 모델 로드
net = cv2.dnn.readNetFromONNX('mnist_cnn.onnx')

# 이미지 전처리
image = cv2.imread('test_image.png', cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (28, 28))
image = image.astype(np.float32) / 255.0
image = np.expand_dims(image, axis=0)
image = np.expand_dims(image, axis=0)
image = np.array(image, dtype=np.float32)

# 모델에 입력
blob = cv2.dnn.blobFromImage(image)
net.setInput(blob)
output = net.forward()

# 결과 출력
predicted_class = np.argmax(output, axis=1)
print(f'Predicted class: {predicted_class[0]}')
            

위 코드는 OpenCV를 사용하여 ONNX 모델을 로드하고, 주어진 이미지를 전처리한 후 모델을 통해
예측 결과를 출력합니다. 이미지가 MNIST 데이터셋의 크기에 맞게 전처리되었는지 확인하세요.

6. 실시간 이미지 처리

OpenCV를 사용하여 웹캠에서 실시간으로 이미지를 처리하고 예측하는 방법도 알아보겠습니다.
아래 코드는 웹캠에서 이미지를 캡처하고, 모델을 사용하여 숫자를 인식하는 예제입니다.


cap = cv2.VideoCapture(0)  # 웹캠 캡처 시작

while True:
    ret, frame = cap.read()
    if not ret:
        break
    
    # 이미지 전처리
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    resized = cv2.resize(gray, (28, 28))
    normed = resized.astype(np.float32) / 255.0
    blob = cv2.dnn.blobFromImage(normed)
    
    # 모델에 입력
    net.setInput(blob)
    output = net.forward()
    predicted_class = np.argmax(output, axis=1)[0]
    
    # 예측 결과 표시
    cv2.putText(frame, f'Predicted: {predicted_class}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                1, (0, 255, 0), 2, cv2.LINE_AA)
    
    cv2.imshow('Webcam', frame)
    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
            

위 코드는 웹캠으로부터 프레임을 가져오고 각 프레임에 대해 전처리를 수행하여 모델을 통해
예측을 수행하고, 예측된 숫자를 화면에 표시합니다.
‘q’ 키를 눌러서 종료할 수 있습니다.

7. 결론

이번 강좌에서는 PyTorch로 학습한 CNN 모델을 OpenCV에서 로드하여 실행하는 방법을 알아보았습니다.
ONNX 형식으로 모델을 변환하고, OpenCV의 DNN 모듈을 활용하여 실시간 이미지 처리에 활용할 수 있음을
배웠습니다. 이러한 방식으로 기존의 다양한 머신러닝 모델을 OpenCV와 통합하여 이미지나 비디오 처리
작업에 활용할 수 있습니다. 향후 더 복잡한 모델 및 다양한 데이터셋으로의 확장을 고려해 보세요.