딥러닝 모델을 훈련하는 과정에서 과적합(overfitting)은 흔히 발생하는 문제 중 하나입니다. 과적합이란, 모델이 훈련 데이터에 지나치게 적합하여 새로운 데이터에 대한 일반화 능력이 떨어지는 현상을 말합니다. 그래서 많은 연구자와 엔지니어가 다양한 방법을 통해 과적합을 방지하고자 노력합니다. 그 중 하나가 바로 ‘조기 종료(Early Stopping)’입니다.
조기 종료(Early Stopping)이란?
조기 종료는 모델의 훈련 과정을 모니터링하여 검증 데이터에 대한 성능이 개선되지 않을 때 훈련을 중단하는 기법입니다. 이 방법은 훈련 데이터에서 모델이 성공적으로 학습했더라도 검증 데이터에서 성능이 떨어지면 훈련을 멈춤으로써 과적합을 방지합니다.
조기 종료의 작동 원리
조기 종료는 기본적으로 모델 훈련 중 검증 손실(validation loss) 또는 검증 정확도(validation accuracy)를 관찰하며, 일정 에폭(epoch) 동안 성능 향상이 없을 경우 훈련을 중지합니다. 이때, 최상의 모델 파라미터를 저장해두어, 훈련이 끝난 후 이 모델을 사용할 수 있습니다.
조기 종료의 구현
여기서는 PyTorch를 활용하여 간단한 이미지 분류 모델을 훈련시키는 예제를 통해 조기 종료를 구현해 보겠습니다. 이 예제에서는 MNIST 데이터셋을 사용하여 손글씨 숫자를 인식하는 모델을 학습합니다.
필요한 라이브러리 설치
pip install torch torchvision matplotlib numpy
코드 예제
아래는 조기 종료를 적용한 PyTorch 코드 예제입니다.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 하이퍼파라미터 설정
input_size = 28 * 28 # MNIST 이미지 크기
num_classes = 10 # 분류할 클래스 수
num_epochs = 20 # 전체 학습 에폭
batch_size = 100 # 배치 크기
learning_rate = 0.001 # 학습률
# MNIST 데이터셋 로드
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
# 단순한 신경망 모델 정의
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = x.view(-1, input_size) # 이미지의 차원 변경
x = torch.relu(self.fc1(x)) # 활성화 함수
x = self.fc2(x)
return x
# 모델, 손실 함수 및 옵티마이저 초기화
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 조기 종료를 위한 변수 초기화
best_loss = float('inf')
patience, trials = 5, 0 # 최대 5회 성능 향상 없을 시 훈련 중지
train_losses, val_losses = [], []
# 훈련 루프
for epoch in range(num_epochs):
model.train() # 모델을 훈련 모드로 전환
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad() # 기울기 초기화
outputs = model(images) # 모델 예측
loss = criterion(outputs, labels) # 손실 계산
loss.backward() # 기울기 계산
optimizer.step() # 가중치 업데이트
running_loss += loss.item()
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
# 검증 단계
model.eval() # 모델을 평가 모드로 전환
val_loss = 0.0
with torch.no_grad(): # 기울기 계산 비활성화
for images, labels in test_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
avg_val_loss = val_loss / len(test_loader)
val_losses.append(avg_val_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_val_loss:.4f}')
# 조기 종료 로직
if avg_val_loss < best_loss:
best_loss = avg_val_loss
trials = 0 # 성능 향상 기록 리셋
torch.save(model.state_dict(), 'best_model.pth') # 최적 모델 저장
else:
trials += 1
if trials >= patience: # patience 만큼 성능 향상이 없으면 훈련 중지
print("Early stopping...")
break
# 테스트 데이터에 대한 성능 평가
model.load_state_dict(torch.load('best_model.pth')) # 최적 모델 로드
model.eval() # 모델을 평가 모드로 전환
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1) # 최대 확률 클래스 선택
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')
코드 설명
위 코드는 MNIST 데이터셋을 사용하는 간단한 신경망 모델을 훈련하는 과정입니다. 먼저 필요한 라이브러리를 import하고, MNIST 데이터셋을 로드합니다. 그리고 단순한 두 개의 Fully Connected Layer로 구성된 신경망을 정의합니다.
이후, 에폭마다 훈련 손실과 검증 손실을 계산하고, 조기 종료 로직을 통해 검증 손실이 개선되지 않으면 훈련을 중단합니다. 마지막으로 테스트 데이터에 대한 정확도를 계산하여 모델의 성능을 평가합니다.
결론
조기 종료는 딥러닝 모델의 성능 최적화를 위한 유용한 기술입니다. 이를 통해 과적합을 방지하고 최적의 모델을 도출할 수 있습니다. 본 강좌에서는 PyTorch를 이용하여 조기 종료를 구현하여 MNIST 분류 문제를 해결해보았습니다. 이를 바탕으로 다양한 딥러닝 문제에 조기 종료 기술을 적용해 보시기를 권장합니다.
참고문헌
- Deep Learning, Ian Goodfellow, Yoshua Bengio, Aaron Courville.
- PyTorch Documentation: https://pytorch.org/docs/stable/index.html
- MNIST Dataset: http://yann.lecun.com/exdb/mnist/