首页 后端开发 Python教程 详解PyTorch批训练及优化器比较

详解PyTorch批训练及优化器比较

Apr 28, 2018 am 09:46 AM
pytorch 优化 比较

本篇文章主要介绍了详解PyTorch批训练及优化器比较,详细的介绍了什么是PyTorch批训练和PyTorch的Optimizer优化器,非常具有实用价值,需要的朋友可以参考下

一、PyTorch批训练

1. 概述

PyTorch提供了一种将数据包装起来进行批训练的工具——DataLoader。使用的时候,只需要将我们的数据首先转换为torch的tensor形式,再转换成torch可以识别的Dataset格式,然后将Dataset放入DataLoader中就可以啦。

import torch 
import torch.utils.data as Data 
 
torch.manual_seed(1) # 设定随机数种子 
 
BATCH_SIZE = 5 
 
x = torch.linspace(1, 10, 10) 
y = torch.linspace(0.5, 5, 10) 
 
# 将数据转换为torch的dataset格式 
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) 
 
# 将torch_dataset置入Dataloader中 
loader = Data.DataLoader( 
  dataset=torch_dataset, 
  batch_size=BATCH_SIZE, # 批大小 
  # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少 
  shuffle=True, # 是否随机打乱顺序 
  num_workers=2, # 多线程读取数据的线程数 
  ) 
 
for epoch in range(3): 
  for step, (batch_x, batch_y) in enumerate(loader): 
    print('Epoch:', epoch, '|Step:', step, '|batch_x:', 
       batch_x.numpy(), '|batch_y', batch_y.numpy()) 
''''' 
shuffle=True 
Epoch: 0 |Step: 0 |batch_x: [ 6. 7. 2. 3. 1.] |batch_y [ 3.  3.5 1.  1.5 0.5] 
Epoch: 0 |Step: 1 |batch_x: [ 9. 10.  4.  8.  5.] |batch_y [ 4.5 5.  2.  4.  2.5] 
Epoch: 1 |Step: 0 |batch_x: [ 3.  4.  2.  9. 10.] |batch_y [ 1.5 2.  1.  4.5 5. ] 
Epoch: 1 |Step: 1 |batch_x: [ 1. 7. 8. 5. 6.] |batch_y [ 0.5 3.5 4.  2.5 3. ] 
Epoch: 2 |Step: 0 |batch_x: [ 3. 9. 2. 6. 7.] |batch_y [ 1.5 4.5 1.  3.  3.5] 
Epoch: 2 |Step: 1 |batch_x: [ 10.  4.  8.  1.  5.] |batch_y [ 5.  2.  4.  0.5 2.5] 
 
shuffle=False 
Epoch: 0 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 0 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
Epoch: 1 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 1 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
Epoch: 2 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 2 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
'''

2. TensorDataset

classtorch.utils.data.TensorDataset(data_tensor, target_tensor)

TensorDataset类用来将样本及其标签打包成torch的Dataset,data_tensor,和target_tensor都是tensor。

3. DataLoader


复制代码 代码如下:

classtorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,num_workers=0, collate_fn=, pin_memory=False,drop_last=False)

dataset就是Torch的Dataset格式的对象;batch_size即每批训练的样本数量,默认为;shuffle表示是否需要随机取样本;num_workers表示读取样本的线程数。

二、PyTorch的Optimizer优化器

本实验中,首先构造一组数据集,转换格式并置于DataLoader中,备用。定义一个固定结构的默认神经网络,然后为每个优化器构建一个神经网络,每个神经网络的区别仅仅是优化器不同。通过记录训练过程中的loss值,最后在图像上呈现得到各个优化器的优化过程。

代码实现:

import torch 
import torch.utils.data as Data 
import torch.nn.functional as F 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
torch.manual_seed(1) # 设定随机数种子 
 
# 定义超参数 
LR = 0.01 # 学习率 
BATCH_SIZE = 32 # 批大小 
EPOCH = 12 # 迭代次数 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1) 
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size())) 
 
#plt.scatter(x.numpy(), y.numpy()) 
#plt.show() 
 
# 将数据转换为torch的dataset格式 
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) 
# 将torch_dataset置入Dataloader中 
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, 
             shuffle=True, num_workers=2) 
 
class Net(torch.nn.Module): 
  def __init__(self): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(1, 20) 
    self.predict = torch.nn.Linear(20, 1) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
# 为每个优化器创建一个Net 
net_SGD = Net() 
net_Momentum = Net() 
net_RMSprop = Net() 
net_Adam = Net()  
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam] 
 
# 初始化优化器 
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR) 
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8) 
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9) 
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99)) 
 
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam] 
 
# 定义损失函数 
loss_function = torch.nn.MSELoss() 
losses_history = [[], [], [], []] # 记录training时不同神经网络的loss值 
 
for epoch in range(EPOCH): 
  print('Epoch:', epoch + 1, 'Training...') 
  for step, (batch_x, batch_y) in enumerate(loader): 
    b_x = Variable(batch_x) 
    b_y = Variable(batch_y) 
 
    for net, opt, l_his in zip(nets, optimizers, losses_history): 
      output = net(b_x) 
      loss = loss_function(output, b_y) 
      opt.zero_grad() 
      loss.backward() 
      opt.step() 
      l_his.append(loss.data[0]) 
 
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam'] 
 
for i, l_his in enumerate(losses_history): 
  plt.plot(l_his, label=labels[i]) 
plt.legend(loc='best') 
plt.xlabel('Steps') 
plt.ylabel('Loss') 
plt.ylim((0, 0.2)) 
plt.show()

实验结果:

由实验结果可见,SGD的优化效果是最差的,速度很慢;作为SGD的改良版本,Momentum表现就好许多;相比RMSprop和Adam的优化速度就非常好。实验中,针对不同的优化问题,比较各个优化器的效果再来决定使用哪个。

三、其他补充

1. Python的zip函数

zip函数接受任意多个(包括0个和1个)序列作为参数,返回一个tuple列表。

x = [1, 2, 3] 
y = [4, 5, 6] 
z = [7, 8, 9] 
xyz = zip(x, y, z) 
print xyz 
[(1, 4, 7), (2, 5, 8), (3, 6, 9)] 
 
x = [1, 2, 3] 
x = zip(x) 
print x 
[(1,), (2,), (3,)] 
 
x = [1, 2, 3] 
y = [4, 5, 6, 7] 
xy = zip(x, y) 
print xy 
[(1, 4), (2, 5), (3, 6)]

相关推荐:

Pytorch入门之mnist分类实例

以上是详解PyTorch批训练及优化器比较的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Clothoff.io

Clothoff.io

AI脱衣机

Video Face Swap

Video Face Swap

使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

热门话题

Laravel 教程
1604
29
PHP教程
1510
276
小米14 Pro怎么开启nfc功能? 小米14 Pro怎么开启nfc功能? Mar 19, 2024 pm 02:28 PM

如今手机的性能和功能越来越强大,几乎所有手机都配备了便捷的NFC功能,方便用户进行移动支付和身份认证。然而,有些小米14Pro的用户可能不清楚如何启用NFC功能。接下来,让我详细向大家介绍一下。小米14Pro怎么开启nfc功能?步骤一:打开手机的设置菜单。步骤二:找到并点击“连接和共享”或“无线和网络”选项。步骤三:在连接和共享或无线和网络菜单中,找到并点击“NFC和支付”。步骤四:找到并点击“NFC开关”。一般情况下,默认是关闭的状态。步骤五:在NFC开关页面上,点击开关按钮,将其切换为开启状

华为 Pocket2怎么隔空刷抖音? 华为 Pocket2怎么隔空刷抖音? Mar 18, 2024 pm 03:00 PM

隔空滑动屏幕是华为的一项功能,在华为mate60系列中可以说是备受好评,这个功能是通过利用手机上的激光感应器和前置摄像头的3D深感摄像头,来完成一系列不需要触碰屏幕的功能,比如说隔空刷抖音,但是华为Pocket2应该要怎么隔空刷抖音呢?华为Pocket2怎么隔空截图?1、打开华为Pocket2的设置2、然后选择【辅助功能】。3、点击打开【智慧感知】。4、打开【隔空滑动屏幕】、【隔空截屏】、【隔空按压】开关就可以了。5、在使用的时候,需要再距离屏幕20~40CM处,张开手掌,待屏幕上出现手掌图标,

C++ 程序优化:时间复杂度降低技巧 C++ 程序优化:时间复杂度降低技巧 Jun 01, 2024 am 11:19 AM

时间复杂度衡量算法执行时间与输入规模的关系。降低C++程序时间复杂度的技巧包括:选择合适的容器(如vector、list)以优化数据存储和管理。利用高效算法(如快速排序)以减少计算时间。消除多重运算以减少重复计算。利用条件分支以避免不必要的计算。通过使用更快的算法(如二分搜索)来优化线性搜索。

大模型中常用的注意力机制GQA详解以及Pytorch代码实现 大模型中常用的注意力机制GQA详解以及Pytorch代码实现 Apr 03, 2024 pm 05:40 PM

组查询注意力(GroupedQueryAttention)是大型语言模型中的一种多查询注意力力方法,它的目标是在保持MQA速度的同时实现MHA的质量。GroupedQueryAttention将查询分组,每个组内的查询共享相同的注意力权重,这有助于降低计算复杂度和提高推理速度。这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。GQA是在论文GQA:TrainingGeneralizedMulti-QueryTransformerModelsfromMulti-HeadCheckpoint

WPS Word怎么设置行距让文档更工整 WPS Word怎么设置行距让文档更工整 Mar 20, 2024 pm 04:30 PM

WPS是我们常用的办公软件,在进行长篇文章的编辑时,经常会因为字体太小而看不清楚,所以会对字体和整个文档进行调整。例如:把文档进行行距的调整,会让整个文档变得非常清晰,我建议各位小伙伴们都要学会这个操作步骤,今天就分享给大家,具体的操作步骤如下,快来看一看!打开要调整的WPS文本文件,在【开始】菜单中找到段落设置工具栏,你会看到行距设置小图标(如图中红色线圈所示)。2、点击行距设置右下角的小倒三角形,会出现相应的行距数值,可以选择1~3倍行距(如图箭头所示)。3、或者点击鼠标右键点击段落,就会出

TrendX 研究院:Merlin Chain 项目分析及生态盘点 TrendX 研究院:Merlin Chain 项目分析及生态盘点 Mar 24, 2024 am 09:01 AM

根据3月2日数据统计,比特币二层网络MerlinChain总TVL已达30亿美元。其中比特币生态资产占比达90.83%,包括价值15.96亿美元的BTC以及4.04亿美元的BRC-20资产等。上一个月,MerlinChain在开启质押活动14天内,其TVL总额就已经达到了19.7亿美元,超过了去年11月份上线也是最近同样引人注目的Blast。2月26日,MerlinChain生态内的NFT总价值超过了4.2亿美元,成为除以太坊以外NFT市值最高的公链项目。项目简介MerlinChain是OKX支

天玑6020处理器究竟比骁龙处理器强多少 天玑6020处理器究竟比骁龙处理器强多少 Mar 18, 2024 pm 12:15 PM

天玑6020处理器和骁龙处理器一直是消费者们争论的焦点。两者都是市场上颇具竞争力的芯片,各有所长,各有适用的场景。究竟天玑6020处理器比骁龙处理器强多少?让我们来仔细比较一下它们的性能和特点。首先,从芯片制程上来看,天玑6020处理器采用了台积电6纳米制程技术,而骁龙处理器一般采用7纳米或更老的制程技术。在同样的工艺制程下,一般来说,制程越小,能耗越低,发

解决 PHP 函数效率低下的方法有哪些? 解决 PHP 函数效率低下的方法有哪些? May 02, 2024 pm 01:48 PM

PHP函数效率优化的五大方法:避免不必要的变量复制。使用引用以避免变量复制。避免重复函数调用。内联简单的函数。使用数组优化循环。

See all articles