Maison > développement back-end > Tutoriel Python > Pourquoi affiner un modèle MLP sur un petit ensemble de données conserve toujours la même précision de test que les poids pré-entraînés ?

Pourquoi affiner un modèle MLP sur un petit ensemble de données conserve toujours la même précision de test que les poids pré-entraînés ?

WBOY
Libérer: 2024-02-10 21:36:04
avant
617 Les gens l'ont consulté

为什么在小数据集上微调 MLP 模型,仍然保持与预训练权重相同的测试精度?

Contenu de la question

J'ai conçu un modèle mlp simple pour m'entraîner sur 6 000 échantillons de données.

class mlp(nn.module):
    def __init__(self,input_dim=92, hidden_dim = 150, num_classes=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        #self.softmax = nn.softmax(dim=1)

        self.layers = nn.sequential(
            nn.linear(self.input_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.num_classes),

        )

    def forward(self, x):
        x = self.layers(x)
        return x
Copier après la connexion

et le modèle est instancié

model = mlp(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes).to(device)

optimizer = optimizer.adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.crossentropyloss()
Copier après la connexion

et hyperparamètres :

num_epoch = 300   # 200e3//len(train_loader)
learning_rate = 1e-3
batch_size = 64
device = torch.device("cuda")
seed = 42
torch.manual_seed(42)
Copier après la connexion

Ma mise en œuvre suit principalement cette question. J'enregistre le modèle sous forme de poids pré-entraînés model_weights.pth.

model在测试数据集上的准确率是96.80%.

Ensuite, j'ai encore 50 échantillons (en finetune_loader) sur lesquels j'essaye d'affiner le modèle :

model_finetune = MLP()
model_finetune.load_state_dict(torch.load('model_weights.pth'))
model_finetune.to(device)
model_finetune.train()
# train the network
for t in tqdm(range(num_epoch)):
  for i, data in enumerate(finetune_loader, 0):
    #def closure():
      # Get and prepare inputs
      inputs, targets = data
      inputs, targets = inputs.float(), targets.long()
      inputs, targets = inputs.to(device), targets.to(device)
      
      # Zero the gradients
      optimizer.zero_grad()
      # Perform forward pass
      outputs = model_finetune(inputs)
      # Compute loss
      loss = criterion(outputs, targets)
      # Perform backward pass
      loss.backward()
      #return loss
      optimizer.step()     # a

model_finetune.eval()
with torch.no_grad():
    outputs2 = model_finetune(test_data)
    #predicted_labels = outputs.squeeze().tolist()

    _, preds = torch.max(outputs2, 1)
    prediction_test = np.array(preds.cpu())
    accuracy_test_finetune = accuracy_score(y_test, prediction_test)
    accuracy_test_finetune
    
    Output: 0.9680851063829787
Copier après la connexion

J'ai vérifié, la précision reste la même qu'avant d'affiner le modèle à 50 échantillons, et les probabilités de sortie sont également les mêmes.

Quelle pourrait en être la raison ? Ai-je commis des erreurs en peaufinant le code ?


Bonne réponse


Vous devez réinitialiser l'optimiseur avec un nouveau modèle (objet model_finetune). Actuellement, comme je peux le voir dans votre code, il semble toujours utiliser l'optimiseur initialisé avec les anciens poids de modèle - model.parameters().

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

source:stackoverflow.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal