解决Pytorch自定义层出现多Variable共享内存错误问题


Posted in Python onJune 28, 2020

错误信息:

RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx)
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt] 

    if flag[cnt] == 1:
     # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
  grad_swapped_all = grad_output[:, c*2//3:c, :, :]

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   # 这里使用clone了
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

   # 这句话删了不会出错, 加上就吹出错
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

如果没有最后一句话, 我们看到

grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

@staticmethod
 def backward(ctx, grad_output):
  ind_lst = ctx.ind_lst
  flag = ctx.flag

  c = grad_output.size(1)
  grad_former_all = grad_output[:, 0:c//3, :, :]
  # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
  grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
  grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()

  spatial_size = ctx.h * ctx.w

  W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
  for idx in range(ctx.bz):
   W_mat = W_mat_all.select(0,idx).clone()
   for cnt in range(spatial_size):
    indS = ind_lst[idx][cnt]

    if flag[cnt] == 1:
     W_mat[cnt, indS] = 1

   W_mat_t = W_mat.t()

   grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

   grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
   grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

  grad_input = torch.cat([grad_former_all, grad_latter_all], 1)

  return grad_input, None, None, None, None, None, None, None, None, None, None

补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

如下所示:

mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
mask[:, :, -2, :, :] = 1 # except for person mask channel

为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

# Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
# You should call .clone() on the resulting tensor if you plan on modifying it
# https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现斐波那契递归函数的方法
Sep 08 Python
Django自定义插件实现网站登录验证码功能
Apr 19 Python
python机器学习实战之树回归详解
Dec 20 Python
Python内置模块ConfigParser实现配置读写功能的方法
Feb 12 Python
Python 3.7新功能之dataclass装饰器详解
Apr 21 Python
Python实现爬虫设置代理IP和伪装成浏览器的方法分享
May 07 Python
python针对不定分隔符切割提取字符串的方法
Oct 26 Python
python 在指定范围内随机生成不重复的n个数实例
Jan 28 Python
python ffmpeg任意提取视频帧的方法
Feb 21 Python
python 工具 字符串转numpy浮点数组的实现
Mar 14 Python
基于python调用jenkins-cli实现快速发布
Aug 14 Python
使用Pytorch搭建模型的步骤
Nov 16 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
Python turtle库的画笔控制说明
Jun 28 #Python
You might like
乱谈我对耳机、音箱的感受
2021/03/02 无线电
php 学习资料零碎东西
2010/12/04 PHP
js+css在交互上的应用
2010/07/18 Javascript
用jquery实现点击栏目背景色改变
2012/12/10 Javascript
raphael.js绘制中国地图 地图绘制方法
2014/02/12 Javascript
Javascript图片上传前的本地预览实例
2014/06/16 Javascript
js实现网页抽奖实例
2015/08/05 Javascript
jQuery实现浮动层随浏览器滚动条滚动的方法
2015/09/22 Javascript
详解如何在Angular中快速定位DOM元素
2017/05/17 Javascript
Vue.js实现微信过渡动画左右切换效果
2017/06/13 Javascript
AngularJS实现自定义指令与控制器数据交互的方法示例
2017/06/19 Javascript
jquery实现倒计时小应用
2017/09/19 jQuery
node简单实现一个更改头像功能的示例
2017/12/29 Javascript
vue初始化动画加载的实例
2018/09/01 Javascript
微信小程序左滑删除功能开发案例详解
2018/11/12 Javascript
vue-cli3.0+element-ui上传组件el-upload的使用
2018/12/03 Javascript
vue中利用Promise封装jsonp并调取数据
2019/06/18 Javascript
修改vue源码实现动态路由缓存的方法
2020/01/21 Javascript
前端vue如何使用高德地图
2020/11/05 Javascript
Python中列表的一些基本操作知识汇总
2015/05/20 Python
使用Python的urllib2模块处理url和图片的技巧两则
2016/02/18 Python
基于Python的XSS测试工具XSStrike使用方法
2017/07/29 Python
使用openCV去除文字中乱入的线条实例
2020/06/02 Python
python编写扎金花小程序的实例代码
2021/02/23 Python
加拿大最大的相机店:Henry’s
2017/05/17 全球购物
艺术用品:Arteza
2018/11/25 全球购物
意大利珠宝店:Luxury Zone
2019/01/05 全球购物
自考毕业生自我鉴定
2013/11/04 职场文书
餐厅考勤管理制度
2014/01/28 职场文书
擅自离岗检讨书
2014/02/11 职场文书
大学生简短的自我评价分享
2014/02/20 职场文书
酒店七夕情人节活动策划方案
2014/08/24 职场文书
新年祝酒词大全
2015/08/11 职场文书
2016年教师节特级教师获奖感言
2015/12/09 职场文书
诉讼和解协议书
2016/03/23 职场文书
python 离散点图画法的实现
2022/04/01 Python