PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计算书页码的统计数字问题实例
Sep 26 Python
微信跳一跳游戏python脚本
Apr 01 Python
Python使用matplotlib绘图无法显示中文问题的解决方法
Mar 14 Python
对numpy中数组元素的统一赋值实例
Apr 04 Python
numpy.std() 计算矩阵标准差的方法
Jul 11 Python
python 通过SSHTunnelForwarder隧道连接redis的方法
Feb 19 Python
Django 反向生成url实例详解
Jul 30 Python
wxPython电子表格功能wx.grid实例教程
Nov 19 Python
python获取依赖包和安装依赖包教程
Feb 13 Python
Python3如何判断三角形的类型
Apr 12 Python
Python使用Pyqt5实现简易浏览器(最新版本测试过)
Apr 27 Python
vscode配置anaconda3的方法步骤
Aug 08 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
Windows7下PHP开发环境安装配置图文方法
2010/05/20 PHP
PHP简单处理表单输入的特殊字符的方法
2016/02/03 PHP
[原创]php实现 data url的图片生成与保存
2016/12/04 PHP
php微信开发之关注事件
2018/06/14 PHP
PHP mongodb操作类定义与用法示例【适合mongodb2.x和mongodb3.x】
2018/06/16 PHP
php封装的page分页类完整实例代码
2020/02/01 PHP
发一个自己用JS写的实用看图工具实现代码
2008/07/26 Javascript
解决jQuery插件tipswindown与hintbox冲突
2010/11/05 Javascript
js 实现日期灵活格式化的小例子
2013/07/14 Javascript
javascript对JSON数据排序的3个例子
2014/04/12 Javascript
JS文件/图片从电脑里面拖拽到浏览器上传文件/图片
2017/03/08 Javascript
KOA+egg.js集成kafka消息队列的示例
2018/11/09 Javascript
[01:19:34]2014 DOTA2国际邀请赛中国区预选赛 New Element VS Dream time
2014/05/22 DOTA
python脚本实现分析dns日志并对受访域名排行
2014/09/18 Python
在Python下使用Txt2Html实现网页过滤代理的教程
2015/04/11 Python
python Celery定时任务的示例
2018/03/13 Python
Python实现确认字符串是否包含指定字符串的实例
2018/05/02 Python
python2 与 python3 实现共存的方法
2018/07/12 Python
Python 实现王者荣耀中的敏感词过滤示例
2019/01/21 Python
python判断字符串或者集合是否为空的实例
2019/01/23 Python
200行python代码实现2048游戏
2019/07/17 Python
用Python生成HTML表格的方法示例
2020/03/06 Python
anaconda安装pytorch1.7.1和torchvision0.8.2的方法(亲测可用)
2021/02/01 Python
详解CSS3伸缩布局盒模型Flex布局
2018/08/20 HTML / CSS
使用 css3 transform 属性来变换背景图的方法
2019/05/07 HTML / CSS
大学自主招生自荐信
2013/12/16 职场文书
日语专业个人求职信范文
2014/02/02 职场文书
优秀毕业生事迹材料
2014/02/12 职场文书
自主招生教师推荐信
2014/05/10 职场文书
篮球比赛口号
2014/06/10 职场文书
2014财务年终工作总结
2014/12/08 职场文书
2014年平安夜寄语
2014/12/08 职场文书
大二学年个人总结
2015/03/03 职场文书
交通安全教育主题班会
2015/08/12 职场文书
《詹天佑》教学反思
2016/02/20 职场文书
公司开业的祝贺语大全(60条)
2019/07/05 职场文书