关于pytorch中网络loss传播和参数更新的理解


Posted in Python onAugust 20, 2019

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

关于pytorch中网络loss传播和参数更新的理解

其中关于pytorch中网络loss传播和参数更新的理解 就是通常说的学习率,关于pytorch中网络loss传播和参数更新的理解 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现生成简单的Makefile文件代码示例
Mar 10 Python
Python中operator模块的操作符使用示例总结
Jun 28 Python
Python编程实现正则删除命令功能
Aug 30 Python
将tensorflow的ckpt模型存储为npy的实例
Jul 09 Python
Django  ORM 练习题及答案
Jul 19 Python
python控制台实现tab补全和清屏的例子
Aug 20 Python
python实现飞机大战项目
Mar 11 Python
详解PyQt5信号与槽的几种高级玩法
Mar 24 Python
如何导出python安装的所有模块名称和版本号到文件中
Jun 05 Python
15款Python编辑器的优缺点,别再问我“选什么编辑器”啦
Oct 19 Python
matplotlib之属性组合包(cycler)的使用
Feb 24 Python
Python中快速掌握Data Frame的常用操作
Mar 31 Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
You might like
2021年最新CPU天梯图
2021/03/04 数码科技
PHP读MYSQL中文乱码的解决方法
2006/12/17 PHP
简单介绍下 PHP5 中引入的 MYSQLI的用途
2007/03/19 PHP
实现PHP框架系列文章(6)mysql数据库方法
2016/03/04 PHP
类似GMAIL的Ajax信息反馈显示
2010/02/16 Javascript
面向对象的Javascript之二(接口实现介绍)
2012/01/27 Javascript
Javascript WebSocket使用实例介绍(简明入门教程)
2014/04/16 Javascript
js加减乘除丢失精度问题解决方法
2014/05/16 Javascript
jQuery实现带有洗牌效果的动画分页实例
2015/08/31 Javascript
学习JavaScript设计模式之观察者模式
2020/04/22 Javascript
Mvc提交表单的四种方法全程详解
2016/08/10 Javascript
jquery mobile实现可折叠的导航按钮
2017/03/11 Javascript
详解nodejs微信公众号开发——1.接入微信公众号
2017/04/10 NodeJs
d3.js实现自定义多y轴折线图的示例代码
2018/05/30 Javascript
使用Vue实现图片上传的三种方式
2018/07/17 Javascript
Vue使用mixin分发组件的可复用功能
2019/09/01 Javascript
微信小程序仿淘宝热搜词在搜索框中轮播功能
2020/01/21 Javascript
vite2.0+vue3移动端项目实战详解
2021/03/03 Vue.js
python实用代码片段收集贴
2015/06/03 Python
使用Python进行AES加密和解密的示例代码
2018/02/02 Python
一些Centos Python 生产环境的部署命令(推荐)
2018/05/07 Python
python 把列表转化为字符串的方法
2018/10/23 Python
浅谈python 导入模块和解决文件句柄找不到问题
2018/12/15 Python
Python小白必备的8个最常用的内置函数(推荐)
2019/04/03 Python
阿里云ECS服务器部署django的方法
2019/08/29 Python
使用phonegap获取位置信息的实现方法
2017/03/31 HTML / CSS
倩碧香港官方网站:Clinique香港
2017/11/13 全球购物
美国波西米亚风格服装品牌:Show Me Your Mumu
2018/01/05 全球购物
意大利拉斐尔时尚购物网:Raffaello Network(支持中文)
2018/11/09 全球购物
欧洲最大的预定车位市场:JustPark
2020/01/06 全球购物
Python里面如何实现tuple和list的转换
2012/06/13 面试题
商务英语专业毕业生自荐信
2013/11/05 职场文书
大学生开西餐厅创业计划书
2014/02/01 职场文书
2014年采购工作总结
2014/11/20 职场文书
2015年治庸问责工作总结
2015/07/27 职场文书
PHP使用QR Code生成二维码实例
2021/07/07 PHP