pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python BeautifulSoup使用方法详解
Nov 21 Python
Python3实现购物车功能
Apr 18 Python
利用Python写一个爬妹子的爬虫
Jun 08 Python
Python中实现单例模式的n种方式和原理
Nov 14 Python
通过python爬虫赚钱的方法
Jan 29 Python
Django实现学员管理系统
Feb 26 Python
python实现图片中文字分割效果
Jul 22 Python
Python3.7 读取 mp3 音频文件生成波形图效果
Nov 05 Python
使用Python的networkx绘制精美网络图教程
Nov 21 Python
使用Matplotlib 绘制精美的数学图形例子
Dec 13 Python
pycharm 关掉syntax检查操作
Jun 09 Python
Django多个app urls配置代码实例
Nov 26 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
php 大数据量及海量数据处理算法总结
2011/05/07 PHP
php中计算未知长度的字符串哪个字符出现的次数最多的代码
2012/08/14 PHP
PHP中file_exists与is_file,is_dir的区别介绍
2012/09/12 PHP
php float不四舍五入截取浮点型字符串方法总结
2013/10/28 PHP
PHP比你想象的好得多
2014/11/27 PHP
Prototype中dom对象方法汇总
2008/09/17 Javascript
javascript脚本编程解决考试分数统计问题
2008/10/18 Javascript
JavaScript CSS修改学习第五章 给“上传”添加样式
2010/02/19 Javascript
jQuery LigerUI 使用教程入门篇
2012/01/18 Javascript
Js实现网页键盘控制翻页的方法
2014/10/30 Javascript
小心!AngularJS结合RequireJS做文件合并压缩的那些坑
2016/01/09 Javascript
详解jquery事件delegate()的使用方法
2016/01/25 Javascript
JavaScript中的return布尔值的用法和原理解析
2017/08/14 Javascript
jQuery Position方法使用和兼容性
2017/08/23 jQuery
浅谈js中的this问题
2017/08/31 Javascript
详解用webpack的CommonsChunkPlugin提取公共代码的3种方式
2017/11/09 Javascript
基于JavaScript实现表格滚动分页
2017/11/22 Javascript
Bootstrap实现可折叠分组侧边导航菜单
2018/03/07 Javascript
JS简单获取并修改input文本框内容的方法示例
2018/04/08 Javascript
JS实现的透明度渐变动画效果示例
2018/04/28 Javascript
[02:33]DOTA2英雄基础教程 司夜刺客
2013/12/04 DOTA
Python实现删除文件但保留指定文件
2015/06/21 Python
python定向爬取淘宝商品价格
2018/02/27 Python
python 美化输出信息的实例
2018/10/15 Python
python安装gdal的两种方法
2019/10/29 Python
pyecharts调整图例与各板块的位置间距实例
2020/05/16 Python
对python pandas中 inplace 参数的理解
2020/06/27 Python
Python用access判断文件是否被占用的实例方法
2020/12/17 Python
Linux文件系统类型
2012/02/15 面试题
高校教师岗位职责
2014/03/18 职场文书
2014年教研活动总结范文
2014/04/26 职场文书
转让协议书
2015/01/27 职场文书
超级详细实用的pycharm常用快捷键
2021/05/12 Python
新手必备之MySQL msi版本下载安装图文详细教程
2021/05/21 MySQL
Win10玩csgo闪退如何解决?Win10玩csgo闪退的解决方法
2022/07/23 数码科技
clear 万能清除浮动(clearfix:after)
2023/05/21 HTML / CSS