이미지 분류의 카테고리 불균형 문제, 구체적인 코드 예제가 필요합니다
요약: 이미지 분류 작업에서 데이터 세트의 카테고리에 불균형 문제가 있을 수 있습니다. 즉, 일부 카테고리의 샘플 수가 훨씬 더 많습니다. 다른 카테고리보다 이러한 클래스 불균형은 모델 학습 및 성능에 부정적인 영향을 미칠 수 있습니다. 이 기사에서는 클래스 불균형 문제의 원인과 영향을 설명하고 문제를 해결하기 위한 구체적인 코드 예제를 제공합니다.
클래스 불균형 문제는 모델의 훈련과 성능에 부정적인 영향을 미칩니다. 첫째, 일부 범주의 표본 수가 적기 때문에 모델이 이러한 범주를 잘못 판단할 수 있습니다. 예를 들어, 2분류 문제에서 두 범주의 샘플 수는 각각 10개와 1000개입니다. 모델이 학습을 수행하지 않고 모든 샘플을 더 많은 수의 샘플이 포함된 범주로 직접 예측하는 경우 정확도는 다음과 같습니다. 매우 높지만 실제로는 표본이 효과적으로 분류되지 않습니다. 둘째, 불균형한 표본 분포로 인해 모델이 표본 수가 더 많은 범주를 예측하는 방향으로 편향되어 다른 범주에 대한 분류 성능이 저하될 수 있습니다. 마지막으로 불균형한 범주 분포는 소수 범주에 대한 모델의 훈련 샘플이 부족하여 학습된 모델의 소수 범주에 대한 일반화 능력이 저하될 수 있습니다.
언더샘플링이란 샘플 수가 많은 카테고리에서 일부 샘플을 무작위로 삭제하여 각 카테고리의 샘플 수가 더 가까워지도록 하는 것을 의미합니다. 이 방법은 간단하고 간단하지만 샘플을 삭제하면 일부 중요한 기능이 손실될 수 있으므로 정보가 손실될 수 있습니다.
오버샘플링은 각 카테고리의 샘플 수를 보다 균형 있게 만들기 위해 샘플 수가 적은 카테고리에서 일부 샘플을 복사하는 것을 의미합니다. 이 방법은 샘플 수를 늘릴 수 있지만 샘플을 복사하면 모델이 훈련 세트에 과적합되어 일반화 능력이 저하될 수 있으므로 과적합 문제가 발생할 수 있습니다.
가중치 조정은 손실 함수에서 다양한 카테고리의 샘플에 서로 다른 가중치를 부여하여 모델이 샘플 수가 적은 카테고리에 더 많은 주의를 기울이는 것을 의미합니다. 이 방법은 추가 샘플을 도입하지 않고도 클래스 불균형 문제를 효과적으로 해결할 수 있습니다. 구체적인 접근 방식은 샘플 수가 적은 범주가 더 큰 가중치를 갖도록 가중치 벡터를 지정하여 손실 함수에서 각 범주의 가중치를 조정하는 것입니다.
다음은 클래스 불균형 문제를 해결하기 위해 가중치 조정 방법을 사용하는 방법을 보여주는 PyTorch 프레임워크를 사용하는 코드 예제입니다.
import torch import torch.nn as nn import torch.optim as optim # 定义分类网络 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 100) self.fc2 = nn.Linear(100, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.fc2(x) return x # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.9])) # 根据样本数量设置权重 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 训练模型 for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 2000 == 1999: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training')
위 코드에서 두 클래스의 가중치는 더 작은 클래스인 torch.tensor([0.1, 0.9])
로 지정됩니다. 샘플 수 가중치는 0.1이고, 샘플 수가 많은 범주의 가중치는 0.9입니다. 이를 통해 모델은 샘플 수가 적은 카테고리에 더 많은 주의를 기울일 수 있습니다.
참고자료:
[1] He, H., & Garcia, E. A.(2009). 불균형 데이터로부터 학습. IEEE Transactions on Knowledge and Data Engineering, 21(9), 1263-1284.
[2] Chawla , N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P. (2002). SMOTE: 인공 지능 연구 저널, 16, 321-357.
위 내용은 이미지 분류의 클래스 불균형 문제의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!