pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
详解Python Socket网络编程
Jan 05 Python
定制FileField中的上传文件名称实例
Aug 23 Python
Python 查看文件的读写权限方法
Jan 23 Python
python更改已存在excel文件的方法
May 03 Python
Python利用ORM控制MongoDB(MongoEngine)的步骤全纪录
Sep 13 Python
NumPy 基本切片和索引的具体使用方法
Apr 24 Python
梅尔倒谱系数(MFCC)实现
Jun 19 Python
python函数的作用域及关键字详解
Aug 20 Python
python队列原理及实现方法示例
Nov 27 Python
Python 时间戳之获取整点凌晨时间戳的操作方法
Jan 28 Python
Pycharm同步远程服务器调试的方法步骤
Nov 04 Python
Python的三个重要函数详解
Jan 18 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
php封装的连接Mysql类及用法分析
2015/12/10 PHP
推荐自用 Javascript 缩图函数 (onDOMLoaded)……
2007/10/23 Javascript
节点的插入之append()和appendTo()的用法介绍
2014/01/13 Javascript
自写的jQuery异步加载数据添加事件
2014/05/15 Javascript
jquery在ie7下选择器的问题导致append失效的解决方法
2016/01/10 Javascript
Bootstrap进度条学习使用
2017/02/09 Javascript
深入理解AngularJS中的ng-bind-html指令
2017/03/27 Javascript
AugularJS从入门到实践(必看篇)
2017/07/10 Javascript
深入探究node之Transform
2017/07/20 Javascript
JavaScript中重名的函数与对象示例详析
2017/09/28 Javascript
搭建element-ui的Vue前端工程操作实例
2018/02/23 Javascript
原生javascript实现连连看游戏
2019/01/03 Javascript
小程序页面动态配置实现方法
2019/02/05 Javascript
javascript的this关键字详解
2019/05/20 Javascript
vue使用map代替Aarry数组循环遍历的方法
2020/04/30 Javascript
详解element-ui 表单校验 Rules 配置 常用黑科技
2020/07/11 Javascript
[00:42]《辉夜杯》—职业组预选赛12月3日15点 正式打响
2015/12/03 DOTA
[01:36:17]DOTA2-DPC中国联赛 正赛 Ehome vs iG BO3 第一场 1月31日
2021/03/11 DOTA
Python去除、替换字符串空格的处理方法
2018/04/01 Python
python 获取当天每个准点时间戳的实例
2018/05/22 Python
利用python如何处理百万条数据(适用java新手)
2018/06/06 Python
Python中的引用知识点总结
2019/05/20 Python
python3中类的继承以及self和super的区别详解
2019/06/26 Python
Python叠加两幅栅格图像的实现方法
2019/07/05 Python
tensorflow之变量初始化(tf.Variable)使用详解
2020/02/06 Python
最新2019Pycharm安装教程 亲测
2020/02/28 Python
Pycharm自带Git实现版本管理的方法步骤
2020/09/18 Python
详解HTML5 canvas绘图基本使用方法
2018/01/29 HTML / CSS
英国标志性奢侈品牌:Burberry
2016/07/28 全球购物
澳大利亚网上书店:QBD
2021/01/09 全球购物
竞选班干部演讲稿600字
2014/08/20 职场文书
2014年城管个人工作总结
2014/12/08 职场文书
结婚通知短信大全
2015/04/17 职场文书
家访教师心得体会
2016/01/23 职场文书
使用python如何删除同一文件夹下相似的图片
2021/05/07 Python
十大好看的穿越动漫排名:《瑞克和莫蒂》第一,国漫《有药》在榜
2022/03/18 日漫