如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接mysql并提交mysql事务示例
Mar 05 Python
python提取字典key列表的方法
Jul 11 Python
Django 配置多站点多域名的实现步骤
May 17 Python
使用Python计算玩彩票赢钱概率
Jun 26 Python
python代码编写计算器小程序
Mar 30 Python
flask框架渲染Jinja模板与传入模板变量操作详解
Jan 25 Python
Python实现在Windows平台修改文件属性
Mar 05 Python
django queryset 去重 .distinct()说明
May 19 Python
pymongo insert_many 批量插入的实例
Dec 05 Python
python爬虫利用selenium实现自动翻页爬取某鱼数据的思路详解
Dec 22 Python
pycharm 配置svn的图文教程(手把手教你)
Jan 15 Python
Python实现粒子群算法的示例
Feb 14 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
PHP实现分页的一个示例
2006/10/09 PHP
Yii框架在页面输出执行sql语句以方便调试的实现方法
2014/12/24 PHP
php更新mysql后获取改变行数的方法
2014/12/25 PHP
PHP单态模式简单用法示例
2016/11/16 PHP
php结合redis高并发下发帖、发微博的实现方法
2016/12/15 PHP
jquery1.4 教程二 ajax方法的改进
2010/02/25 Javascript
JSChart轻量级图形报表工具(内置函数中文参考)
2010/10/11 Javascript
js 获取后台的字段 改变 checkbox的被选中的状态 代码
2013/06/05 Javascript
使用typeof方法判断undefined类型
2014/09/09 Javascript
javascript中hasOwnProperty() 方法使用指南
2015/03/09 Javascript
jQuery插件EasyUI校验规则 validatebox验证框
2015/11/29 Javascript
MVC+jQuery.Ajax异步实现增删改查和分页
2020/12/22 Javascript
jQuery实现的placeholder效果完整实例
2016/08/02 Javascript
jQuery插件DataTable使用方法详解(.Net平台)
2016/12/22 Javascript
JS对象的深度克隆方法示例
2017/03/16 Javascript
js+css实现红包雨效果
2018/07/12 Javascript
使用react context 实现vue插槽slot功能
2019/07/18 Javascript
原生js+canvas实现贪吃蛇效果
2020/08/02 Javascript
[01:01:52]DOTA2-DPC中国联赛定级赛 SAG vs iG BO3第二场 1月9日
2021/03/11 DOTA
python分割文件的常用方法
2014/11/01 Python
python实现发送和获取手机短信验证码
2016/01/15 Python
Django自定义分页与bootstrap分页结合
2021/02/22 Python
Python下实现的RSA加密/解密及签名/验证功能示例
2017/07/17 Python
Python实现两个list求交集,并集,差集的方法示例
2018/08/02 Python
Django中提供的6种缓存方式详解
2019/08/05 Python
python禁用键鼠与提权代码实例
2019/08/16 Python
pymysql模块的操作实例
2019/12/17 Python
python实现的分析并统计nginx日志数据功能示例
2019/12/21 Python
python按顺序重命名文件并分类转移到各个文件夹中的实现代码
2020/07/21 Python
银行介绍信范文
2014/01/10 职场文书
《彩色世界》教学反思
2014/04/12 职场文书
上课不认真检讨书
2014/09/17 职场文书
党支部先进事迹材料
2014/12/24 职场文书
2015社区精神文明建设工作总结
2015/04/21 职场文书
企业财务管理制度范本
2015/08/04 职场文书
vue中div禁止点击事件的实现
2022/04/02 Vue.js