pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计算书页码的统计数字问题实例
Sep 26 Python
用virtualenv建立多个Python独立虚拟开发环境
Jul 06 Python
Python打印“菱形”星号代码方法
Feb 05 Python
Python 正则表达式匹配字符串中的http链接方法
Dec 25 Python
Python基于opencv调用摄像头获取个人图片的实现方法
Feb 21 Python
Python 实现自动导入缺失的库
Oct 29 Python
Python基于pandas爬取网页表格数据
May 11 Python
Python 如何创建一个线程池
Jul 28 Python
Idea安装python显示无SDK问题解决方案
Aug 12 Python
Django如何实现密码错误报错提醒
Sep 04 Python
python 多线程共享全局变量的优劣
Sep 24 Python
Python扫描端口的实现
Jan 25 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
PHP删除非空目录的函数代码小结
2013/02/28 PHP
探讨:web上存漏洞及原理分析、防范方法
2013/06/29 PHP
php过滤所有的空白字符(空格、全角空格、换行等)
2015/10/27 PHP
CentOS下搭建PHP环境与WordPress博客程序的全流程总结
2016/05/07 PHP
php实现异步将远程链接上内容(图片或内容)写到本地的方法
2016/11/30 PHP
Yii2.0框架模型添加/修改/删除数据操作示例
2019/07/18 PHP
Laravel 集成微信用户登录和绑定的实现
2019/12/27 PHP
可以显示单图片,多图片ajax请求的ThickBox3.1类下载
2007/12/23 Javascript
JS 文件大小判断的实现代码
2010/04/07 Javascript
浅析jQuery中常用的元素查找方法总结
2013/07/04 Javascript
javascript实例分享---具有立体效果的图片特效
2014/06/08 Javascript
JQuery控制radio选中和不选中方法总结
2015/04/15 Javascript
javascript事件绑定学习要点
2016/03/09 Javascript
jQuery实现的页面遮罩层功能示例【测试可用】
2017/10/14 jQuery
JS实现登录页密码的显示和隐藏功能
2017/12/06 Javascript
详解vue-router 命名路由和命名视图
2018/06/01 Javascript
详解React之父子组件传递和其它一些要点
2018/06/25 Javascript
vue自定义指令之面板拖拽的实现
2019/04/14 Javascript
layui-select动态选中值的例子
2019/09/23 Javascript
layui table 复选框跳页后再回来保持原来选中的状态示例
2019/10/26 Javascript
vue-cli中实现响应式布局的方法
2021/03/02 Vue.js
跟老齐学Python之赋值,简单也不简单
2014/09/24 Python
Python最长公共子串算法实例
2015/03/07 Python
python实现在windows下操作word的方法
2015/04/28 Python
78行Python代码实现现微信撤回消息功能
2018/07/26 Python
Python实现中英文全文搜索的示例
2020/12/04 Python
贝嫂喜欢的婴儿品牌,个性化的婴儿礼物:My 1st Years
2017/11/19 全球购物
益模软件Java笔试题
2012/03/27 面试题
医药代表个人的求职信分享
2013/12/08 职场文书
学校运动会开幕演讲稿
2014/01/04 职场文书
大学活动总结模板
2014/07/10 职场文书
学籍证明模板
2014/11/21 职场文书
2014年销售工作总结范文
2014/12/01 职场文书
给男朋友的道歉短信
2015/05/12 职场文书
小学英语教学随笔
2015/08/14 职场文书
导游词之云南-元阳梯田
2019/10/08 职场文书