Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python web.py开发httpserver解决跨域问题实例解析
Feb 12 Python
查看django执行的sql语句及消耗时间的两种方法
May 29 Python
Python中pandas模块DataFrame创建方法示例
Jun 20 Python
对python借助百度云API对评论进行观点抽取的方法详解
Feb 21 Python
python3实现斐波那契数列(4种方法)
Jul 15 Python
Python 50行爬虫抓取并处理图灵书目过程详解
Sep 20 Python
Python numpy线性代数用法实例解析
Nov 15 Python
Python模块_PyLibTiff读取tif文件的实例
Jan 13 Python
Python logging模块写入中文出现乱码
May 21 Python
Python sorted对list和dict排序
Jun 09 Python
Python数据可视化实现漏斗图过程图解
Jul 20 Python
用Python实现职工信息管理系统
Dec 30 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
PHP速成大法
2015/01/30 PHP
PHP+Ajax实时自动检测是否联网的方法
2015/07/01 PHP
PHP中利用sleep函数实现定时执行功能实现代码
2016/08/25 PHP
Laravel 5.3 学习笔记之 错误&日志
2016/08/28 PHP
PHP实现类似题库抽题效果
2018/08/16 PHP
Javascript图像处理思路及实现代码
2012/12/25 Javascript
JS获得URL超链接的参数值实例代码
2013/06/21 Javascript
jquery.ui.draggable中文文档(原文翻译)
2013/11/15 Javascript
jquery单行文字向上滚动效果示例
2014/03/06 Javascript
jQuery+CSS实现的网页二级下滑菜单效果
2015/08/25 Javascript
JavaScript驾驭网页-获取网页元素
2016/03/24 Javascript
jQuery EasyUI菜单与按钮详解
2016/07/13 Javascript
bootstrap警告框使用方法解析
2017/01/13 Javascript
js实现年月日表单三级联动
2020/04/17 Javascript
JavaScript创建对象_动力节点Java学院整理
2017/06/27 Javascript
JavaScript栈和队列相关操作与实现方法详解
2018/12/07 Javascript
TypeScript高级用法的知识点汇总
2019/12/17 Javascript
Python编程之黑板上排列组合,你舍得解开吗
2017/10/30 Python
python中使用zip函数出现错误的原因
2018/09/28 Python
详解Python函数式编程—高阶函数
2019/03/29 Python
Python numpy线性代数用法实例解析
2019/11/15 Python
Python 创建守护进程的示例
2020/09/29 Python
一款CSS3实现多功能下拉菜单(带分享按)的教程
2014/11/05 HTML / CSS
Superdry极度乾燥官网:日本街头风格,纯英国制造品牌
2016/10/31 全球购物
西班牙购买行李箱和背包网站:Maletas Greenwich
2019/10/08 全球购物
欧姆龙医疗保健与医疗产品:Omron Healthcare
2020/02/10 全球购物
泰海淘:泰国king Power王权免税集团旗下跨境海淘综合型电商
2020/07/26 全球购物
社区居务公开实施方案
2014/03/27 职场文书
体育课课后反思
2014/04/24 职场文书
党在我心中演讲稿
2014/09/02 职场文书
2014年机关工会工作总结
2014/12/19 职场文书
小学班主任工作总结2015
2015/04/07 职场文书
公司奖励通知
2015/04/21 职场文书
工厂员工辞职信范文
2015/05/12 职场文书
决心书格式及范文
2019/06/24 职场文书
Python list列表删除元素的4种方法
2021/11/01 Python