详解PyTorch批训练及优化器比较


Posted in Python onApril 28, 2018

一、PyTorch批训练

1. 概述

PyTorch提供了一种将数据包装起来进行批训练的工具——DataLoader。使用的时候,只需要将我们的数据首先转换为torch的tensor形式,再转换成torch可以识别的Dataset格式,然后将Dataset放入DataLoader中就可以啦。

import torch 
import torch.utils.data as Data 
 
torch.manual_seed(1) # 设定随机数种子 
 
BATCH_SIZE = 5 
 
x = torch.linspace(1, 10, 10) 
y = torch.linspace(0.5, 5, 10) 
 
# 将数据转换为torch的dataset格式 
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) 
 
# 将torch_dataset置入Dataloader中 
loader = Data.DataLoader( 
  dataset=torch_dataset, 
  batch_size=BATCH_SIZE, # 批大小 
  # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少 
  shuffle=True, # 是否随机打乱顺序 
  num_workers=2, # 多线程读取数据的线程数 
  ) 
 
for epoch in range(3): 
  for step, (batch_x, batch_y) in enumerate(loader): 
    print('Epoch:', epoch, '|Step:', step, '|batch_x:', 
       batch_x.numpy(), '|batch_y', batch_y.numpy()) 
''''' 
shuffle=True 
Epoch: 0 |Step: 0 |batch_x: [ 6. 7. 2. 3. 1.] |batch_y [ 3.  3.5 1.  1.5 0.5] 
Epoch: 0 |Step: 1 |batch_x: [ 9. 10.  4.  8.  5.] |batch_y [ 4.5 5.  2.  4.  2.5] 
Epoch: 1 |Step: 0 |batch_x: [ 3.  4.  2.  9. 10.] |batch_y [ 1.5 2.  1.  4.5 5. ] 
Epoch: 1 |Step: 1 |batch_x: [ 1. 7. 8. 5. 6.] |batch_y [ 0.5 3.5 4.  2.5 3. ] 
Epoch: 2 |Step: 0 |batch_x: [ 3. 9. 2. 6. 7.] |batch_y [ 1.5 4.5 1.  3.  3.5] 
Epoch: 2 |Step: 1 |batch_x: [ 10.  4.  8.  1.  5.] |batch_y [ 5.  2.  4.  0.5 2.5] 
 
shuffle=False 
Epoch: 0 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 0 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
Epoch: 1 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 1 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
Epoch: 2 |Step: 0 |batch_x: [ 1. 2. 3. 4. 5.] |batch_y [ 0.5 1.  1.5 2.  2.5] 
Epoch: 2 |Step: 1 |batch_x: [ 6.  7.  8.  9. 10.] |batch_y [ 3.  3.5 4.  4.5 5. ] 
'''

2. TensorDataset

classtorch.utils.data.TensorDataset(data_tensor, target_tensor)

TensorDataset类用来将样本及其标签打包成torch的Dataset,data_tensor,和target_tensor都是tensor。

3. DataLoader

classtorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,num_workers=0, collate_fn=<function default_collate>, pin_memory=False,drop_last=False)

dataset就是Torch的Dataset格式的对象;batch_size即每批训练的样本数量,默认为;shuffle表示是否需要随机取样本;num_workers表示读取样本的线程数。

二、PyTorch的Optimizer优化器

本实验中,首先构造一组数据集,转换格式并置于DataLoader中,备用。定义一个固定结构的默认神经网络,然后为每个优化器构建一个神经网络,每个神经网络的区别仅仅是优化器不同。通过记录训练过程中的loss值,最后在图像上呈现得到各个优化器的优化过程。

代码实现:

import torch 
import torch.utils.data as Data 
import torch.nn.functional as F 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
torch.manual_seed(1) # 设定随机数种子 
 
# 定义超参数 
LR = 0.01 # 学习率 
BATCH_SIZE = 32 # 批大小 
EPOCH = 12 # 迭代次数 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1) 
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size())) 
 
#plt.scatter(x.numpy(), y.numpy()) 
#plt.show() 
 
# 将数据转换为torch的dataset格式 
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) 
# 将torch_dataset置入Dataloader中 
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, 
             shuffle=True, num_workers=2) 
 
class Net(torch.nn.Module): 
  def __init__(self): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(1, 20) 
    self.predict = torch.nn.Linear(20, 1) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
# 为每个优化器创建一个Net 
net_SGD = Net() 
net_Momentum = Net() 
net_RMSprop = Net() 
net_Adam = Net()  
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam] 
 
# 初始化优化器 
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR) 
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8) 
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9) 
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99)) 
 
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam] 
 
# 定义损失函数 
loss_function = torch.nn.MSELoss() 
losses_history = [[], [], [], []] # 记录training时不同神经网络的loss值 
 
for epoch in range(EPOCH): 
  print('Epoch:', epoch + 1, 'Training...') 
  for step, (batch_x, batch_y) in enumerate(loader): 
    b_x = Variable(batch_x) 
    b_y = Variable(batch_y) 
 
    for net, opt, l_his in zip(nets, optimizers, losses_history): 
      output = net(b_x) 
      loss = loss_function(output, b_y) 
      opt.zero_grad() 
      loss.backward() 
      opt.step() 
      l_his.append(loss.data[0]) 
 
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam'] 
 
for i, l_his in enumerate(losses_history): 
  plt.plot(l_his, label=labels[i]) 
plt.legend(loc='best') 
plt.xlabel('Steps') 
plt.ylabel('Loss') 
plt.ylim((0, 0.2)) 
plt.show()

实验结果:

详解PyTorch批训练及优化器比较

由实验结果可见,SGD的优化效果是最差的,速度很慢;作为SGD的改良版本,Momentum表现就好许多;相比RMSprop和Adam的优化速度就非常好。实验中,针对不同的优化问题,比较各个优化器的效果再来决定使用哪个。

三、其他补充

1. Python的zip函数

zip函数接受任意多个(包括0个和1个)序列作为参数,返回一个tuple列表。

x = [1, 2, 3] 
y = [4, 5, 6] 
z = [7, 8, 9] 
xyz = zip(x, y, z) 
print xyz 
[(1, 4, 7), (2, 5, 8), (3, 6, 9)] 
 
x = [1, 2, 3] 
x = zip(x) 
print x 
[(1,), (2,), (3,)] 
 
x = [1, 2, 3] 
y = [4, 5, 6, 7] 
xy = zip(x, y) 
print xy 
[(1, 4), (2, 5), (3, 6)]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python语言使用技巧分享
May 31 Python
PyQt5 pyqt多线程操作入门
May 05 Python
Python获取系统所有进程PID及进程名称的方法示例
May 24 Python
Python3.5迭代器与生成器用法实例分析
Apr 30 Python
Flask-WTF表单的使用方法
Jul 12 Python
python scipy卷积运算的实现方法
Sep 16 Python
Python 生成器,迭代,yield关键字,send()传参给yield语句操作示例
Oct 12 Python
Python定时器线程池原理详解
Feb 26 Python
python读取配置文件方式(ini、yaml、xml)
Apr 09 Python
在keras下实现多个模型的融合方式
May 23 Python
浅谈Python中的继承
Jun 19 Python
如何基于Python按行合并两个txt
Nov 03 Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
python实现log日志的示例代码
Apr 28 #Python
Python学习笔记之open()函数打开文件路径报错问题
Apr 28 #Python
You might like
PHP调用Linux的命令行执行文件压缩命令
2013/01/27 PHP
关于更改Zend Studio/Eclipse代码风格主题的介绍
2013/06/23 PHP
php获取服务器端mac和客户端mac的地址支持WIN/LINUX
2014/05/15 PHP
PHP5.6读写excel表格文件操作示例
2019/02/26 PHP
javascript 读取XML数据,在页面中展现、编辑、保存的实现
2009/10/27 Javascript
document.body.scrollTop 值总为0的解决方法 比较常见的标准问题
2009/11/30 Javascript
使用Jquery搭建最佳用户体验的登录页面之记住密码自动登录功能(含后台代码)
2011/07/10 Javascript
IE与Firefox在JavaScript上的7个不同句法分享
2011/10/30 Javascript
解析javascript 数组以及json元素的添加删除
2013/06/26 Javascript
当滚动条滚动到页面底部自动加载增加内容的js代码
2014/05/13 Javascript
css如何让浮动元素水平居中
2015/08/07 Javascript
浅析JavaScript 箭头函数 generator Date JSON
2016/05/23 Javascript
详解Node.js模块间共享数据库连接的方法
2016/05/24 Javascript
JS Canvas定时器模拟动态加载动画
2016/09/17 Javascript
js中的eval()函数把含有转义字符的字符串转换成Object对象的方法
2016/12/02 Javascript
JavaScript组成、引入、输出、运算符基础知识讲解
2016/12/08 Javascript
浅析bootstrap原理及优缺点
2017/03/19 Javascript
Vue表单验证插件Vue Validator使用方法详解
2017/04/07 Javascript
vue.js框架实现表单排序和分页效果
2017/08/09 Javascript
Windows下Node.js安装及环境配置方法
2017/09/18 Javascript
详解vuex之store拆分即多模块状态管理(modules)篇
2018/11/13 Javascript
基于Vue实现微前端的示例代码
2020/04/24 Javascript
ant-design-vue 时间选择器赋值默认时间的操作
2020/10/27 Javascript
[01:01]青春无憾,一战成名——DOTA2全国高校联赛开启
2018/02/25 DOTA
跟老齐学Python之??碌某?? target=
2014/09/12 Python
python dataframe向下向上填充,fillna和ffill的方法
2018/11/28 Python
Python3实现计算两个数组的交集算法示例
2019/04/03 Python
django框架防止XSS注入的方法分析
2019/06/21 Python
Python requests及aiohttp速度对比代码实例
2020/07/16 Python
携程英文网站:Trip.com
2017/02/07 全球购物
学生会宣传部部长竞选演讲稿
2014/04/25 职场文书
三八节标语
2014/06/27 职场文书
乡领导班子四风问题对照检查材料
2014/09/25 职场文书
开展党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
四风问题自查自纠工作情况报告
2014/10/28 职场文书
黄山导游词
2015/01/31 职场文书