在pytorch中实现只让指定变量向后传播梯度


Posted in Python onFebruary 29, 2020

pytorch中如何只让指定变量向后传播梯度?

(或者说如何让指定变量不参与后向传播?)

有以下公式,假如要让L对xvar求导:

在pytorch中实现只让指定变量向后传播梯度

(1)中,L对xvar的求导将同时计算out1部分和out2部分;

(2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False;

(3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False;

验证如下:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""
 
import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()

pytorch中,将变量的requires_grad设为False,即可让变量不参与梯度的后向传播;

但是不能直接将out1.requires_grad=False;

其实,Variable类型提供了detach()方法,所返回变量的requires_grad为False。

注意:如果out1和out2的requires_grad都为False的话,那么xvar.grad就出错了,因为梯度没有传到xvar

补充:

volatile=True表示这个变量不计算梯度, 参考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.

以上这篇在pytorch中实现只让指定变量向后传播梯度就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 合并文件的具体实例
Aug 08 Python
跟老齐学Python之永远强大的函数
Sep 14 Python
python获取一组数据里最大值max函数用法实例
May 26 Python
详解Django框架中用户的登录和退出的实现
Jul 23 Python
Python中map,reduce,filter和sorted函数的使用方法
Aug 17 Python
离线安装Pyecharts的步骤以及依赖包流程
Apr 23 Python
PyQt5 窗口切换与自定义对话框的实例
Jun 20 Python
Python3.0 实现决策树算法的流程
Aug 08 Python
python3实现微型的web服务器
Sep 03 Python
基于python实现把json数据转换成Excel表格
May 07 Python
numba提升python运行速度的实例方法
Jan 25 Python
Python关于OS文件目录处理的实例分享
May 23 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 #Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 #Python
You might like
php小偷相关截取函数备忘
2010/11/28 PHP
对于PHP 5.4 你必须要知道的
2013/08/07 PHP
通过chrome浏览器控制台(Console)进行PHP Debug的方法
2016/10/19 PHP
tp5框架基于ajax实现异步删除图片的方法示例
2020/02/10 PHP
JSQL 基于客户端的成绩统计实现方法
2010/05/05 Javascript
动态加载jquery库的方法
2014/02/12 Javascript
js实现文本框选中的方法
2015/05/26 Javascript
JavaScript代码性能优化总结(推荐)
2016/05/16 Javascript
对象转换为原始值的实现方法
2016/06/06 Javascript
微信小程序左滑动显示菜单功能的实现
2018/06/14 Javascript
详解Angular5/Angular6项目如何添加热更新(HMR)功能
2018/10/10 Javascript
微信小程序提交form操作示例
2018/12/30 Javascript
JavaScript常见继承模式实例小结
2019/01/11 Javascript
layui 动态设置checbox 选中状态的例子
2019/09/02 Javascript
layer插件实现在弹出层中弹出一警告提示并关闭弹出层的方法
2019/09/24 Javascript
JS实现压缩上传图片base64长度功能
2019/12/03 Javascript
小程序使用wxs解决wxml保留2位小数问题
2019/12/13 Javascript
微信小程序开发中var that =this的用法详解
2020/01/18 Javascript
Vue+Vuex实现自动登录的知识点详解
2020/03/04 Javascript
Python计算程序运行时间的方法
2014/12/13 Python
使用70行Python代码实现一个递归下降解析器的教程
2015/04/17 Python
python计算方程式根的方法
2015/05/07 Python
python 将json数据提取转化为txt的方法
2018/10/26 Python
三个python爬虫项目实例代码
2019/12/28 Python
python logging 日志的级别调整方式
2020/02/21 Python
python 代码实现k-means聚类分析的思路(不使用现成聚类库)
2020/06/01 Python
使用Python项目生成所有依赖包的清单方式
2020/07/13 Python
Python远程方法调用实现过程解析
2020/07/28 Python
Python3爬虫里关于识别微博宫格验证码的知识点详解
2020/07/30 Python
The Kooples美国官方网站:为情侣提供的法国当代时尚品牌
2019/01/03 全球购物
俄罗斯小米家用电器、电子产品和智能家居商店:Poood.ru
2020/04/03 全球购物
函授毕业生自我鉴定
2013/11/06 职场文书
企业道德讲堂实施方案
2014/03/19 职场文书
安全教育演讲稿
2014/05/09 职场文书
工程负责人任命书
2014/06/06 职场文书
MYSQL如何查看操作日志详解
2022/05/30 MySQL