Maison > Périphériques technologiques > IA > Neuf opérations clés de PyTorch !

Neuf opérations clés de PyTorch !

PHPz
Libérer: 2024-01-06 16:45:47
avant
741 Les gens l'ont consulté

Aujourd'hui, nous allons parler de PyTorch. J'ai résumé les neuf opérations PyTorch les plus importantes, ce qui vous donnera un concept global.

Neuf opérations clés de PyTorch !

Création de tenseurs et opérations de base

Les tenseurs PyTorch sont similaires aux tableaux NumPy, mais ils ont des fonctions d'accélération GPU et de dérivation automatique. Nous pouvons utiliser la fonction torch.tensor pour créer des tenseurs, ou nous pouvons utiliser torch.zeros, torch.ones et d'autres fonctions pour créer des tenseurs. Ces fonctions peuvent nous aider à créer des tenseurs plus facilement. Le module

import torch# 创建张量a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])# 张量加法c = a + bprint(c)
Copier après la connexion

Autograd (Autograd)

torch.autograd fournit un mécanisme de dérivation automatique, permettant les opérations d'enregistrement et le calcul des gradients.

x = torch.tensor([1.0], requires_grad=True)y = x**2y.backward()print(x.grad)
Copier après la connexion

Couche de réseau neuronal (nn.Module)

torch.nn.Module est le composant de base pour la construction d'un réseau neuronal. Il peut inclure diverses couches, telles qu'une couche linéaire (nn.Linear), une couche convolutive (nn.Conv2d). ) attendez.

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self): super(SimpleNN, self).__init__() self.fc = nn.Linear(10, 5)def forward(self, x): return self.fc(x)model = SimpleNN()
Copier après la connexion

Optimiseur

L'optimiseur est utilisé pour ajuster les paramètres du modèle afin de réduire la fonction de perte. Vous trouverez ci-dessous un exemple utilisant l'optimiseur Stochastic Gradient Descent (SGD).

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=0.01)
Copier après la connexion

Fonction de perte

La fonction de perte est utilisée pour mesurer l'écart entre la sortie du modèle et la cible. Par exemple, la perte d’entropie croisée convient aux problèmes de classification.

loss_function = nn.CrossEntropyLoss()
Copier après la connexion

Chargement et prétraitement des données

Le module torch.utils.data de PyTorch fournit les classes Dataset et DataLoader pour le chargement et le prétraitement des données. Les classes d'ensembles de données peuvent être personnalisées pour s'adapter à différents formats de données et tâches.

from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):# 实现数据集的初始化和__getitem__方法dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
Copier après la connexion

Sauvegarde et chargement du modèle

Vous pouvez utiliser torch.save pour enregistrer le dictionnaire d'état du modèle et torch.load pour charger le modèle. Le module

# 保存模型torch.save(model.state_dict(), 'model.pth')# 加载模型loaded_model = SimpleNN()loaded_model.load_state_dict(torch.load('model.pth'))
Copier après la connexion

Ajustement du taux d'apprentissage

torch.optim.lr_scheduler fournit des outils pour l'ajustement du taux d'apprentissage. Par exemple, StepLR peut être utilisé pour réduire le taux d'apprentissage après chaque époque.

from torch.optim import lr_schedulerscheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
Copier après la connexion

Évaluation du modèle

Une fois la formation du modèle terminée, les performances du modèle doivent être évaluées. Lors de l'évaluation, vous devez passer le modèle en mode d'évaluation (model.eval()) et utiliser le gestionnaire de contexte torch.no_grad() pour éviter les calculs de dégradé.

model.eval()with torch.no_grad():# 运行模型并计算性能指标
Copier après la connexion

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:51cto.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal