如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中精确输出JSON浮点数的方法
Apr 18 Python
Python通过Django实现用户注册和邮箱验证功能代码
Dec 11 Python
python skimage 连通性区域检测方法
Jun 21 Python
Python WSGI的深入理解
Aug 01 Python
python实现梯度下降算法
Mar 24 Python
浅谈python已知元素,获取元素索引(numpy,pandas)
Nov 26 Python
在PyCharm中遇到pip安装 失败问题及解决方案(pip失效时的解决方案)
Mar 10 Python
python numpy库np.percentile用法说明
Jun 08 Python
Python使用socket模块实现简单tcp通信
Aug 18 Python
使用Python绘制台风轨迹图的示例代码
Sep 21 Python
Python jieba库分词模式实例用法
Jan 13 Python
python使用PySimpleGUI设置进度条及控件使用
Jun 10 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
PHP的FTP学习(四)
2006/10/09 PHP
PHP中的事务使用实例
2015/05/26 PHP
PHP封装返回Ajax字符串和JSON数组的方法
2017/02/17 PHP
解决Laravel 不能创建 migration 的问题
2019/10/09 PHP
PHP上传图片到数据库并显示的实例代码
2019/12/20 PHP
新页面打开实际尺寸的图片
2006/08/25 Javascript
js实现单行文本向上滚动效果实例代码
2013/11/28 Javascript
Jquery原生态实现表格header头随滚动条滚动而滚动
2014/03/18 Javascript
实例分析js和C#中使用正则表达式匹配a标签
2014/11/26 Javascript
深入理解JavaScript系列(33):设计模式之策略模式详解
2015/03/03 Javascript
jquery简单实现图片切换效果的方法
2015/05/12 Javascript
jQuery实现强制cookie过期方法汇总
2015/05/22 Javascript
jquery实现可点击伸缩与展开的菜单效果代码
2015/08/31 Javascript
AngularJS中使用HTML5手机摄像头拍照
2016/02/22 Javascript
js实现砖头在页面拖拉效果
2020/11/20 Javascript
Vue.js实现一个漂亮、灵活、可复用的提示组件示例
2017/03/17 Javascript
解决vue-cli中stylus无法使用的问题方法
2017/06/19 Javascript
小程序自定义组件实现城市选择功能
2018/07/18 Javascript
微信小程序五子棋游戏的悔棋实现方法【附demo源码下载】
2019/02/20 Javascript
vue请求本地自己编写的json文件的方法
2019/04/25 Javascript
微信小程序实现多行文字滚动
2020/11/18 Javascript
手写Vue2.0 数据劫持的示例
2021/03/04 Vue.js
Python制作CSDN免积分下载器
2015/03/10 Python
python中函数总结之装饰器闭包详解
2016/06/12 Python
django实现前后台交互实例
2017/08/07 Python
pyecharts在数据可视化中的应用详解
2020/06/08 Python
Mio Skincare英国官网:身体紧致及孕期身体护理
2018/08/19 全球购物
PHP解析URL是哪个函数?怎么用?
2013/05/09 面试题
医学院学生求职简历的自我评价
2013/10/24 职场文书
学生会干部自荐信
2014/02/04 职场文书
高考寄语大全
2014/04/08 职场文书
优秀少先队工作者事迹材料
2014/05/13 职场文书
会计主管岗位职责
2015/04/02 职场文书
2015年推普周活动方案
2015/05/06 职场文书
Python基础之元类详解
2021/04/29 Python
python可视化之颜色映射详解
2021/09/15 Python