pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现电子词典
Apr 23 Python
python实现web方式logview的方法
Aug 10 Python
初学python的操作难点总结(新手必看篇)
Aug 03 Python
Python数据结构与算法之列表(链表,linked list)简单实现
Oct 30 Python
Python线程创建和终止实例代码
Jan 20 Python
Django实战之用户认证(用户登录与注销)
Jul 16 Python
python常用函数与用法示例
Jul 02 Python
python与C、C++混编的四种方式(小结)
Jul 15 Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 Python
Python如何实现在字符串里嵌入双引号或者单引号
Mar 02 Python
python suds访问webservice服务实现
Jun 26 Python
python中Pexpect的工作流程实例讲解
Mar 02 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
php图片验证码代码
2008/03/27 PHP
php中的boolean(布尔)类型详解
2013/10/28 PHP
php中使用getimagesize获取图片、flash等文件的尺寸信息实例
2014/04/29 PHP
php打印一个边长为N的实心和空心菱型的方法
2015/03/02 PHP
Laravel框架定时任务2种实现方式示例
2018/12/08 PHP
YII框架模块化处理操作示例
2019/04/26 PHP
JS 参数传递的实际应用代码分析
2009/09/13 Javascript
JavaScript类和继承 constructor属性
2010/03/04 Javascript
Javascript 判断Flash是否加载完成的代码
2010/04/12 Javascript
JS禁用浏览器退格键实现思路及代码
2013/10/29 Javascript
node.js中的http.createClient方法使用说明
2014/12/15 Javascript
jQuery实现带延迟的二级tab切换下拉列表效果
2015/09/01 Javascript
angularjs创建弹出框实现拖动效果
2020/08/25 Javascript
Js获取图片原始宽高的实现代码
2016/05/17 Javascript
如何解决hover在ie6中的兼容性问题
2016/12/15 Javascript
详解webpack+vue-cli项目打包技巧
2017/06/17 Javascript
JS中将多个逗号替换为一个逗号的实现代码
2017/06/23 Javascript
js指定日期增加指定月份的实现方法
2018/12/19 Javascript
python实现的解析crontab配置文件代码
2014/06/30 Python
Python Web程序部署到Ubuntu服务器上的方法
2018/02/22 Python
python使用scrapy发送post请求的坑
2018/09/04 Python
浅谈python已知元素,获取元素索引(numpy,pandas)
2019/11/26 Python
Python While循环语句实例演示及原理解析
2020/01/03 Python
Pytorch DataLoader 变长数据处理方式
2020/01/08 Python
Python操作MongoDb数据库流程详解
2020/03/05 Python
Django Serializer HiddenField隐藏字段实例
2020/03/31 Python
阿迪达斯意大利在线商店:adidas意大利
2016/09/19 全球购物
威尔逊皮革:Wilsons Leather
2018/12/07 全球购物
应届毕业生的个人自我鉴定
2013/10/24 职场文书
税务专业毕业生自荐信
2013/11/10 职场文书
好家长事迹材料
2014/01/23 职场文书
餐厅楼面部长岗位职责范文
2014/02/16 职场文书
《谁的本领大》教后反思
2014/04/25 职场文书
毕业典礼致辞
2015/07/29 职场文书
 分享一个Python 遇到数据库超好用的模块
2022/04/06 Python
Redis实现分布式锁的五种方法详解
2022/06/14 Redis