TensorFlow实现指数衰减学习率的方法


Posted in Python onFebruary 05, 2020

在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

TensorFlow实现指数衰减学习率的方法

tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:

#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式

#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)

#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)

①decayed_leaming_rate为每一轮优化时使用的学习率;

②leaming_rate为事先设定的初始学习率;

③decay_rate为衰减系数;

④global_step为当前训练的轮数;

⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。

当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。

接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):

①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:

import tensorflow as tf
 .
 .
 .
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
  .
 .
 .

②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:

import tensorflow as tf
 .
 .
 .
#设置学习率 
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
 .
 .
 .
#创建会话
with tf.Session() as sess:
 .
 .
 .
 for i in range(STEPS):
 .
 .
 .
  #通过选取的样本训练神经网络并更新参数
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
 .
 .
 .

总结

以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!

Python 相关文章推荐
Python3读取文件常用方法实例分析
May 22 Python
python实现树形打印目录结构
Mar 29 Python
浅谈python新式类和旧式类区别
Apr 26 Python
python自制包并用pip免提交到pypi仅安装到本机【推荐】
Jun 03 Python
python itchat实现调用微信接口的第三方模块方法
Jun 11 Python
使用python os模块复制文件到指定文件夹的方法
Aug 22 Python
python网络爬虫 CrawlSpider使用详解
Sep 27 Python
python应用Axes3D绘图(批量梯度下降算法)
Mar 25 Python
如何基于python3和Vue实现AES数据加密
Mar 27 Python
安装多个版本的TensorFlow的方法步骤
Apr 21 Python
Python持续监听文件变化代码实例
Jul 22 Python
基于 Python 实践感知器分类算法
Jan 07 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
You might like
php中error与exception的区别及应用
2014/07/28 PHP
PHP利用APC模块实现文件上传进度条的方法
2015/01/26 PHP
PHP处理大量表单字段的便捷方法
2015/02/07 PHP
PHP使用ffmpeg给视频增加字幕显示的方法
2015/03/12 PHP
php中Snoopy类用法实例
2015/06/19 PHP
YII CLinkPager分页类扩展增加显示共多少页
2016/01/29 PHP
php判断手机浏览还是web浏览,并执行相应的动作简单实例
2016/07/28 PHP
JQuery与JSon实现的无刷新分页代码
2011/09/13 Javascript
js/jquery判断浏览器的方法小结
2014/09/02 Javascript
JavaScript实现数组在指定位置插入若干元素的方法
2015/04/06 Javascript
谈谈AngularJs中的隐藏和显示
2015/12/09 Javascript
简单几步实现返回顶部效果
2016/12/05 Javascript
详谈js中数组(array)和对象(object)的区别
2017/02/27 Javascript
js字符限制(字符截取) 一个中文汉字算两个字符
2017/09/12 Javascript
再谈Angular4 脏值检测(性能优化)
2018/04/23 Javascript
小程序显示弹窗时禁止下层的内容滚动实现方法
2019/03/20 Javascript
封装微信小程序http拦截器过程解析
2019/08/13 Javascript
微信小程序实现树莓派(raspberry pi)小车控制
2020/02/12 Javascript
Vue自定义组件的四种方式示例详解
2020/02/28 Javascript
Angular进行简单单元测试的实现方法实例
2020/08/16 Javascript
django获取from表单multiple-select的value和id的方法
2019/07/19 Python
Python多线程及其基本使用方法实例分析
2019/10/29 Python
python双向链表原理与实现方法详解
2019/12/03 Python
pytorch中torch.max和Tensor.view函数用法详解
2020/01/03 Python
python shapely.geometry.polygon任意两个四边形的IOU计算实例
2020/04/12 Python
浅谈keras中的目标函数和优化函数MSE用法
2020/06/10 Python
django rest framework 过滤时间操作
2020/07/12 Python
python 装饰器重要在哪
2021/02/14 Python
澳大利亚领先的运动鞋商店:Hype DC
2018/03/31 全球购物
2014信息公开实施方案
2014/02/22 职场文书
公证处委托书
2015/01/28 职场文书
老乡聚会通知
2015/04/23 职场文书
小平您好观后感
2015/06/09 职场文书
养成教育工作总结
2015/08/13 职场文书
react国际化react-intl的使用
2021/05/06 Javascript
vue cli4中mockjs在dev环境和build环境的配置详情
2022/04/06 Vue.js