关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python写的创建文件夹自定义函数mkdir()
Aug 25 Python
python在windows下创建隐藏窗口子进程的方法
Jun 04 Python
python安装教程
Feb 28 Python
Python实现从log日志中提取ip的方法【正则提取】
Mar 31 Python
selenium设置proxy、headers的方法(phantomjs、Chrome、Firefox)
Nov 29 Python
python实现列表中最大最小值输出的示例
Jul 09 Python
Python 文件数据读写的具体实现
Jan 24 Python
django之导入并执行自定义的函数模块图解
Apr 01 Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 Python
keras 多gpu并行运行案例
Jun 10 Python
在Keras中利用np.random.shuffle()打乱数据集实例
Jun 15 Python
详解pandas中利用DataFrame对象的.loc[]、.iloc[]方法抽取数据
Dec 13 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
PHP5 安装方法
2006/10/09 PHP
php中的MVC模式运用技巧
2007/05/03 PHP
PHP进程通信基础之信号
2017/02/19 PHP
phpstudy2018升级MySQL5.5为5.7教程(图文)
2018/10/24 PHP
PHP学习记录之常用的魔术常量详解
2019/12/12 PHP
jquery禁用右键示例
2014/04/28 Javascript
js+html5获取用户地理位置信息并在Google地图上显示的方法
2015/06/05 Javascript
AngularJS 2.0入门权威指南
2016/10/08 Javascript
实现隔行换色效果的两种方式【实用】
2016/11/27 Javascript
javascript数组去重常用方法实例分析
2017/04/11 Javascript
理解Angular的providers给Http添加默认headers
2017/07/04 Javascript
在 Angular 中使用Chart.js 和 ng2-charts的示例代码
2017/08/17 Javascript
详解Nuxt.js Vue服务端渲染摸索
2018/02/08 Javascript
Vue渲染过程浅析
2019/03/14 Javascript
微信小程序实现电子签名功能
2020/07/29 Javascript
[03:00]2014DOTA2国际邀请赛 Titan淘汰潸然泪下Ohaiyo专访
2014/07/15 DOTA
python3使用urllib示例取googletranslate(谷歌翻译)
2014/01/23 Python
使用优化器来提升Python程序的执行效率的教程
2015/04/02 Python
Django中的“惰性翻译”方法的相关使用
2015/07/27 Python
Python 文件管理实例详解
2015/11/10 Python
python3 与python2 异常处理的区别与联系
2016/06/19 Python
Python入门_浅谈数据结构的4种基本类型
2017/05/16 Python
浅谈flask源码之请求过程
2018/07/26 Python
django 消息框架 message使用详解
2019/07/22 Python
jupyter notebook 多环境conda kernel配置方式
2020/04/10 Python
Python控制台实现交互式环境执行
2020/06/09 Python
浅谈django不使用restframework自定义接口与使用的区别
2020/07/15 Python
python 获取字典特定值对应的键的实现
2020/09/29 Python
python代数式括号有效性检验示例代码
2020/10/04 Python
html5使用canvas画一条线
2014/12/15 HTML / CSS
Steve Madden官网:美国鞋类品牌
2017/01/29 全球购物
ghd法国官方网站:英国最受欢迎的美发工具品牌
2019/04/18 全球购物
台湾7-ELEVEN线上购物中心:7-11
2021/01/21 全球购物
办公室主任主任岗位责任制
2014/02/11 职场文书
开会通知
2015/04/20 职场文书
go语言中切片与内存复制 memcpy 的实现操作
2021/04/27 Golang