pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
利用Python中的mock库对Python代码进行模拟测试
Apr 16 Python
python fabric使用笔记
May 09 Python
Python爬虫实例_利用百度地图API批量获取城市所有的POI点
Jan 10 Python
Python数据分析之获取双色球历史信息的方法示例
Feb 03 Python
python中利用zfill方法自动给数字前面补0
Apr 10 Python
Python基于opencv调用摄像头获取个人图片的实现方法
Feb 21 Python
python3 tkinter实现点击一个按钮跳出另一个窗口的方法
Jun 13 Python
python 环境搭建 及python-3.4.4的下载和安装过程
Jul 20 Python
在django模板中实现超链接配置
Aug 21 Python
python继承threading.Thread实现有返回值的子类实例
May 02 Python
jupyter notebook保存文件默认路径更改方法汇总(亲测可以)
Jun 09 Python
用 Python 定义 Schema 并生成 Parquet 文件详情
Sep 25 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
PHP远程连接MYSQL数据库非常慢的解决方法
2008/07/05 PHP
JS弹出对话框返回值代码(asp.net后台)
2010/12/28 Javascript
用js的for循环获取radio选中的值
2013/10/21 Javascript
javascript Deferred和递归次数限制实例
2014/10/21 Javascript
浅谈Unicode与JavaScript的发展史
2015/01/19 Javascript
jQuery蓝色风格滑动导航栏代码分享
2015/08/19 Javascript
Javascript的无new构建实例详解
2016/05/15 Javascript
jQuery基础知识点总结(必看)
2016/05/31 Javascript
js方法数据验证的简单实例
2016/09/17 Javascript
EasyUi 打开对话框后控件赋值及赋值后不显示的问题解决办法
2017/01/19 Javascript
canvas实现刮刮卡效果
2017/03/14 Javascript
详解Angular路由 ng-route和ui-router的区别
2017/05/22 Javascript
vue上传图片组件编写代码
2017/07/26 Javascript
JS switch判断 三目运算 while 及 属性操作代码
2017/09/03 Javascript
jquery ajaxfileupload异步上传插件
2017/11/21 jQuery
Vux+Axios拦截器增加loading的问题及实现方法
2018/11/08 Javascript
ElementUI radio组件选中小改造
2019/08/12 Javascript
Vue proxyTable配置多个接口地址,解决跨域的问题
2020/09/11 Javascript
vue实现下拉菜单树
2020/10/22 Javascript
[06:53]DOTA2每周TOP10 精彩击杀集锦vol.3
2014/06/25 DOTA
[00:36]TI7不朽珍藏III——斯温不朽展示
2017/07/15 DOTA
[00:56]跨越时空加入战场 全新祈求者身心“失落奇艺侍祭”展示
2019/07/20 DOTA
Python实现冒泡,插入,选择排序简单实例
2014/08/18 Python
Python装饰器使用示例及实际应用例子
2015/03/06 Python
Python素数检测的方法
2015/05/11 Python
利用python的socket发送http(s)请求方法示例
2018/05/07 Python
基于DATAFRAME中元素的读取与修改方法
2018/06/08 Python
python实现翻转棋游戏(othello)
2019/07/29 Python
Python turtle绘画象棋棋盘
2019/08/21 Python
python获取依赖包和安装依赖包教程
2020/02/13 Python
CentOS 7如何实现定时执行python脚本
2020/06/24 Python
欧洲第一中国智能手机和平板电脑网上商店:CECT-SHOP
2018/01/08 全球购物
全球性的在线婚纱礼服工厂:27dress.com
2019/03/21 全球购物
学生干部培训方案
2014/06/12 职场文书
使用nginx动态转换图片大小生成缩略图
2021/03/31 Servers
详细介绍Java中的CyclicBarrier
2022/04/13 Java/Android