pytorch 实现cross entropy损失函数计算方式


Posted in Python onJanuary 02, 2020

均方损失函数:

pytorch 实现cross entropy损失函数计算方式

这里 loss, x, y 的维度是一样的,可以是向量或者矩阵,i 是下标。

很多的 loss 函数都有 size_average 和 reduce 两个布尔类型的参数。因为一般损失函数都是直接计算 batch 的数据,因此返回的 loss 结果都是维度为 (batch_size, ) 的向量。

(1)如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss

(2)如果 reduce = True,那么 loss 返回的是标量

a)如果 size_average = True,返回 loss.mean();
b)如果 size_average = False,返回 loss.sum();

注意:默认情况下, reduce = True,size_average = True

import torch
import numpy as np

1、返回向量

loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,5]])
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))

这里将Variable类型统一为float()(tensor类型也是调用xxx.float())

loss = loss_fn(input.float(), target.float())
print(loss)
tensor([[ 1., 1.],
  [ 1., 1.]])

2、返回平均值

a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,4]])
loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))
loss = loss_fn(input.float(), target.float())
print(loss)
tensor(0.7500)

以上这篇pytorch 实现cross entropy损失函数计算方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用windows api锁定计算机示例
Apr 17 Python
Python实现导出数据生成excel报表的方法示例
Jul 12 Python
python自动裁剪图像代码分享
Nov 25 Python
Python3 Random模块代码详解
Dec 04 Python
Python 获得命令行参数的方法(推荐)
Jan 24 Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 Python
pytorch模型预测结果与ndarray互转方式
Jan 15 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 Python
在python中利用pycharm自定义代码块教程(三步搞定)
Apr 15 Python
基于Python组装jmx并调用JMeter实现压力测试
Nov 03 Python
基于Python实现粒子滤波效果
Dec 01 Python
在 Python 中利用 Pool 进行多线程
Apr 24 Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
Python面向对象原理与基础语法详解
Jan 02 #Python
Pytorch 的损失函数Loss function使用详解
Jan 02 #Python
Python面向对象封装操作案例详解 II
Jan 02 #Python
Python实现搜索算法的实例代码
Jan 02 #Python
python 实现从高分辨图像上抠取图像块
Jan 02 #Python
You might like
PHP 用数组降低程序的时间复杂度
2009/12/04 PHP
kohana框架上传文件验证规则写法示例
2014/07/14 PHP
php单例模式实现方法分析
2015/03/14 PHP
php cli配置文件问题分析
2015/10/15 PHP
分析PHP中单双引号的误区和双引号小隐患
2016/07/19 PHP
Paypal实现循环扣款(订阅)功能
2017/03/23 PHP
js关闭子窗体刷新父窗体实现方法
2012/12/04 Javascript
jQuery.prototype.init选择器构造函数源码思路分析
2013/02/05 Javascript
js取float型小数点后两位数的方法
2014/01/18 Javascript
js实现继承的5种方式
2015/12/01 Javascript
关于微信中a链接无法跳转问题
2016/08/02 Javascript
JavaScript之RegExp_动力节点Java学院整理
2017/06/29 Javascript
原生JS实现轮播图效果
2018/10/12 Javascript
JS简单数组排序操作示例【sort方法】
2019/05/17 Javascript
详解vue中使用axios对同一个接口连续请求导致返回数据混乱的问题
2019/11/06 Javascript
AutoJs实现刷宝短视频的思路详解
2020/05/22 Javascript
js实现九宫格布局效果
2020/05/28 Javascript
Element Alert警告的具体使用方法
2020/07/27 Javascript
JavaScript经典案例之简易计算器
2020/08/24 Javascript
nodejs使用Sequelize框架操作数据库的实现
2020/10/21 NodeJs
python append、extend与insert的区别
2016/10/13 Python
new_zeros() pytorch版本的转换方式
2020/02/18 Python
django admin后管定制-显示字段的实例
2020/03/11 Python
英国标志性奢侈品牌:Burberry
2016/07/28 全球购物
Carolina工作鞋官网:Carolina Footwear
2019/03/14 全球购物
台湾演唱会订票网站:StubHub台湾
2019/06/11 全球购物
请编程遍历页面上所有 TextBox 控件并给它赋值为 string.Empty
2015/12/03 面试题
为什么要做架构设计
2015/07/08 面试题
公司授权委托书
2014/04/04 职场文书
倡议书格式
2014/04/14 职场文书
人事专员岗位职责
2015/02/03 职场文书
培训师岗位职责
2015/02/14 职场文书
CSS中em的正确打开方式详解
2021/04/08 HTML / CSS
详解Java实现数据结构之并查集
2021/06/23 Java/Android
http通过StreamingHttpResponse完成连续的数据传输长链接方式
2022/02/12 Python
Meta增速拉垮,元宇宙难当重任
2022/04/29 数码科技