详解PyTorch批训练及优化器比较
本篇文章主要介绍了详解PyTorch批训练及优化器比较,详细的介绍了什么是PyTorch批训练和PyTorch的Optimizer优化器,非常具有实用价值,需要的朋友可以参考下
一、PyTorch批训练
1. 概述
PyTorch提供了一种将数据包装起来进行批训练的工具——DataLoader。使用的时候,只需要将我们的数据首先转换为torch的tensor形式,再转换成torch可以识别的Dataset格式,然后将Dataset放入DataLoader中就可以啦。
import torch import torch.utils.data as Data torch.manual_seed(1) # 设定随机数种子 BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) y = torch.linspace(0.5, 5, 10) # 将数据转换为torch的dataset格式 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) # 将torch_dataset置入Dataloader中 loader = Data.DataLoader( dataset=torch_dataset, batch_size=BATCH_SIZE, # 批大小 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少 shuffle=True, # 是否随机打乱顺序 num_workers=2, # 多线程读取数据的线程数 ) for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): print('Epoch:', epoch, '|Step:', step, '|batch_x:', batch_x.numpy(), '|batch_y', batch_y.numpy()) ''''' shuffle=True Epoch: 0 |Step: 0 |batch_x: [ 6. 7. 2. 3. 1.] |batch_y [ 3. 3.5 1. 1.5 0.5] Epoch: 0 |Step: 1 |batch_x: [ 9. 10. 4. 8. 5.] |batch_y [ 4.5 5. 2. 4. 2.5] Epoch: 1 |Step: 0 |batch_x: [ 3. 4. 2. 9. 10.] |batch_y [ 1.5 2. 1. 4.5 5. ] Epoch: 1 |Step: 1 |batch_x: [ 1. 7. 8. 5. 6.] |batch_y [ 0.5 3.5 4. 2.5 3. ] Epoch: 2 |Step: 0 |batch_x: [ 3. 9. 2. 6. 7.] |batch_y [ 1.5 4.5 1. 3. 3.5] Epoch: 2 |Step: 1 |batch_x: [ 10. 4. 8. 1. 5.] |batch_y [ 5. 2. 4. 0.5 2.5] shuffle=False Epoch: 0 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 0 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] Epoch: 1 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 1 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] Epoch: 2 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1. 1.5 2. 2.5] Epoch: 2 |Step: 1 |batch_x: [ 6. 7. 8. 9. 10.] |batch_y [ 3. 3.5 4. 4.5 5. ] '''
2. TensorDataset
classtorch.utils.data.TensorDataset(data_tensor, target_tensor)
TensorDataset类用来将样本及其标签打包成torch的Dataset,data_tensor,和target_tensor都是tensor。
3. DataLoader
复制代码 代码如下:
classtorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,num_workers=0, collate_fn=
dataset就是Torch的Dataset格式的对象;batch_size即每批训练的样本数量,默认为;shuffle表示是否需要随机取样本;num_workers表示读取样本的线程数。
二、PyTorch的Optimizer优化器
本实验中,首先构造一组数据集,转换格式并置于DataLoader中,备用。定义一个固定结构的默认神经网络,然后为每个优化器构建一个神经网络,每个神经网络的区别仅仅是优化器不同。通过记录训练过程中的loss值,最后在图像上呈现得到各个优化器的优化过程。
代码实现:
import torch import torch.utils.data as Data import torch.nn.functional as F from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed(1) # 设定随机数种子 # 定义超参数 LR = 0.01 # 学习率 BATCH_SIZE = 32 # 批大小 EPOCH = 12 # 迭代次数 x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1) y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size())) #plt.scatter(x.numpy(), y.numpy()) #plt.show() # 将数据转换为torch的dataset格式 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) # 将torch_dataset置入Dataloader中 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.hidden = torch.nn.Linear(1, 20) self.predict = torch.nn.Linear(20, 1) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x # 为每个优化器创建一个Net net_SGD = Net() net_Momentum = Net() net_RMSprop = Net() net_Adam = Net() nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam] # 初始化优化器 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR) opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8) opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9) opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99)) optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam] # 定义损失函数 loss_function = torch.nn.MSELoss() losses_history = [[], [], [], []] # 记录training时不同神经网络的loss值 for epoch in range(EPOCH): print('Epoch:', epoch + 1, 'Training...') for step, (batch_x, batch_y) in enumerate(loader): b_x = Variable(batch_x) b_y = Variable(batch_y) for net, opt, l_his in zip(nets, optimizers, losses_history): output = net(b_x) loss = loss_function(output, b_y) opt.zero_grad() loss.backward() opt.step() l_his.append(loss.data[0]) labels = ['SGD', 'Momentum', 'RMSprop', 'Adam'] for i, l_his in enumerate(losses_history): plt.plot(l_his, label=labels[i]) plt.legend(loc='best') plt.xlabel('Steps') plt.ylabel('Loss') plt.ylim((0, 0.2)) plt.show()
实验结果:
由实验结果可见,SGD的优化效果是最差的,速度很慢;作为SGD的改良版本,Momentum表现就好许多;相比RMSprop和Adam的优化速度就非常好。实验中,针对不同的优化问题,比较各个优化器的效果再来决定使用哪个。
三、其他补充
1. Python的zip函数
zip函数接受任意多个(包括0个和1个)序列作为参数,返回一个tuple列表。
x = [1, 2, 3] y = [4, 5, 6] z = [7, 8, 9] xyz = zip(x, y, z) print xyz [(1, 4, 7), (2, 5, 8), (3, 6, 9)] x = [1, 2, 3] x = zip(x) print x [(1,), (2,), (3,)] x = [1, 2, 3] y = [4, 5, 6, 7] xy = zip(x, y) print xy [(1, 4), (2, 5), (3, 6)]
相关推荐:
以上是详解PyTorch批训练及优化器比较的详细内容。更多信息请关注PHP中文网其他相关文章!

热AI工具

Undress AI Tool
免费脱衣服图片

Undresser.AI Undress
人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover
用于从照片中去除衣服的在线人工智能工具。

Clothoff.io
AI脱衣机

Video Face Swap
使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热门文章

热工具

记事本++7.3.1
好用且免费的代码编辑器

SublimeText3汉化版
中文版,非常好用

禅工作室 13.0.1
功能强大的PHP集成开发环境

Dreamweaver CS6
视觉化网页开发工具

SublimeText3 Mac版
神级代码编辑软件(SublimeText3)

如今手机的性能和功能越来越强大,几乎所有手机都配备了便捷的NFC功能,方便用户进行移动支付和身份认证。然而,有些小米14Pro的用户可能不清楚如何启用NFC功能。接下来,让我详细向大家介绍一下。小米14Pro怎么开启nfc功能?步骤一:打开手机的设置菜单。步骤二:找到并点击“连接和共享”或“无线和网络”选项。步骤三:在连接和共享或无线和网络菜单中,找到并点击“NFC和支付”。步骤四:找到并点击“NFC开关”。一般情况下,默认是关闭的状态。步骤五:在NFC开关页面上,点击开关按钮,将其切换为开启状

隔空滑动屏幕是华为的一项功能,在华为mate60系列中可以说是备受好评,这个功能是通过利用手机上的激光感应器和前置摄像头的3D深感摄像头,来完成一系列不需要触碰屏幕的功能,比如说隔空刷抖音,但是华为Pocket2应该要怎么隔空刷抖音呢?华为Pocket2怎么隔空截图?1、打开华为Pocket2的设置2、然后选择【辅助功能】。3、点击打开【智慧感知】。4、打开【隔空滑动屏幕】、【隔空截屏】、【隔空按压】开关就可以了。5、在使用的时候,需要再距离屏幕20~40CM处,张开手掌,待屏幕上出现手掌图标,

时间复杂度衡量算法执行时间与输入规模的关系。降低C++程序时间复杂度的技巧包括:选择合适的容器(如vector、list)以优化数据存储和管理。利用高效算法(如快速排序)以减少计算时间。消除多重运算以减少重复计算。利用条件分支以避免不必要的计算。通过使用更快的算法(如二分搜索)来优化线性搜索。

组查询注意力(GroupedQueryAttention)是大型语言模型中的一种多查询注意力力方法,它的目标是在保持MQA速度的同时实现MHA的质量。GroupedQueryAttention将查询分组,每个组内的查询共享相同的注意力权重,这有助于降低计算复杂度和提高推理速度。这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。GQA是在论文GQA:TrainingGeneralizedMulti-QueryTransformerModelsfromMulti-HeadCheckpoint

WPS是我们常用的办公软件,在进行长篇文章的编辑时,经常会因为字体太小而看不清楚,所以会对字体和整个文档进行调整。例如:把文档进行行距的调整,会让整个文档变得非常清晰,我建议各位小伙伴们都要学会这个操作步骤,今天就分享给大家,具体的操作步骤如下,快来看一看!打开要调整的WPS文本文件,在【开始】菜单中找到段落设置工具栏,你会看到行距设置小图标(如图中红色线圈所示)。2、点击行距设置右下角的小倒三角形,会出现相应的行距数值,可以选择1~3倍行距(如图箭头所示)。3、或者点击鼠标右键点击段落,就会出

根据3月2日数据统计,比特币二层网络MerlinChain总TVL已达30亿美元。其中比特币生态资产占比达90.83%,包括价值15.96亿美元的BTC以及4.04亿美元的BRC-20资产等。上一个月,MerlinChain在开启质押活动14天内,其TVL总额就已经达到了19.7亿美元,超过了去年11月份上线也是最近同样引人注目的Blast。2月26日,MerlinChain生态内的NFT总价值超过了4.2亿美元,成为除以太坊以外NFT市值最高的公链项目。项目简介MerlinChain是OKX支

天玑6020处理器和骁龙处理器一直是消费者们争论的焦点。两者都是市场上颇具竞争力的芯片,各有所长,各有适用的场景。究竟天玑6020处理器比骁龙处理器强多少?让我们来仔细比较一下它们的性能和特点。首先,从芯片制程上来看,天玑6020处理器采用了台积电6纳米制程技术,而骁龙处理器一般采用7纳米或更老的制程技术。在同样的工艺制程下,一般来说,制程越小,能耗越低,发

PHP函数效率优化的五大方法:避免不必要的变量复制。使用引用以避免变量复制。避免重复函数调用。内联简单的函数。使用数组优化循环。
