Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
50行代码实现贪吃蛇(具体思路及代码)
Apr 27 Python
本地文件上传到七牛云服务器示例(七牛云存储)
Jan 11 Python
浅析Python中将单词首字母大写的capitalize()方法
May 18 Python
Python 编码Basic Auth使用方法简单实例
May 25 Python
对numpy数据写入文件的方法讲解
Jul 09 Python
python 字典中取值的两种方法小结
Aug 02 Python
解决Shell执行python文件,传参空格引起的问题
Oct 30 Python
python程序快速缩进多行代码方法总结
Jun 23 Python
Python实现简单的2048小游戏
Mar 01 Python
Python OpenCV 彩色与灰度图像的转换实现
Jun 05 Python
教你如何用Python实现人脸识别(含源代码)
Jun 23 Python
利用Python多线程实现图片下载器
Mar 25 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
php htmlentities和htmlspecialchars 的区别
2008/08/18 PHP
php数组一对一替换实现代码
2012/08/31 PHP
php使用Cookie控制访问授权的方法
2015/01/21 PHP
PHP简单实现生成txt文件到指定目录的方法
2016/04/25 PHP
PHP简单装饰器模式实现与用法示例
2017/06/22 PHP
JS Range HTML文档/文字内容选中、库及应用介绍
2011/05/12 Javascript
THREE.JS入门教程(5)你应当知道的十件事
2013/01/24 Javascript
js创建子窗口并且回传值示例代码
2013/07/02 Javascript
Nodejs使用mysql模块之获得更新和删除影响的行数的方法
2014/03/18 NodeJs
js如何准确获取当前页面url网址信息
2020/09/13 Javascript
ionic js 模型 $ionicModal 可以遮住用户主界面的内容框
2016/06/06 Javascript
浅谈jquery设置和获得checkbox选中的问题
2016/08/19 Javascript
JS字符串false转boolean的方法(推荐)
2017/03/08 Javascript
详解用vue-cli来搭建vue项目和webpack
2017/04/20 Javascript
微信小程序 swiper组件构建轮播图的实例
2017/09/20 Javascript
微信小程序实现随机验证码功能
2018/12/20 Javascript
解决layer.open弹出框不能获取input框的值为空的问题
2019/09/10 Javascript
Vue+Node实现商品列表的分页、排序、筛选,添加购物车功能详解
2019/12/07 Javascript
Python中文件遍历的两种方法
2014/06/16 Python
对python实现模板生成脚本的方法详解
2019/01/30 Python
对Python3中dict.keys()转换成list类型的方法详解
2019/02/03 Python
使用Django搭建网站实现商品分页功能
2020/05/22 Python
纯css3实现的竖形无限级导航
2014/12/10 HTML / CSS
CSS3实现歌词进度文字颜色填充变化动态效果的思路详解
2020/06/02 HTML / CSS
使用html5制作loading图的示例
2014/04/14 HTML / CSS
html5启动原生APP总结
2020/07/03 HTML / CSS
韩国保养品、日本药妆购物网:小三美日
2018/12/30 全球购物
求职者应聘的自我评价
2013/10/16 职场文书
经典安踏广告词
2014/03/21 职场文书
小学语文课后反思精选
2014/04/25 职场文书
大学运动会加油稿200字(5篇)
2014/09/27 职场文书
九寨沟导游词
2015/02/02 职场文书
2015年乡镇食品安全工作总结
2015/10/22 职场文书
2016党校培训心得体会
2016/01/07 职场文书
机关干部作风整顿心得体会
2016/01/22 职场文书
pytorch 实现在测试的时候启用dropout
2021/05/27 Python