pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python生成指定尺寸缩略图的示例
May 07 Python
详解Python发送邮件实例
Jan 10 Python
Python爬虫包BeautifulSoup异常处理(二)
Jun 17 Python
Python sorted函数详解(高级篇)
Sep 18 Python
Python中的类与类型示例详解
Jul 10 Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 Python
python实现的Iou与Giou代码
Jan 18 Python
Python 实现自动登录+点击+滑动验证功能
Jun 10 Python
用pandas划分数据集实现训练集和测试集
Jul 20 Python
Python读写csv文件流程及异常解决
Oct 20 Python
Python初识逻辑与if语句及用法大全
Aug 07 Python
Python使用openpyxl模块处理Excel文件
Jun 05 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
php miniBB中文乱码问题解决方法
2008/11/25 PHP
php 批量添加多行文本框textarea一行一个
2014/06/03 PHP
php中HTTP_REFERER函数用法实例
2014/11/21 PHP
php 调用ffmpeg获取视频信息的简单实现
2017/04/03 PHP
推荐一些非常不错的javascript学习资源站点
2007/08/29 Javascript
jquery 学习之二 属性 文本与值(text,val)
2010/11/25 Javascript
js 高效去除数组重复元素示例代码
2013/12/19 Javascript
JavaScript+CSS控制打印格式示例介绍
2014/01/07 Javascript
javascript实现省市区三级联动下拉框菜单
2015/11/17 Javascript
使用JavaScript为Kindeditor自定义按钮增加Audio标签
2016/03/18 Javascript
浅析JS中对函数function的理解(基础篇)
2016/10/14 Javascript
Angular.JS判断复选框checkbox是否选中并实时显示
2016/11/30 Javascript
基于JavaScript实现下拉列表左右移动代码
2017/02/07 Javascript
canvas实现环形进度条效果
2017/03/23 Javascript
详解angular element()方法使用
2017/04/08 Javascript
nodejs个人博客开发第五步 分配数据
2017/04/12 NodeJs
JavaScript编写的网页小游戏,很给力
2017/08/18 Javascript
详解Vue微信公众号开发踩坑全记录
2017/08/21 Javascript
零基础之Node.js搭建API服务器的详解
2019/03/08 Javascript
js实现上传图片并显示图片名称
2019/12/18 Javascript
Vue 使用typescript如何优雅的调用swagger API
2020/09/01 Javascript
[01:15:16]DOTA2-DPC中国联赛 正赛 Elephant vs Aster BO3 第一场 1月26日
2021/03/11 DOTA
Python中实现对list做减法操作介绍
2015/01/09 Python
python实现图片文件批量重命名
2020/03/23 Python
python树莓派红外反射传感器
2019/01/21 Python
python 求1-100之间的奇数或者偶数之和的实例
2019/06/11 Python
python Django编写接口并用Jmeter测试的方法
2019/07/31 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
2020/06/17 Python
详解Django关于StreamingHttpResponse与FileResponse文件下载的最优方法
2021/01/07 Python
日本土著品牌,综合型购物网站:Cecile
2016/08/23 全球购物
C/C++有关内存的思考题
2015/12/04 面试题
党政领导班子群众路线对照检查材料
2014/10/26 职场文书
关于幸福的感言
2015/08/03 职场文书
简短的36句中秋节祝福信息语句
2019/09/09 职场文书
Python编程源码报错解决方法总结经验分享
2021/10/05 Python
table不让td文字溢出操作方法
2022/12/24 HTML / CSS