PyTorch中冻结中间层参数的策略与实践

聖光之護
发布: 2025-08-22 15:42:27
原创
833人浏览过

PyTorch中冻结中间层参数的策略与实践

本文深入探讨了在PyTorch神经网络中冻结特定中间层参数的两种主要方法:使用torch.no_grad()上下文管理器和设置参数的requires_grad=False属性。通过实验对比,我们揭示了这两种方法在梯度回传机制上的关键差异,并明确指出在需要精确冻结特定层而允许其他层更新的场景下,应优先采用requires_grad=False策略,以实现灵活高效的模型训练。

导言:理解层冻结的需求

在深度学习模型训练中,我们有时需要冻结网络中的某些层,即阻止这些层的参数在反向传播过程中被更新。这在多种场景下非常有用,例如:

  • 迁移学习(Transfer Learning):使用预训练模型作为特征提取器,只微调顶层分类器。
  • 模型稳定性:在训练的某些阶段,固定部分层以稳定训练过程。
  • 实验控制:隔离特定层的影响,以便更好地理解模型行为。

然而,如何正确地冻结一个中间层,同时确保其前后层能够正常更新,是一个常见的疑问。本文将详细探讨两种常用的方法,并通过实验分析它们的实际效果。

方法一:使用 torch.no_grad() 上下文管理器

torch.no_grad() 是PyTorch提供的一个上下文管理器,其作用是在其内部执行的代码块中,禁用梯度计算。这意味着,在该代码块中创建的任何张量都不会追踪其操作历史,也不会计算梯度。

考虑一个简单的三层线性网络:lin0 -> lin1 -> lin2。如果我们的目标是冻结 lin1,同时允许 lin0 和 lin2 更新,一个直观的想法是在 lin1 的前向传播中使用 torch.no_grad():

import torch
import torch.nn as nn

class SimpleModelNoGrad(nn.Module):
    def __init__(self):
        super(SimpleModelNoGrad, self).__init__()
        self.lin0 = nn.Linear(1, 2)
        self.lin1 = nn.Linear(2, 2)
        self.lin2 = nn.Linear(2, 10)

    def forward(self, x):
        x = self.lin0(x)
        # 在lin1的前向传播中使用no_grad
        with torch.no_grad():
            x = self.lin1(x)
        x = self.lin2(x)
        return x

# 实例化模型
model_nograd = SimpleModelNoGrad()

# 记录初始参数
initial_lin0_weight = model_nograd.lin0.weight.clone()
initial_lin1_weight = model_nograd.lin1.weight.clone()
initial_lin2_weight = model_nograd.lin2.weight.clone()

# 模拟训练步骤
optimizer = torch.optim.SGD(model_nograd.parameters(), lr=0.01)
input_data = torch.randn(1, 1)
target = torch.randint(0, 10, (1,))
loss_fn = nn.CrossEntropyLoss()

print("--- 使用 torch.no_grad() 策略 ---")
print("初始 lin0 权重:\n", initial_lin0_weight)
print("初始 lin1 权重:\n", initial_lin1_weight)
print("初始 lin2 权重:\n", initial_lin2_weight)

# 进行一次前向传播、反向传播和优化
optimizer.zero_grad()
output = model_nograd(input_data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()

# 检查参数变化
print("\n更新后 lin0 权重:\n", model_nograd.lin0.weight)
print("更新后 lin1 权重:\n", model_nograd.lin1.weight)
print("更新后 lin2 权重:\n", model_nograd.lin2.weight)

print("\nlin0 权重是否改变:", not torch.equal(initial_lin0_weight, model_nograd.lin0.weight))
print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight, model_nograd.lin1.weight))
print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight, model_nograd.lin2.weight))
登录后复制

实验结果分析: 在上述实验中,你会发现 lin0、lin1 和 lin2 的参数都没有更新。这是因为 torch.no_grad() 不仅阻止了 lin1 内部的梯度计算,更重要的是,它切断了从 lin2 到 lin1 再到 lin0 的整个梯度回传路径。一旦某个张量(lin1 的输出)在 no_grad 块中生成,它就没有梯度历史,因此其上游的 lin0 也无法接收到梯度信号,从而导致所有相关参数都无法更新。

结论: torch.no_grad() 适用于完全禁用某个计算分支的梯度计算,例如在推理阶段或特征提取阶段。它不适用于需要精确冻结中间层同时允许其上游层更新的场景。

方法二:设置参数的 requires_grad=False 属性

更精确地冻结特定层的方法是直接修改其参数的 requires_grad 属性。PyTorch中的每个张量都有一个 requires_grad 属性,默认为 True。如果将其设置为 False,PyTorch将不会为该张量计算梯度,并且在反向传播时,任何依赖于该张量的操作的梯度都不会传播到该张量。

为了冻结 lin1,我们需要在模型定义之后,但在优化器初始化之前,将其所有参数(权重和偏置)的 requires_grad 属性设置为 False。

import torch
import torch.nn as nn

class SimpleModelRequiresGrad(nn.Module):
    def __init__(self):
        super(SimpleModelRequiresGrad, self).__init__()
        self.lin0 = nn.Linear(1, 2)
        self.lin1 = nn.Linear(2, 2)
        self.lin2 = nn.Linear(2, 10)

    def forward(self, x):
        x = self.lin0(x)
        x = self.lin1(x)
        x = self.lin2(x)
        return x

# 实例化模型
model_req_grad = SimpleModelRequiresGrad()

# 在优化器定义之前,冻结lin1的参数
for param in model_req_grad.lin1.parameters():
    param.requires_grad = False

# 记录初始参数
initial_lin0_weight_rg = model_req_grad.lin0.weight.clone()
initial_lin1_weight_rg = model_req_grad.lin1.weight.clone()
initial_lin2_weight_rg = model_req_grad.lin2.weight.clone()

# 只有requires_grad=True的参数才会被优化器考虑
optimizer_rg = torch.optim.SGD(filter(lambda p: p.requires_grad, model_req_grad.parameters()), lr=0.01)
input_data_rg = torch.randn(1, 1)
target_rg = torch.randint(0, 10, (1,))
loss_fn_rg = nn.CrossEntropyLoss()

print("\n--- 使用 requires_grad=False 策略 ---")
print("初始 lin0 权重:\n", initial_lin0_weight_rg)
print("初始 lin1 权重:\n", initial_lin1_weight_rg)
print("初始 lin2 权重:\n", initial_lin2_weight_rg)

# 进行一次前向传播、反向传播和优化
optimizer_rg.zero_grad()
output_rg = model_req_grad(input_data_rg)
loss_rg = loss_fn_rg(output_rg, target_rg)
loss_rg.backward()
optimizer_rg.step()

# 检查参数变化
print("\n更新后 lin0 权重:\n", model_req_grad.lin0.weight)
print("更新后 lin1 权重:\n", model_req_grad.lin1.weight)
print("更新后 lin2 权重:\n", model_req_grad.lin2.weight)

print("\nlin0 权重是否改变:", not torch.equal(initial_lin0_weight_rg, model_req_grad.lin0.weight))
print("lin1 权重是否改变:", not torch.equal(initial_lin1_weight_rg, model_req_grad.lin1.weight))
print("lin2 权重是否改变:", not torch.equal(initial_lin2_weight_rg, model_req_grad.lin2.weight))
登录后复制

实验结果分析: 通过这种方法,你会发现 lin0 和 lin2 的参数得到了更新,而 lin1 的参数保持不变。这是因为 lin1 的 requires_grad 被设置为 False,其梯度不会被计算,也不会参与优化。但 lin2 的梯度会正常计算并回传到 lin1 的输入,由于 lin1 的参数不需要梯度,梯度会继续回传到 lin0,从而使得 lin0 也能正常更新。

关键注意事项:

  • 优化器参数过滤:在创建优化器时,务必只传入 requires_grad=True 的参数。optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01) 是一种常见且推荐的做法。如果直接传入 model.parameters(),优化器会尝试为所有参数分配内存,即使它们不会更新,这可能导致不必要的资源消耗,虽然最终它们不会被更新。
  • 批量操作:对于包含多个子模块的复杂模型,可以通过循环遍历子模块或使用 named_parameters() 来批量设置 requires_grad。

总结与最佳实践

特性/方法 torch.no_grad() param.requires_grad = False
作用范围 局部,作用于上下文管理器内的所有计算操作 全局,作用于特定参数本身
梯度回传 切断梯度回传路径,其上游和自身均无法更新 允许梯度通过,但不会为 requires_grad=False 的参数计算和存储梯度,其上游层可正常更新
适用场景 推理阶段、性能评估、特征提取等不需要梯度计算的场景 冻结特定层进行迁移学习、微调、或实验控制等需要精确控制参数更新的场景
推荐程度 不推荐用于精确冻结中间层并允许前后层更新的场景 强烈推荐用于精确冻结特定层的场景

综上所述,当您需要在PyTorch中冻结一个中间层,同时确保其前后层能够正常训练和更新时,设置目标层的参数 requires_grad=False 是最准确和推荐的方法。torch.no_grad() 更适用于完全禁用某个计算路径的梯度追踪,它会影响到整个计算链条,导致意外的冻结效果。理解这两种机制的差异,对于高效和准确地进行模型训练至关重要。

以上就是PyTorch中冻结中间层参数的策略与实践的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
相关标签:
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 //m.sbmmt.com/ All Rights Reserved | php.cn | 湘ICP备2023035733号