PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
django定期执行任务(实例讲解)
Nov 03 Python
python生成excel的实例代码
Nov 08 Python
Python3解决棋盘覆盖问题的方法示例
Dec 07 Python
python读写LMDB文件的方法
Jul 02 Python
解决使用PyCharm时无法启动控制台的问题
Jan 19 Python
python快速编写单行注释多行注释的方法
Jul 31 Python
大家都说好用的Python命令行库click的使用
Nov 07 Python
Python中文分词库jieba,pkusegwg性能准确度比较
Feb 11 Python
Python编程快速上手——Excel表格创建乘法表案例分析
Feb 28 Python
Selenium+BeautifulSoup+json获取Script标签内的json数据
Dec 07 Python
Python实现byte转integer
Jun 03 Python
FP-growth算法发现频繁项集——发现频繁项集
Jun 24 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
Access数据库导入Mysql的方法之一
2006/10/09 PHP
php5.2以下版本无json_decode函数的解决方法
2014/05/25 PHP
Yii2中YiiBase自动加载类、引用文件方法分析(autoload)
2016/07/25 PHP
PHP文件操作实例总结【文件上传、下载、分页】
2018/12/08 PHP
javascript 动态table添加colspan\rowspan 参数的方法
2009/07/25 Javascript
Dom 学习总结以及实例的使用介绍
2013/04/24 Javascript
javascript+canvas制作九宫格小程序
2014/12/28 Javascript
JS实现的竖向折叠菜单代码
2015/10/21 Javascript
AngularJS 工作原理详解
2016/08/18 Javascript
jQuery编写网页版2048小游戏
2017/01/06 Javascript
jQuery实现扑克正反面翻牌效果
2017/03/10 Javascript
angular 未登录状态拦截路由跳转的方法
2018/10/09 Javascript
JavaScript中的事件与异常捕获详析
2019/02/24 Javascript
面试题:react和vue的区别分析
2019/04/08 Javascript
jQuery实现移动端下拉展现新的内容回弹动画
2020/06/24 jQuery
vue中利用three.js实现全景图的完整示例
2020/12/07 Vue.js
python常用知识梳理(必看篇)
2017/03/23 Python
利用python获取Ping结果示例代码
2017/07/06 Python
使用Python实现博客上进行自动翻页
2017/08/23 Python
tensorflow使用神经网络实现mnist分类
2018/09/08 Python
Python K最近邻从原理到实现的方法
2019/08/15 Python
python requests抓取one推送文字和图片代码实例
2019/11/04 Python
在jupyter notebook 添加 conda 环境的操作详解
2020/04/10 Python
Python3内置函数chr和ord实现进制转换
2020/06/05 Python
Python通过zookeeper实现分布式服务代码解析
2020/07/22 Python
Django框架请求生命周期实现原理
2020/11/13 Python
18-35岁旅游团的全球领导者:Contiki
2017/02/08 全球购物
经典优秀个人求职自荐信格式
2013/09/25 职场文书
项目专员岗位职责
2013/12/04 职场文书
综合办公室个人的自我评价
2013/12/22 职场文书
企业管理部经理岗位职责
2013/12/24 职场文书
三年大学自我鉴定
2014/01/16 职场文书
设备动力科岗位职责范本
2014/02/23 职场文书
企业负责人任命书
2014/06/05 职场文书
个人承诺书格式范文
2015/04/29 职场文书
Java数组与堆栈相关知识总结
2021/06/29 Java/Android