PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细介绍Python函数中的默认参数
Mar 30 Python
Python正则抓取新闻标题和链接的方法示例
Apr 24 Python
浅谈python和C语言混编的几种方式(推荐)
Sep 27 Python
Python获取指定文件夹下的文件名的方法
Feb 06 Python
对PyTorch torch.stack的实例讲解
Jul 30 Python
使用Python实现从各个子文件夹中复制指定文件的方法
Oct 25 Python
Python企业编码生成系统之主程序模块设计详解
Jul 26 Python
程序员的七夕用30行代码让Python化身表白神器
Aug 07 Python
python selenium xpath定位操作
Sep 01 Python
Prometheus开发中间件Exporter过程详解
Nov 30 Python
python爬虫 requests-html的使用
Nov 30 Python
浅谈Selenium 控制浏览器的常用方法
Dec 04 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
关于PHP session 存储方式的详细介绍
2013/06/25 PHP
PHP中Header使用的HTTP协议及常用方法小结
2014/11/04 PHP
php实现生成验证码实例分享
2016/04/10 PHP
PHP对象克隆clone用法示例
2016/09/28 PHP
Jquery升级新版本后选择器的语法问题
2010/06/02 Javascript
JQuery自动触发事件的方法
2015/06/13 Javascript
jQuery实现灰蓝风格标准二级下拉菜单效果代码
2015/08/31 Javascript
实例详解JavaScript获取链接参数的方法
2016/01/01 Javascript
AngularJS 入门教程之HTML DOM实例详解
2016/07/28 Javascript
详解AngularJS 路由 resolve用法
2017/04/24 Javascript
JavaScript中数组常见操作技巧
2017/09/01 Javascript
Angular 作用域scope的具体使用
2017/12/11 Javascript
JS生成随机打乱数组的方法示例
2017/12/23 Javascript
vue组件(全局,局部,动态加载组件)
2018/09/02 Javascript
vue.js 打包时出现空白页和路径错误问题及解决方法
2019/06/26 Javascript
Python使用MySQLdb for Python操作数据库教程
2014/10/11 Python
python实现人脸识别代码
2017/11/08 Python
Python 实现删除某路径下文件及文件夹的实例讲解
2018/04/24 Python
python+mysql实现学生信息查询系统
2019/02/21 Python
python中的colorlog库使用详解
2019/07/05 Python
Python for循环与getitem的关系详解
2020/01/02 Python
Python3 实现爬取网站下所有URL方式
2020/01/16 Python
python matplotlib中的subplot函数使用详解
2020/01/19 Python
Python面向对象程序设计之私有变量,私有方法原理与用法分析
2020/03/23 Python
Python环境下安装PyGame和PyOpenGL的方法
2020/03/25 Python
利用CSS3实现毛玻璃效果示例源码
2016/09/25 HTML / CSS
h5页面背景图很长要有滚动条滑动效果的实现
2021/01/27 HTML / CSS
最畅销的视频游戏享受高达90%的折扣:CDKeys
2020/02/10 全球购物
大专生自我鉴定范文
2013/10/01 职场文书
公司员工的自我评价范例
2013/11/01 职场文书
给老婆的搞笑检讨书
2014/01/12 职场文书
优秀教师先进事迹
2014/01/22 职场文书
平安工地建设方案
2014/05/06 职场文书
植物生产学专业求职信
2014/08/08 职场文书
工作证明范本(2篇)
2014/09/14 职场文书
感恩教师节主题班会
2015/08/12 职场文书