pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python字典序问题实例
Sep 26 Python
使用Python写一个小游戏
Apr 02 Python
浅谈python的输入输出,注释,基本数据类型
Apr 02 Python
python多线程共享变量的使用和效率方法
Jul 16 Python
8段用于数据清洗Python代码(小结)
Oct 31 Python
python批量处理txt文件的实例代码
Jan 13 Python
Django bulk_create()、update()与数据库事务的效率对比分析
May 15 Python
Keras构建神经网络踩坑(解决model.predict预测值全为0.0的问题)
Jul 07 Python
Django URL参数Template反向解析
Nov 24 Python
python 实现图片批量压缩的示例
Dec 18 Python
python中re模块知识点总结
Jan 17 Python
python实战之一步一步教你绘制小猪佩奇
Apr 22 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
怎样在PHP中通过ADO调用Asscess数据库和COM程序
2006/10/09 PHP
屏蔽机器人从你的网站搜取email地址的php代码
2012/11/14 PHP
如何使用php判断所处服务器操作系统的类型
2013/06/20 PHP
php绘制一个扇形的方法
2015/01/24 PHP
实例讲解PHP设计模式编程中的简单工厂模式
2016/02/29 PHP
php语言注释,单行注释和多行注释
2018/01/21 PHP
javascript 数据类型转换(parseInt,parseFloat)
2010/07/20 Javascript
javascript开发中因空格引发的错误
2010/11/08 Javascript
ASP.NET jQuery 实例6 (实现CheckBoxList成员全选或全取消)
2012/01/13 Javascript
利用Keydown事件阻止用户输入实现代码
2014/03/11 Javascript
JavaScript判断变量是对象还是数组的方法
2014/08/28 Javascript
学习Bootstrap组件之下拉菜单
2015/07/28 Javascript
JavaScript驾驭网页-获取网页元素
2016/03/24 Javascript
Jquery ajax请求导出Excel表格的实现代码
2016/06/08 Javascript
ionic+AngularJs实现获取验证码倒计时按钮
2017/04/22 Javascript
JS使用setInterval实现的简单计时器功能示例
2018/04/19 Javascript
Vue-Router的使用方法
2018/09/05 Javascript
Vue使用watch监听一个对象中的属性的实现方法
2019/05/10 Javascript
文章或博客自动生成章节目录索引(支持三级)的实现代码
2020/05/10 Javascript
vue实现简单全选和反选功能
2020/09/15 Javascript
跟老齐学Python之关于类的初步认识
2014/10/11 Python
python正则分析nginx的访问日志
2017/01/17 Python
python:pandas合并csv文件的方法(图书数据集成)
2018/04/12 Python
python删除字符串中指定字符的方法
2018/08/13 Python
python批量获取html内body内容的实例
2019/01/02 Python
python 基于dlib库的人脸检测的实现
2019/11/08 Python
pygame实现飞机大战
2020/03/11 Python
CSS3 2D模拟实现摩天轮旋转效果
2016/11/16 HTML / CSS
全球领先的全景影像品牌:Insta360
2019/08/21 全球购物
西雅图电动自行车公司:Rad Power Bikes
2020/02/02 全球购物
Perfume’s Club中文官网:西班牙美妆在线零售品牌
2020/08/24 全球购物
党员干部三严三实心得体会
2014/10/13 职场文书
谢师宴家长答谢词
2015/09/30 职场文书
入党心得体会
2019/06/20 职场文书
虚拟机linux端mysql数据库无法远程访问的解决办法
2021/05/26 MySQL
logback 实现给变量指定默认值
2021/08/30 Java/Android