Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python抓取豆瓣图片并自动保存示例学习
Jan 10 Python
用Python的Django框架完成视频处理任务的教程
Apr 02 Python
基于Python的XSS测试工具XSStrike使用方法
Jul 29 Python
python修改txt文件中的某一项方法
Dec 29 Python
Gauss-Seidel迭代算法的Python实现详解
Jun 29 Python
解决pycharm 工具栏Tool中找不到Run manager.py Task的问题
Jul 01 Python
一文详述 Python 中的 property 语法
Sep 01 Python
基于Python模拟浏览器发送http请求
Nov 06 Python
python3字符串输出常见面试题总结
Dec 01 Python
python爬虫基础之urllib的使用
Dec 31 Python
如何使用Python对NetCDF数据做空间相关分析
Apr 21 Python
pandas中DataFrame数据合并连接(merge、join、concat)
May 30 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
PHP目录操作实例总结
2016/09/27 PHP
php 多个变量指向同一个引用($b = &$a)用法分析
2019/11/13 PHP
JavaScript中null与undefined分析
2009/07/25 Javascript
JS继承 笔记
2011/07/13 Javascript
修复bash漏洞的shell脚本分享
2014/12/31 Javascript
jquery文档操作wrap()方法实例简述
2015/01/10 Javascript
JavaScript代码性能优化总结(推荐)
2016/05/16 Javascript
js无法获取到html标签的属性的解决方法
2016/07/26 Javascript
jQuery解决input元素的blur事件和其他非表单元素的click事件冲突问题
2016/08/15 Javascript
微信小程序 轮播图swiper详解及实例(源码下载)
2017/01/11 Javascript
简单实现jQuery手风琴效果
2017/08/18 jQuery
jquery无缝图片轮播组件封装
2020/11/25 jQuery
Element-ui中元素滚动时el-option超出元素区域的问题
2019/05/30 Javascript
Vue实现穿梭框效果
2020/09/30 Javascript
[53:15]2018DOTA2亚洲邀请赛3月29日 小组赛A组 KG VS OG
2018/03/30 DOTA
Python的函数的一些高阶特性
2015/04/27 Python
详解python中的 is 操作符
2017/12/26 Python
对python3 中方法各种参数和返回值详解
2018/12/15 Python
实例讲解Python中浮点型的基本内容
2019/02/11 Python
使用Python+wxpy 找出微信里把你删除的好友实例
2019/02/21 Python
Python何时应该使用Lambda函数
2019/07/02 Python
基于sklearn实现Bagging算法(python)
2019/07/11 Python
python标准库sys和OS的函数使用方法与实例详解
2020/02/12 Python
python中逻辑与或(and、or)和按位与或异或(&、|、^)区别
2020/08/05 Python
用python绘制樱花树
2020/10/09 Python
Python通过递归函数输出嵌套列表元素
2020/10/15 Python
关于h5中的fetch方法解读(小结)
2017/11/15 HTML / CSS
美国一家主营日韩美妆护肤品的在线商店:iMomoko
2016/09/11 全球购物
英国珠宝钟表和家居礼品精品店:David Shuttle
2018/02/24 全球购物
捷克玩具商店:Bambule
2019/02/23 全球购物
校园报刊亭创业计划书
2014/01/02 职场文书
教师个人剖析材料
2014/02/05 职场文书
优秀毕业生推荐信范文
2014/03/07 职场文书
课题研究阶段性总结
2015/08/13 职场文书
2016年公司新年寄语
2015/08/17 职场文书
教你如何用cmd快速登录服务器
2022/06/10 Servers