pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python提取html文件中的特定数据的实现代码
Mar 24 Python
Python中__new__与__init__方法的区别详解
May 04 Python
python实现的AES双向对称加密解密与用法分析
May 02 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 Python
Python简单实现阿拉伯数字和罗马数字的互相转换功能示例
Apr 17 Python
用uWSGI和Nginx部署Flask项目的方法示例
May 05 Python
详解python中的数据类型和控制流
Aug 08 Python
Python异常模块traceback用法实例分析
Oct 22 Python
pytorch的梯度计算以及backward方法详解
Jan 10 Python
python实现电子词典
Mar 03 Python
python pyqtgraph 保存图片到本地的实例
Mar 14 Python
python数据处理之Pandas类型转换
Apr 28 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
php header()函数使用说明
2008/07/10 PHP
详解PHP的Yii框架中的Controller控制器
2016/03/29 PHP
IE6/7/8中Option元素未设value时Select将获取空字符串
2011/04/07 Javascript
js实现文字超出部分用省略号代替实例代码
2016/09/01 Javascript
NodeJS、NPM安装配置步骤(windows版本) 以及环境变量详解
2017/05/13 NodeJs
nodejs之get/post请求的几种方式小结
2017/07/26 NodeJs
推荐VSCode 上特别好用的 Vue 插件之vetur
2017/09/14 Javascript
vue实现组件之间传值功能示例
2018/07/13 Javascript
基于layui内置模块(element常用元素的操作)
2019/09/20 Javascript
基于javascript的无缝滚动动画1
2020/08/07 Javascript
TypeScript 运行时类型检查补充工具
2020/09/28 Javascript
[01:21]辉夜杯战队访谈宣传片—CDEC
2015/12/25 DOTA
微信跳一跳辅助python代码实现
2018/01/05 Python
对Python中gensim库word2vec的使用详解
2018/05/08 Python
详解TensorFlow查看ckpt中变量的几种方法
2018/06/19 Python
对Python subprocess.Popen子进程管道阻塞详解
2018/10/29 Python
对python判断是否回文数的实例详解
2019/02/08 Python
Python3几个常见问题的处理方法
2019/02/26 Python
python实现多进程通信实例分析
2019/09/01 Python
基于python实现从尾到头打印链表
2019/11/02 Python
Python内置函数locals和globals对比
2020/04/28 Python
CSS3中box-shadow的用法介绍
2015/07/15 HTML / CSS
Cinque网上商店:德国服装品牌
2019/03/17 全球购物
Bose加拿大官方网站:美国知名音响品牌
2019/03/21 全球购物
限量版运动鞋和街头服饰:TheDrop
2020/09/06 全球购物
毕业生的自我评价范文
2013/12/31 职场文书
公司门卫管理制度
2014/02/01 职场文书
保密工作实施方案
2014/02/24 职场文书
求职面试个人自我评价
2014/02/28 职场文书
新闻学专业职业生涯规划范文:我的人生我做主
2014/09/12 职场文书
警察正风肃纪剖析材料
2014/10/16 职场文书
锦旗赠语
2015/06/23 职场文书
高温慰问简报
2015/07/21 职场文书
《角的初步认识》教学反思
2016/02/17 职场文书
评估“风险”创业计划的几大要点
2019/08/12 职场文书
python如何读取.mtx文件
2021/04/22 Python