PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 控制语句
Nov 03 Python
Python提示[Errno 32]Broken pipe导致线程crash错误解决方法
Nov 19 Python
用Python编写一个简单的Lisp解释器的教程
Apr 03 Python
Python使用修饰器执行函数的参数检查功能示例
Sep 26 Python
flask中过滤器的使用详解
Aug 01 Python
在IPython中执行Python程序文件的示例
Nov 01 Python
python使用pandas处理大数据节省内存技巧(推荐)
May 05 Python
Python 多线程其他属性以及继承Thread类详解
Aug 28 Python
Python-for循环的内部机制
Jun 12 Python
python多线程semaphore实现线程数控制的示例
Aug 10 Python
Pytorch生成随机数Tensor的方法汇总
Sep 09 Python
用python进行视频剪辑
Nov 02 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
PHP程序员最常犯的11个MySQL错误小结
2010/11/20 PHP
php编写的抽奖程序中奖概率算法
2015/05/14 PHP
Zend Framework动作助手Json用法实例分析
2016/03/05 PHP
使用composer安装使用thinkphp6.0框架问题【视频教程】
2019/10/01 PHP
HTML node相关的一些资料整理
2010/01/01 Javascript
在firefox和Chrome下关闭浏览器窗口无效的解决方法
2014/01/16 Javascript
Jquery实现的一种常用高亮效果示例代码
2014/01/28 Javascript
js判断横竖屏及禁止浏览器滑动条示例
2014/04/29 Javascript
JavaScript更改字符串的大小写
2015/05/07 Javascript
JavaScript中字符串(string)转json的2种方法
2015/06/25 Javascript
jquery实现用户信息修改验证输入方法汇总
2015/07/18 Javascript
JS实现灵巧的下拉导航效果代码
2015/08/25 Javascript
jquery心形点赞关注效果的简单实现
2016/11/14 Javascript
基于jQuery实现一个marquee无缝滚动的插件
2017/03/09 Javascript
nodejs动态创建二维码的方法
2017/08/12 NodeJs
解决JQuery全选/反选第二次失效的问题
2017/10/11 jQuery
前端axios下载excel文件(二进制)的处理方法
2018/07/31 Javascript
使用FormData实现上传多个文件
2018/12/04 Javascript
JavaScript实现页面中录音功能的方法
2019/06/04 Javascript
[02:43]DOTA2英雄基础教程 半人马战行者
2014/01/13 DOTA
[08:29]DOTA2每周TOP10 精彩击杀集锦vol.7
2014/06/25 DOTA
[01:29]2017 DOTA2国际邀请赛官方英雄手办展示
2017/03/18 DOTA
python 自动化将markdown文件转成html文件的方法
2016/09/23 Python
Python 通过URL打开图片实例详解
2017/06/01 Python
Python编程实现的简单Web服务器示例
2017/06/22 Python
Python使用getpass库读取密码的示例
2017/10/10 Python
解决python3中解压zip文件是文件名乱码的问题
2018/03/22 Python
python GUI库图形界面开发之PyQt5多线程中信号与槽的详细使用方法与实例
2020/03/08 Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
2020/06/10 Python
python七种方法判断字符串是否包含子串
2020/08/18 Python
Sentry错误日志监控使用方法解析
2020/11/12 Python
CSS3制作Dropdown下拉菜单的方法
2015/07/18 HTML / CSS
智能电子秤、手表和健康监测仪:Withings(之前为诺基亚健康)
2018/10/30 全球购物
大学生职业生涯规划范文——找准自我,定位人生
2014/01/23 职场文书
西门豹教学反思
2014/02/04 职场文书
用python画城市轮播地图
2021/05/28 Python