在深度学习模型训练中,我们有时需要冻结网络中的某些层,即阻止这些层的参数在反向传播过程中被更新。这在多种场景下非常有用,例如:
然而,如何正确地冻结一个中间层,同时确保其前后层能够正常更新,是一个常见的疑问。本文将详细探讨两种常用的方法,并通过实验分析它们的实际效果。
torch.no_grad() 是PyTorch提供的一个上下文管理器,其作用是在其内部执行的代码块中,禁用梯度计算。这意味着,在该代码块中创建的任何张量都不会追踪其操作历史,也不会计算梯度。
考虑一个简单的三层线性网络:lin0 -> lin1 -> lin2。如果我们的目标是冻结 lin1,同时允许 lin0 和 lin2 更新,一个直观的想法是在 lin1 的前向传播中使用 torch.no_grad():
import torch import torch.nn as nn class SimpleModelNoGrad(nn.Module): def __init__(self): super(SimpleModelNoGrad, self).__init__() self.lin0 = nn.Linear(1, 2) self.lin1 = nn.Linear(2, 2) self.lin2 = nn.Linear(2, 10) def forward(self, x): x = self.lin0(x) # 在lin1的前向传播中使用no_grad with torch.no_grad(): x = self.lin1(x) x = self.lin2(x) return x # 实例化模型 model_nograd = SimpleModelNoGrad() # 记录初始参数 initial_lin0_weight = model_nograd.lin0.weight.clone() initial_lin1_weight = model_nograd.lin1.weight.clone() initial_lin2_weight = model_nograd.lin2.weight.clone() # 模拟训练步骤 optimizer = torch.optim.SGD(model_nograd.parameters(), lr=0.01) input_data = torch.randn(1, 1) target = torch.randint(0, 10, (1,)) loss_fn = nn.CrossEntropyLoss() print("--- 使用 torch.no_grad() 策略 ---") print("初始 lin0 权重:\n", initial_lin0_weight) print("初始 lin1 权重:\n", initial_lin1_weight) print("初始 lin2 权重:\n", initial_lin2_weight) # 进行一次前向传播、反向传播和优化 optimizer.zero_grad() output = model_nograd(input_data) loss = loss_fn(output, target) loss.backward() optimizer.step() # 检查参数变化 print("\n更新后 lin0 权重:\n", model_nograd.lin0.weight) print("更新后 lin1 权重:\n", model_nograd.lin1.weight) print("更新后 lin2 权重:\n", model_nograd.lin2.weight) print("\nlin0 权重是否改变:", not torch.equal(initial_lin0_weight, model_nograd.lin0.weight)) print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight, model_nograd.lin1.weight)) print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight, model_nograd.lin2.weight))
实验结果分析: 在上述实验中,你会发现 lin0、lin1 和 lin2 的参数都没有更新。这是因为 torch.no_grad() 不仅阻止了 lin1 内部的梯度计算,更重要的是,它切断了从 lin2 到 lin1 再到 lin0 的整个梯度回传路径。一旦某个张量(lin1 的输出)在 no_grad 块中生成,它就没有梯度历史,因此其上游的 lin0 也无法接收到梯度信号,从而导致所有相关参数都无法更新。
结论: torch.no_grad() 适用于完全禁用某个计算分支的梯度计算,例如在推理阶段或特征提取阶段。它不适用于需要精确冻结中间层同时允许其上游层更新的场景。
更精确地冻结特定层的方法是直接修改其参数的 requires_grad 属性。PyTorch中的每个张量都有一个 requires_grad 属性,默认为 True。如果将其设置为 False,PyTorch将不会为该张量计算梯度,并且在反向传播时,任何依赖于该张量的操作的梯度都不会传播到该张量。
为了冻结 lin1,我们需要在模型定义之后,但在优化器初始化之前,将其所有参数(权重和偏置)的 requires_grad 属性设置为 False。
import torch import torch.nn as nn class SimpleModelRequiresGrad(nn.Module): def __init__(self): super(SimpleModelRequiresGrad, self).__init__() self.lin0 = nn.Linear(1, 2) self.lin1 = nn.Linear(2, 2) self.lin2 = nn.Linear(2, 10) def forward(self, x): x = self.lin0(x) x = self.lin1(x) x = self.lin2(x) return x # 实例化模型 model_req_grad = SimpleModelRequiresGrad() # 在优化器定义之前,冻结lin1的参数 for param in model_req_grad.lin1.parameters(): param.requires_grad = False # 记录初始参数 initial_lin0_weight_rg = model_req_grad.lin0.weight.clone() initial_lin1_weight_rg = model_req_grad.lin1.weight.clone() initial_lin2_weight_rg = model_req_grad.lin2.weight.clone() # 只有requires_grad=True的参数才会被优化器考虑 optimizer_rg = torch.optim.SGD(filter(lambda p: p.requires_grad, model_req_grad.parameters()), lr=0.01) input_data_rg = torch.randn(1, 1) target_rg = torch.randint(0, 10, (1,)) loss_fn_rg = nn.CrossEntropyLoss() print("\n--- 使用 requires_grad=False 策略 ---") print("初始 lin0 权重:\n", initial_lin0_weight_rg) print("初始 lin1 权重:\n", initial_lin1_weight_rg) print("初始 lin2 权重:\n", initial_lin2_weight_rg) # 进行一次前向传播、反向传播和优化 optimizer_rg.zero_grad() output_rg = model_req_grad(input_data_rg) loss_rg = loss_fn_rg(output_rg, target_rg) loss_rg.backward() optimizer_rg.step() # 检查参数变化 print("\n更新后 lin0 权重:\n", model_req_grad.lin0.weight) print("更新后 lin1 权重:\n", model_req_grad.lin1.weight) print("更新后 lin2 权重:\n", model_req_grad.lin2.weight) print("\nlin0 权重是否改变:", not torch.equal(initial_lin0_weight_rg, model_req_grad.lin0.weight)) print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight_rg, model_req_grad.lin1.weight)) print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight_rg, model_req_grad.lin2.weight))
实验结果分析: 通过这种方法,你会发现 lin0 和 lin2 的参数得到了更新,而 lin1 的参数保持不变。这是因为 lin1 的 requires_grad 被设置为 False,其梯度不会被计算,也不会参与优化。但 lin2 的梯度会正常计算并回传到 lin1 的输入,由于 lin1 的参数不需要梯度,梯度会继续回传到 lin0,从而使得 lin0 也能正常更新。
关键注意事项:
特性/方法 | torch.no_grad() | param.requires_grad = False |
---|---|---|
作用范围 | 局部,作用于上下文管理器内的所有计算操作 | 全局,作用于特定参数本身 |
梯度回传 | 切断梯度回传路径,其上游和自身均无法更新 | 允许梯度通过,但不会为 requires_grad=False 的参数计算和存储梯度,其上游层可正常更新 |
适用场景 | 推理阶段、性能评估、特征提取等不需要梯度计算的场景 | 冻结特定层进行迁移学习、微调、或实验控制等需要精确控制参数更新的场景 |
推荐程度 | 不推荐用于精确冻结中间层并允许前后层更新的场景 | 强烈推荐用于精确冻结特定层的场景 |
综上所述,当您需要在PyTorch中冻结一个中间层,同时确保其前后层能够正常训练和更新时,设置目标层的参数 requires_grad=False 是最准确和推荐的方法。torch.no_grad() 更适用于完全禁用某个计算路径的梯度追踪,它会影响到整个计算链条,导致意外的冻结效果。理解这两种机制的差异,对于高效和准确地进行模型训练至关重要。
以上就是PyTorch中冻结中间层参数的策略与实践的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 //m.sbmmt.com/ All Rights Reserved | php.cn | 湘ICP备2023035733号