解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python的内建模块collections的教程
Apr 28 Python
python实现画圆功能
Jan 25 Python
python跳过第一行快速读取文件内容的实例
Jul 12 Python
Python寻找两个有序数组的中位数实例详解
Dec 05 Python
pytorch中如何使用DataLoader对数据集进行批处理的方法
Aug 06 Python
Pyinstaller 打包发布经验总结
Jun 02 Python
python dict如何定义
Sep 02 Python
解决Python3.7.0 SSL低版本导致Pip无法使用问题
Sep 03 Python
如何使用 Python 读取文件和照片的创建日期
Sep 05 Python
python入门教程之基本算术运算符
Nov 13 Python
python 下载文件的多种方法汇总
Nov 17 Python
python tqdm用法及实例详解
Jun 16 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
mysql建立外键
2006/11/25 PHP
解析php扩展php_curl.dll不加载的解决方法
2013/06/26 PHP
WordPress主题制作之模板文件的引入方法
2015/12/28 PHP
thinkPHP自动验证机制详解
2016/12/05 PHP
php使用redis的几种常见操作方式和用法示例
2020/02/20 PHP
在JavaScript中处理字符串之link()方法的使用
2015/06/08 Javascript
Jquery调用iframe父页面中的元素及方法
2016/08/23 Javascript
纯js和css完成贪吃蛇小游戏demo
2016/09/01 Javascript
脚本div实现拖放功能(两种)
2017/02/13 Javascript
详解vue-router 2.0 常用基础知识点之router-link
2017/05/10 Javascript
纯JS实现只能输入数字的简单代码
2017/06/21 Javascript
JavaScript设计模式之代理模式简单实例教程
2018/07/03 Javascript
解决angularjs中同步执行http请求的方法
2018/08/13 Javascript
原生JS实现简单的无缝自动轮播效果
2018/09/26 Javascript
微信小程序使用template标签实现五星评分功能
2018/11/03 Javascript
JavaScript碎片—函数闭包(模拟面向对象)
2019/03/13 Javascript
Vue路由对象属性 .meta $route.matched详解
2019/11/04 Javascript
详解如何修改 node_modules 里的文件
2020/05/22 Javascript
如何管理Vue中的缓存页面
2021/02/06 Vue.js
js 执行上下文和作用域的相关总结
2021/02/08 Javascript
Python version 2.7 required, which was not found in the registry
2014/08/26 Python
python格式化字符串实例总结
2014/09/28 Python
跟老齐学Python之编写类之二方法
2014/10/11 Python
Python编程入门之Hello World的三种实现方式
2015/11/13 Python
python获取微信小程序手机号并绑定遇到的坑
2018/11/19 Python
Python3对称加密算法AES、DES3实例详解
2018/12/06 Python
如何基于Python批量下载音乐
2019/11/11 Python
Python-split()函数实例用法讲解
2020/12/18 Python
Pycharm 解决自动格式化冲突的设置操作
2021/01/15 Python
js实现移动端H5页面手指滑动刻度尺功能
2017/11/16 HTML / CSS
意大利简约的休闲品牌:Aspesi
2018/02/08 全球购物
中文系学生自荐信范文
2013/11/13 职场文书
小学生母亲节演讲稿
2014/05/07 职场文书
电子商务求职信
2014/06/15 职场文书
小学生感恩父母演讲稿
2014/08/28 职场文书
特此通知格式
2015/04/27 职场文书