Alors que les modèles d'apprentissage automatique continuent de gagner en complexité et en capacités. Une technique efficace pour améliorer les performances de modèles volumineux et complexes sur de petits ensembles de données est la distillation des connaissances, qui implique la formation d'un modèle plus petit et plus efficace pour imiter le comportement d'un modèle « enseignant » plus grand.
Dans cet article, nous explorerons le concept de distillation des connaissances et comment le mettre en œuvre dans PyTorch. Nous verrons comment il peut être utilisé pour compresser un modèle volumineux et peu maniable en un modèle plus petit et plus efficace tout en conservant la précision et les performances du modèle d'origine.
Nous définissons d'abord le problème à résoudre par distillation des connaissances.
Nous avons formé un vaste réseau neuronal profond pour effectuer des tâches complexes telles que la classification d'images ou la traduction automatique. Ce modèle peut comporter des milliers de couches et des millions de paramètres, ce qui rend difficile son déploiement dans des applications du monde réel, des appareils de périphérie, etc. Et ce modèle très volumineux nécessite également beaucoup de ressources informatiques pour fonctionner, ce qui le rend incapable de fonctionner sur certaines plates-formes aux ressources limitées.
Une façon de résoudre ce problème consiste à utiliser la distillation des connaissances pour compresser de grands modèles en modèles plus petits. Ce processus implique la formation d'un modèle plus petit pour imiter le comportement du modèle plus grand dans une tâche donnée.
Nous utiliserons un exemple de distillation des connaissances en utilisant l'ensemble de données de radiographie pulmonaire de Kaggle pour la classification de la pneumonie. L'ensemble de données que nous avons utilisé est organisé en 3 dossiers (train, test, val) et contient des sous-dossiers pour chaque catégorie d'image (Pneumonie/Normal). Il existe 5 863 images radiographiques (JPEG) et 2 catégories (pneumonie/normale).
Comparez les images de ces deux classes :
Le chargement et le prétraitement des données sont indépendants du fait que nous utilisons la distillation des connaissances ou un modèle spécifique, l'extrait de code pourrait ressembler à ceci :
transforms_train = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) train_data = ImageFolder(root=train_dir, transform=transforms_train) test_data = ImageFolder(root=test_dir, transform=transforms_test) train_loader = DataLoader(train_data, batch_size=32, shuffle=True) test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
Dans ce contexte Pour le modèle d'enseignant intermédiaire, nous utilisons Resnet-18 et l'affinons sur cet ensemble de données.
import torch import torch.nn as nn import torchvision class TeacherNet(nn.Module): def __init__(self): super().__init__() self.model = torchvision.models.resnet18(pretrained=True) for params in self.model.parameters(): params.requires_grad_ = False n_filters = self.model.fc.in_features self.model.fc = nn.Linear(n_filters, 2) def forward(self, x): x = self.model(x) return x
Le code pour l'entraînement de réglage fin est le suivant
def train(model, train_loader, test_loader, optimizer, criterion, device): dataloaders = {'train': train_loader, 'val': test_loader} for epoch in range(30): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 for inputs, labels in tqdm.tqdm(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
Il s'agit d'une étape d'entraînement de réglage fin standard. Après l'entraînement, nous pouvons voir que le modèle a atteint une précision de 91 % sur l'ensemble de test, ce qui signifie que nous n'avons pas choisi. un modèle plus grand. La raison en est que la précision du test 91 est suffisante pour être utilisée comme modèle de classe de base.
Nous savons que le modèle comporte 11,7 millions de paramètres, il ne pourra donc pas nécessairement s'adapter aux appareils de pointe ou à d'autres scénarios spécifiques.
Notre étudiant est un CNN moins profond avec seulement quelques couches et environ 100 000 paramètres.
class StudentNet(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 4, kernel_size=3, padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.fc = nn.Linear(4 * 112 * 112, 2) def forward(self, x): out = self.layer1(x) out = out.view(out.size(0), -1) out = self.fc(out) return out
C'est très simple si vous regardez le code, n'est-ce pas.
Si je peux simplement former ce réseau neuronal plus petit, pourquoi devrais-je m'embêter avec la distillation des connaissances ? Nous joindrons enfin les résultats de la formation de ce réseau à partir de zéro grâce à l'ajustement des hyperparamètres et d'autres moyens de comparaison ?
Mais maintenant, nous continuons nos étapes de distillation des connaissances
Les étapes de base de la formation sont les mêmes, mais la différence est de savoir comment calculer la perte de formation finale, nous utiliserons la perte du modèle de l'enseignant, le modèle de l'étudiant perte et La perte de distillation est calculée avec la perte finale.
class DistillationLoss: def __init__(self): self.student_loss = nn.CrossEntropyLoss() self.distillation_loss = nn.KLDivLoss() self.temperature = 1 self.alpha = 0.25 def __call__(self, student_logits, student_target_loss, teacher_logits): distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1), F.softmax(teacher_logits / self.temperature, dim=1)) loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss return loss
La fonction de perte est la somme pondérée des deux choses suivantes :
En termes simples, notre modèle d'enseignant doit apprendre aux élèves à « penser », ce qui fait référence à son incertitude ; par exemple, si la probabilité de sortie finale du modèle de l'enseignant est [0,53, 0,47], nous espérons que les élèves obtiendront également les mêmes résultats similaires, la différence entre ces prédictions sont la perte de distillation.
Afin de contrôler la perte, il y a deux paramètres principaux :
Dans les points ci-dessus, les valeurs d'alpha et de température sont basées sur les meilleurs résultats que nous avons essayés avec quelques combinaisons.
Ceci est un résumé tabulaire de cette expérience.
Nous pouvons clairement voir les énormes avantages obtenus en utilisant un CNN plus petit (99,14 %), moins profond : 10 points d'amélioration de la précision par rapport à l'entraînement sans distillation, et 11 points plus rapide que Resnet-18 Times En d'autres termes, notre ! le petit modèle a vraiment appris quelque chose d'utile du grand modèle.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!