tensorflow自定义激活函数实例


Posted in Python onFebruary 04, 2020

前言:因为研究工作的需要,要更改激活函数以适应自己的网络模型,但是单纯的函数替换会训练导致不能收敛。这里还有些不清楚为什么,希望有人可以给出解释。查了一些博客,发现了解决之道。下面将解决过程贴出来供大家指正。

1.背景

之前听某位老师提到说tensorflow可以在不给梯度函数的基础上做梯度下降,所以尝试了替换。我的例子时将ReLU改为平方。即原来的激活函数是 tensorflow自定义激活函数实例 现在换成 tensorflow自定义激活函数实例

单纯替换激活函数并不能较好的效果,在我的实验中,迭代到一定批次,准确率就会下降,最终降为10%左右保持稳定。而事实上,这中间最好的训练精度为92%。资源有限,问了对神经网络颇有研究的同学,说是激活函数的问题,然而某篇很厉害的论文中提到其精度在99%,着实有意思。之后开始研究自己些梯度函数以完成训练。

2.大概流程

首先要确定梯度函数,之后将其处理为tf能接受的类型。

2.1定义自己的激活函数

def square(x):
 return pow(x, 2)

2.2 定义该激活函数的一次梯度函数

def square_grad(x):
 return 2 * x

2.3 让numpy数组每一个元素都能应用该函数(全局)

square_np = np.vectorize(square)
square_grad_np = np.vectorize(square_grad)

2.4 转为tf可用的32位float型,numpy默认是64位(全局)

square_np_32 = lambda x: square_np(x).astype(np.float32)
square_grad_np_32 = lambda x: square_grad_np(x).astype(np.float32)

2.5 定义tf版的梯度函数

def square_grad_tf(x, name=None):
 with ops.name_scope(name, "square_grad_tf", [x]) as name:
 y = tf.py_func(square_grad_np_32, [x], [tf.float32], name=name, stateful=False)
 return y[0]

2.6 定义函数

def my_py_func(func, inp, Tout, stateful=False, name=None, my_grad_func=None):
 # need to generate a unique name to avoid duplicates:
 random_name = "PyFuncGrad" + str(np.random.randint(0, 1E+8))
 tf.RegisterGradient(random_name)(my_grad_func)
 g = tf.get_default_graph()
 with g.gradient_override_map({"PyFunc": random_name, "PyFuncStateless": random_name}):
 return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

2.7 定义梯度,该函数依靠上一个函数my_py_func计算并传播

def _square_grad(op, pred_grad):
 x = op.inputs[0]
 cur_grad = square_grad(x)
 next_grad = pred_grad * cur_grad
 return next_grad

2.8 定义tf版的square函数

def square_tf(x, name=None):
 with ops.name_scope(name, "square_tf", [x]) as name:
 y = my_py_func(square_np_32,
   [x],
   [tf.float32],
   stateful=False,
   name=name,
   my_grad_func=_square_grad)
 return y[0]

3.使用

跟用其他激活函数一样,直接用就行了。input_data:输入数据。

h = square_tf(input_data)

over. 学艺不精,多多指教!

以上这篇tensorflow自定义激活函数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python图像处理之反色实现方法
May 30 Python
Python程序中设置HTTP代理
Nov 06 Python
Python函数式编程
Jul 20 Python
matplotlib savefig 保存图片大小的实例
May 24 Python
python创建文件时去掉非法字符的方法
Oct 31 Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 Python
python opencv圆、椭圆与任意多边形的绘制实例详解
Feb 06 Python
总结python 三种常见的内存泄漏场景
Nov 20 Python
Python将QQ聊天记录生成词云的示例代码
Feb 10 Python
Matplotlib animation模块实现动态图
Feb 25 Python
python - asyncio异步编程
Apr 06 Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
pytorch梯度剪裁方式
Feb 04 #Python
基于梯度爆炸的解决方法:clip gradient
Feb 04 #Python
Python 格式化输出_String Formatting_控制小数点位数的实例详解
Feb 04 #Python
python求一个字符串的所有排列的实现方法
Feb 04 #Python
Windows上安装tensorflow  详细教程(图文详解)
Feb 04 #Python
有关Tensorflow梯度下降常用的优化方法分享
Feb 04 #Python
You might like
PHP中把stdClass Object转array的几个方法
2014/05/08 PHP
php检查字符串中是否包含7位GSM字符的方法
2015/03/17 PHP
PHP判断上传文件类型的解决办法
2015/10/20 PHP
php实现图片上传、剪切功能
2016/05/07 PHP
PHP面试常用算法(推荐)
2016/07/22 PHP
PHP模版引擎原理、定义与用法实例
2019/03/29 PHP
php如何获取Http请求
2020/04/30 PHP
php swoft框架实例用法
2020/12/22 PHP
JavaScript脚本语言在网页中的简单应用
2007/05/13 Javascript
chrome原生方法之数组
2011/11/30 Javascript
浏览器打开层自动缓慢展开收缩实例代码
2013/07/04 Javascript
button没写type=button会导致点击时提交
2014/03/06 Javascript
JS实现简单的顶部定时关闭层效果
2014/06/15 Javascript
JQuery之proxy实现绑定代理方法
2016/08/01 Javascript
JavaScript实现页面定时刷新(定时器,meta)
2016/10/12 Javascript
vue.js利用Object.defineProperty实现双向绑定
2017/03/09 Javascript
JS回调函数基本定义与用法实例分析
2017/05/24 Javascript
微信小程序scroll-view组件实现滚动动画
2018/01/31 Javascript
ES6基础之展开语法(Spread syntax)
2019/02/21 Javascript
Vue开发环境中修改端口号的实现方法
2019/08/15 Javascript
VUE 组件转换为微信小程序组件的方法
2019/11/06 Javascript
vue+element-ui JYAdmin后台管理系统模板解析
2020/07/28 Javascript
原生JavaScript实现拖动校验功能
2020/09/29 Javascript
老生常谈Python进阶之装饰器
2017/05/11 Python
Python中定时任务框架APScheduler的快速入门指南
2017/07/06 Python
详解python执行shell脚本创建用户及相关操作
2019/04/11 Python
python3 BeautifulSoup模块使用字典的方法抓取a标签内的数据示例
2019/11/28 Python
Python使用matplotlib绘制Logistic曲线操作示例
2019/11/28 Python
Pytorch高阶OP操作where,gather原理
2020/04/30 Python
利用python3筛选excel中特定的行(行值满足某个条件/行值属于某个集合)
2020/09/04 Python
html5指南-7.geolocation结合google maps开发一个小的应用
2013/01/07 HTML / CSS
教育课题研究自我鉴定范文
2013/12/28 职场文书
电子商务专业求职信范文
2015/03/19 职场文书
使用redis生成唯一编号及原理示例详解
2021/09/15 Redis
一文弄懂MySQL索引创建原则
2022/02/28 MySQL
MySQL限制查询和数据排序介绍
2022/03/25 MySQL