TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
Python实现抓取页面上链接的简单爬虫分享
Jan 21 Python
Python中使用第三方库xlrd来读取Excel示例
Apr 05 Python
Python OS模块常用函数说明
May 23 Python
Python lxml模块安装教程
Jun 02 Python
Python中的列表生成式与生成器学习教程
Mar 13 Python
浅谈python配置与使用OpenCV踩的一些坑
Apr 02 Python
Pandas 同元素多列去重的实例
Jul 03 Python
transform python环境快速配置方法
Sep 27 Python
Python 元组拆包示例(Tuple Unpacking)
Dec 24 Python
python 爬取B站原视频的实例代码
Sep 09 Python
java关于string最常出现的面试题整理
Jan 18 Python
pytorch通过训练结果的复现设置随机种子
Jun 01 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
Cappuccino 卡布其诺咖啡之制作
2021/03/03 冲泡冲煮
DOTA2【瓜皮时刻】Vol.91 RTZ山史最惨“矿难”
2021/03/05 DOTA
PHP HTML JavaScript MySQL代码如何互相传值的方法分享
2012/09/30 PHP
基于php冒泡排序算法的深入理解
2013/06/09 PHP
解决yii2左侧菜单子级无法高亮问题的方法
2016/05/08 PHP
PHP批量删除jQuery操作
2017/07/23 PHP
nginx 设置多个站跨域
2021/03/09 Servers
深入认识javascript中的eval函数
2009/11/02 Javascript
基于jquery的修改当前TAB显示标题的代码
2010/12/11 Javascript
JavaScript自定义事件介绍
2013/08/29 Javascript
JS小功能(button选择颜色)简单实例
2013/11/29 Javascript
JS实现最简单的冒泡排序算法
2017/02/15 Javascript
easyUI下拉列表点击事件使用方法
2017/05/18 Javascript
Node之简单的前后端交互(实例讲解)
2017/11/14 Javascript
vue2中使用less简易教程
2018/03/27 Javascript
Bootstrap-table使用footerFormatter做统计列功能
2018/09/07 Javascript
JavaScript简单实现动态改变HTML内容的方法示例
2018/12/25 Javascript
详解VUE Element-UI多级菜单动态渲染的组件
2019/04/25 Javascript
微信小程序获取公众号文章列表及显示文章的示例代码
2020/03/10 Javascript
原生js实现自定义难度的扫雷游戏
2021/01/22 Javascript
提升Python程序运行效率的6个方法
2015/03/31 Python
python实现用户管理系统
2018/01/10 Python
Python中单例模式总结
2018/02/20 Python
Django数据库连接丢失问题的解决方法
2018/12/29 Python
python3安装crypto出错及解决方法
2019/07/30 Python
基于python调用psutil模块过程解析
2019/12/20 Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
2020/06/28 Python
浅谈HTML5新增和废弃的标签
2019/04/28 HTML / CSS
自动化毕业生专业自荐书范文
2014/02/04 职场文书
2015年环境整治工作总结
2015/05/22 职场文书
监护人证明
2015/06/19 职场文书
学会感恩主题班会
2015/08/12 职场文书
2016秋季运动会开幕词
2016/03/04 职场文书
蓝天保卫战收官在即 :15行业将开展环保分级评价
2019/07/19 职场文书
创业计划书之校园跑腿公司
2019/09/24 职场文书
为什么MySQL选择Repeatable Read作为默认隔离级别
2021/07/26 MySQL