Les décorateurs sont des implémentations spécifiques des gestionnaires de contexte Python. Cet article illustrera comment les utiliser à travers un exemple de débogage GPU pytorch. Même si cela ne fonctionne pas dans toutes les situations, je les ai trouvés très utiles.
Il existe de nombreuses façons de déboguer les fuites de mémoire. Cet article présentera une méthode utile pour identifier les lignes problématiques dans votre code. Cette méthode peut aider à trouver l’emplacement spécifique de manière concise.
Si vous rencontrez un problème, une méthode classique et couramment utilisée consiste à utiliser le débogueur pour inspecter ligne par ligne, comme l'exemple suivant :
Cela fonctionne, mais une telle opération semble lourde. Nous pouvons l’encapsuler dans une fonction, qui peut être appelée en cas de besoin, il n’est donc quasiment pas nécessaire de modifier le code existant, ce qui nous amène à introduire la fonction du décorateur.
Les décorateurs peuvent être enveloppés dans n'importe quelle partie du code. Ici, nous utilisons le décorateur pour vérifier s'il existe des tenseurs supplémentaires. De plus, nous avons également besoin d'un compteur car le nombre de tenseurs doit être calculé avant et après l'exécution. Le modèle ressemble à ceci :
def memleak_wrapper(func): def wrap(*args, **kwargs): print("num tensors start is ...") out = func(*args, **kwargs) print("num tensors end is ...") return out return wrap@memleak_wrapper def function_to_debug(x): print(f"put line(s) of code here. Input is {x}") out = x + 10 return outout = function_to_debug(x=1000) print(f"out is {out}") #输入类似这样 #num tensors start is ... #put line(s) of code here. Input is 1000 #num tensors end is ... #outis 1010
Pour exécuter ce code, nous devons mettre la ligne de code que nous voulons archiver dans une fonction (function_to_debug). Mais ce n’est pas la meilleure solution car nous devons encore insérer beaucoup de code manuellement. L'autre chose est que si le bloc de code génère plus d'une variable, vous devez trouver des solutions supplémentaires pour utiliser ces variables en aval.
Afin de résoudre le problème ci-dessus, nous pouvons utiliser le gestionnaire de contexte au lieu du décorateur de fonction. L'exemple le plus largement utilisé de gestionnaire de contexte consiste à instancier un contexte à l'aide de l'instruction with. Le plus courant était :
with open("file") as f: …
En utilisant la bibliothèque contextlib de Python, les utilisateurs de Python peuvent facilement créer eux-mêmes des gestionnaires de contexte. Donc, dans cet article, nous utiliserons ContextDecorator pour terminer le travail que nous avons essayé d'utiliser le décorateur ci-dessus. Parce que c'est plus facile à développer et à utiliser :
from contextlib import ContextDecorator class check_memory_leak_context(ContextDecorator): def __enter__(self): print('Starting') return self def __exit__(self, *exc): print('Finishing') return False
ContextDecorator a 2 méthodes : enter() et exit() qui sont appelées lorsque l'on entre ou sort du contexte. Le paramètre *exc dans __exit__ représente toute exception entrante.
Utilisons-le maintenant pour résoudre le problème mentionné ci-dessus.
Parce que nous devons calculer le nombre total de tenseurs, nous encapsulons le processus de calcul dans une fonction get_n_tensors(), afin que le nombre de tenseurs puisse être calculé au début et à la fin du contexte :
class check_memory_leak_context(ContextDecorator): def __enter__(self): self.start = get_n_tensors() return self def __exit__(self, *exc): self.end = get_n_tensors() increase = self.end — self.start if increase > 0: print(f”num tensors increased with" f"{self.end — self.start} !”) else: print(”no added tensors”) return False
S'il y a des augmentations, imprimez-le sur la console.
get_n_tensor() utilise un garbage collector (gc) et est personnalisé pour pytorch, mais peut être facilement modifié pour d'autres bibliothèques :
import gc def get_n_tensors(): tensors= [] for obj in gc.get_objects(): try: if (torch.is_tensor(obj) or (hasattr(obj, ‘data’) and torch.is_tensor(obj.data))): tensors.append(obj) except: pass return len(tensors)
Cela fonctionne maintenant, nous l'utilisons pour n'importe quelle ligne (ou bloc) de code Contexte :
x = arbitrary_operation(x) ... with check_memory_leak_context(): y = x[0].permute(1, 2, 0).cpu().detach().numpy() x = some_harmless_operation() ... x = another_arbitrary_operation(x)
Si un nouveau tenseur est créé dans la ligne enveloppée par le décorateur de contexte, il sera imprimé.
C'est un très bon extrait de code, vous pouvez le mettre dans un fichier séparé lors du développement, voici le code complet de cet article :
https://gist.github.com/MarkTension /4783697ebd5212ba500cdd829b364338
Enfin, j'espère que ce petit article pourra vous aider à comprendre ce qu'est un gestionnaire de contexte, comment utiliser les décorateurs de contexte et comment les appliquer pour déboguer pytorch.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!