PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解Python中字典的键的使用
Aug 19 Python
关于python pyqt5安装失败问题的解决方法
Aug 08 Python
python topN 取最大的N个数或最小的N个数方法
Jun 04 Python
django Serializer序列化使用方法详解
Oct 16 Python
python selenium firefox使用详解
Feb 26 Python
Django model select的多种用法详解
Jul 16 Python
python实现KNN分类算法
Oct 16 Python
Django 自定义404 500等错误页面的实现
Mar 08 Python
python实现吃苹果小游戏
Mar 21 Python
使用tensorflow进行音乐类型的分类
Aug 14 Python
python+selenium 简易地疫情信息自动打卡签到功能的实现代码
Aug 22 Python
python cookie反爬处理的实现
Nov 01 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
Session保存到数据库的php类分享
2011/10/24 PHP
php数据结构与算法(PHP描述) 查找与二分法查找
2012/06/21 PHP
PHP对MongoDB[NoSQL]数据库的操作
2013/03/01 PHP
Zend Framework入门教程之Zend_Mail用法示例
2016/12/08 PHP
php实现微信发红包功能
2018/07/13 PHP
JXTree对象,读取外部xml文件数据,生成树的函数
2007/04/02 Javascript
JavaScript学习笔记(二) js对象
2011/10/25 Javascript
JavaScript数组对象赋值用法实例
2015/08/04 Javascript
学习JavaScript设计模式之状态模式
2016/01/08 Javascript
jQuery Validate插件实现表单验证
2016/08/19 Javascript
Angular1.x复杂指令实例详解
2017/03/01 Javascript
详解webpack 如何集成第三方js库
2017/06/29 Javascript
angular2路由切换改变页面title的示例代码
2017/08/23 Javascript
js仿微信抢红包功能
2020/09/25 Javascript
JavaScript设计模式之调停者模式实例详解
2018/02/03 Javascript
WebSocket的简单介绍及应用
2019/05/23 Javascript
中高级前端必须了解的JS中的内存管理(推荐)
2019/07/04 Javascript
无法使用pip命令安装python第三方库的原因及解决方法
2018/06/12 Python
Python3对称加密算法AES、DES3实例详解
2018/12/06 Python
pandas 对日期类型数据的处理方法详解
2019/08/08 Python
python中的函数递归和迭代原理解析
2019/11/14 Python
python GUI模拟实现计算器
2020/06/22 Python
详解pycharm自动import所需的库的操作方法
2020/11/30 Python
HTML5 UTF-8 中文乱码的解决方法
2013/11/18 HTML / CSS
澳大利亚领先的睡衣品牌:Peter Alexander
2016/08/16 全球购物
教育技术学专业职业规划书
2014/03/03 职场文书
检察院对照“四风”认真查找问题落实整改措施
2014/09/26 职场文书
报到证办理个人委托书
2014/10/06 职场文书
幼儿园六一主持词开场白
2015/05/28 职场文书
荒岛余生观后感
2015/06/09 职场文书
小学运动会入场词
2015/07/18 职场文书
防震减灾主题班会
2015/08/14 职场文书
python开发实时可视化仪表盘的示例
2021/05/07 Python
Python爬虫基础讲解之请求
2021/05/13 Python
Vue3.0中Ref与Reactive的区别示例详析
2021/07/07 Vue.js
win11自动弹出虚拟键盘怎么关闭? Win11关闭虚拟键盘的技巧
2023/01/09 数码科技