pytorch 实现cross entropy损失函数计算方式


Posted in Python onJanuary 02, 2020

均方损失函数:

pytorch 实现cross entropy损失函数计算方式

这里 loss, x, y 的维度是一样的,可以是向量或者矩阵,i 是下标。

很多的 loss 函数都有 size_average 和 reduce 两个布尔类型的参数。因为一般损失函数都是直接计算 batch 的数据,因此返回的 loss 结果都是维度为 (batch_size, ) 的向量。

(1)如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss

(2)如果 reduce = True,那么 loss 返回的是标量

a)如果 size_average = True,返回 loss.mean();
b)如果 size_average = False,返回 loss.sum();

注意:默认情况下, reduce = True,size_average = True

import torch
import numpy as np

1、返回向量

loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,5]])
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))

这里将Variable类型统一为float()(tensor类型也是调用xxx.float())

loss = loss_fn(input.float(), target.float())
print(loss)
tensor([[ 1., 1.],
  [ 1., 1.]])

2、返回平均值

a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,4]])
loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))
loss = loss_fn(input.float(), target.float())
print(loss)
tensor(0.7500)

以上这篇pytorch 实现cross entropy损失函数计算方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基于DES算法加密解密实例
Jun 03 Python
Flask的图形化管理界面搭建框架Flask-Admin的使用教程
Jun 13 Python
pandas 实现字典转换成DataFrame的方法
Jul 04 Python
NumPy 数学函数及代数运算的实现代码
Jul 18 Python
centos6.8安装python3.7无法import _ssl的解决方法
Sep 17 Python
python/Matplotlib绘制复变函数图像教程
Nov 21 Python
python with (as)语句实例详解
Feb 04 Python
Python 实现网课实时监控自动签到、打卡功能
Mar 12 Python
django admin管理工具自定义时间区间筛选器DateRangeFilter介绍
May 19 Python
python批量生成条形码的示例
Oct 10 Python
Python 操作SQLite数据库的示例
Oct 16 Python
python 自动刷新网页的两种方法
Apr 20 Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
Python面向对象原理与基础语法详解
Jan 02 #Python
Pytorch 的损失函数Loss function使用详解
Jan 02 #Python
Python面向对象封装操作案例详解 II
Jan 02 #Python
Python实现搜索算法的实例代码
Jan 02 #Python
python 实现从高分辨图像上抠取图像块
Jan 02 #Python
You might like
显示youtube视频缩略图和Vimeo视频缩略图代码分享
2014/02/13 PHP
ThinkPHP实现支付宝接口功能实例
2014/12/02 PHP
php自动载入类用法实例分析
2016/06/24 PHP
修改Laravel5.3中的路由文件与路径
2016/08/10 PHP
跨浏览器开发经验总结(四) 怎么写入剪贴板
2010/05/13 Javascript
关于JavaScript的with 语句的使用方法
2011/05/09 Javascript
前后台交互过程中json格式如何解析以及如何生成
2012/12/26 Javascript
JS获取当前日期时间并定时刷新示例
2021/03/04 Javascript
JS实现鼠标经过好友列表中的好友头像时显示资料卡的效果
2014/07/02 Javascript
js实现的简洁网页滑动tab菜单效果代码
2015/08/24 Javascript
使用getBoundingClientRect方法实现简洁的sticky组件的方法
2016/03/22 Javascript
jquery Deferred 快速解决异步回调的问题
2016/04/05 Javascript
JQuery核心函数是什么及使用方法介绍
2016/05/03 Javascript
AngularJS 所有版本下载地址
2016/09/14 Javascript
把json格式的字符串转换成javascript对象或数组的方法总结
2016/11/03 Javascript
jQuery使用ajax方法解析返回的json数据功能示例
2017/01/10 Javascript
vue-hook-form使用详解
2017/04/07 Javascript
Vue.js常用指令的使用小结
2017/06/23 Javascript
分享Bootstrap简单表格、表单、登录页面
2017/08/04 Javascript
React-intl 实现多语言的示例代码
2017/11/03 Javascript
详解Vue项目中实现锚点定位
2019/04/24 Javascript
JS前端知识点offset,scroll,client,冒泡,事件对象的应用整理总结
2019/06/27 Javascript
[02:56]DOTA2亚洲邀请赛 VG出场战队巡礼
2015/02/07 DOTA
[00:12]DAC2018 Miracle-站上中单舞台,他能否再写奇迹?
2018/04/06 DOTA
Python查找相似单词的方法
2015/03/05 Python
python使用pandas实现数据分割实例代码
2018/01/25 Python
Python实现针对给定字符串寻找最长非重复子串的方法
2018/04/21 Python
python利用百度AI实现文字识别功能
2018/11/27 Python
Python循环实现n的全排列功能
2019/09/16 Python
python基于三阶贝塞尔曲线的数据平滑算法
2019/12/27 Python
浅析PyCharm 的初始设置(知道)
2020/10/12 Python
乌克兰第一的珠宝网上商店:Gold.ua
2019/11/29 全球购物
汇源肾宝广告词
2014/03/20 职场文书
在教室放鞭炮的检讨书
2014/09/28 职场文书
php 防护xss,PHP的防御XSS注入的终极解决方案
2021/04/01 PHP
Python中如何处理常见报错
2022/01/18 Python