如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现多线程暴力破解登陆路由器功能代码分享
Jan 04 Python
Python功能键的读取方法
May 28 Python
详解Python读取配置文件模块ConfigParser
May 11 Python
Python实现字符串反转的常用方法分析【4种方法】
Sep 30 Python
使用python实现ANN
Dec 20 Python
python正则表达式面试题解答
Apr 28 Python
基于python生成器封装的协程类
Mar 20 Python
Python 使用 attrs 和 cattrs 实现面向对象编程的实践
Jun 12 Python
python实现爬取百度图片的方法示例
Jul 06 Python
ML神器:sklearn的快速使用及入门
Jul 11 Python
python入门:argparse浅析 nargs='+'作用
Jul 12 Python
python实现三次密码验证的示例
Apr 29 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
AM/FM收音机的安装与调试
2021/03/02 无线电
轻松入门: 煮好咖啡的七个诀窍
2021/03/03 冲泡冲煮
自己写的兼容低于PHP 5.5版本的array_column()函数
2014/10/24 PHP
php正则表达式验证(邮件地址、Url地址、电话号码、邮政编码)
2016/03/14 PHP
PHP自定义错误处理的方法分析
2018/12/19 PHP
php解决约瑟夫环算法实例分析
2019/09/30 PHP
PHP的Trait机制原理与用法分析
2019/10/18 PHP
Laravel5.1 框架Request请求操作常见用法实例分析
2020/01/04 PHP
JavaScript脚本性能的优化方法
2007/02/02 Javascript
Js(JavaScript)中,弹出是或否的选择框示例(confirm用法的实例分析)
2013/07/09 Javascript
javascript三元运算符用法实例
2015/04/16 Javascript
浅谈javascript alert和confirm的美化
2016/12/15 Javascript
浅析webpack 如何优雅的使用tree-shaking(摇树优化)
2017/08/16 Javascript
浅谈React的最大亮点之虚拟DOM
2018/05/29 Javascript
快速解决bootstrap下拉菜单无法隐藏的问题
2018/08/10 Javascript
vue 之 css module的使用方法
2018/12/04 Javascript
详解基于Vue的支持数据双向绑定的select组件
2019/09/02 Javascript
基于vue-cli3创建libs库的实现方法
2019/12/04 Javascript
vue使用map代替Aarry数组循环遍历的方法
2020/04/30 Javascript
[00:20]DOTA2荣耀之路7:-ah fu-抢盾
2018/05/31 DOTA
深入剖析Python的爬虫框架Scrapy的结构与运作流程
2016/01/20 Python
Python设计模式之观察者模式原理与用法详解
2019/01/16 Python
TensorFlow2.0矩阵与向量的加减乘实例
2020/02/07 Python
python图片剪裁代码(图片按四个点坐标剪裁)
2020/03/10 Python
印度手工编织服装和家居用品商店:Fabindi
2019/10/07 全球购物
应届生保险求职信
2013/11/11 职场文书
《逃家小兔》教学反思
2014/02/23 职场文书
行政监察建议书
2014/05/19 职场文书
意外伤害赔偿协议书
2014/09/16 职场文书
青年教师个人总结
2015/02/11 职场文书
2015年档案室工作总结
2015/05/23 职场文书
企业反腐倡廉心得体会
2015/08/15 职场文书
2016年秋季运动会通讯稿
2015/11/25 职场文书
2016保送生自荐信范文
2016/01/29 职场文书
德生BCL3000抢先使用感受和评价
2022/04/07 无线电
德生2P3收音机开箱评测
2022/04/30 无线电