TensorFlow 滑动平均的示例代码


Posted in Python onJune 19, 2018

滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates)

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1])

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1]))

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf 
if __name__ == "__main__": 

  v = tf.Variable(0.,name="v") 

  #设置滑动平均模型的系数 

  ema = tf.train.ExponentialMovingAverage(0.99) 

  #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量 

  op = ema.apply([v]) 

  #获取变量v的名字 

  print(v.name) 

  #v:0 

  #创建一个保存模型的对象 

  save = tf.train.Saver() 

  sess = tf.Session() 

  #初始化所有变量 

  init = tf.initialize_all_variables() 

  sess.run(init) 

  #给变量v重新赋值 

  sess.run(tf.assign(v,10)) 

  #应用平均滑动设置 

  sess.run(op) 

  #保存模型文件 

  save.save(sess,"./model.ckpt") 

  #输出变量v之前的值和使用滑动平均模型之后的值 

  print(sess.run([v,ema.average(v)])) 

  #[10.0, 0.099999905]

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v") 

#定义模型对象 

saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) 

sess = tf.Session() 

saver.restore(sess,"./model.ckpt") 

print(sess.run(v)) 

#0.0999999

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v") 

#滑动模型的参数的大小并不会影响v的值 

ema = tf.train.ExponentialMovingAverage(0.99) 

print(ema.variables_to_restore()) 

#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} 

sess = tf.Session() 

saver = tf.train.Saver(ema.variables_to_restore()) 

saver.restore(sess,"./model.ckpt") 

print(sess.run(v)) 

#0.0999999

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

| Example usage when creating a training model:
 | 
 | ```python
 | # Create variables.
 | var0 = tf.Variable(...)
 | var1 = tf.Variable(...)
 | # ... use the variables to build a training model...
 | ...
 | # Create an op that applies the optimizer. This is what we usually
 | # would use as a training op.
 | opt_op = opt.minimize(my_loss, [var0, var1])
 | 
 | # Create an ExponentialMovingAverage object
 | ema = tf.train.ExponentialMovingAverage(decay=0.9999)
 | 
 | with tf.control_dependencies([opt_op]):
 |   # Create the shadow variables, and add ops to maintain moving averages
 |   # of var0 and var1. This also creates an op that will update the moving
 |   # averages after each training step. This is what we will use in place
 |   # of the usual training op.
 |   training_op = ema.apply([var0, var1])
 | 
 | ...train the model by running training_op...
 | ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):

  with tf.variable_scope(scope):

    # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)

    # gamma = tf.get_variable(name='gamma', shape=[n_out],

    #             initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)

    batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')

    ema = tf.train.ExponentialMovingAverage(decay=decay)

 

    def mean_var_with_update():

      ema_apply_op = ema.apply([batch_mean,batch_var])

      with tf.control_dependencies([ema_apply_op]):

        return tf.identity(batch_mean),tf.identity(batch_var)

        # identity之后会把Variable转换为Tensor并入图中,

        # 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制

 

    mean,var = tf.cond(phase_train,

              mean_var_with_update,

              lambda: (ema.average(batch_mean),ema.average(batch_var)))

   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)

  return normed

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之一个免费的实验室
Sep 14 Python
Python开发常用的一些开源Package分享
Feb 14 Python
Python获取任意xml节点值的方法
May 05 Python
Python中文字符串截取问题
Jun 15 Python
python 读写txt文件 json文件的实现方法
Oct 22 Python
python re模块findall()函数实例解析
Jan 19 Python
利用nohup来开启python文件的方法
Jan 14 Python
用Q-learning算法实现自动走迷宫机器人的方法示例
Jun 03 Python
Python 200行代码实现一个滑动验证码过程详解
Jul 11 Python
nginx黑名单和django限速,最简单的防恶意请求方法分享
Aug 09 Python
python获取本周、上周、本月、上月及本季的时间代码实例
Sep 08 Python
Python 居然可以在 Excel 中画画你知道吗
Feb 15 Python
python3个性签名设计实现代码
Jun 19 #Python
TensorFlow 模型载入方法汇总(小结)
Jun 19 #Python
python3爬虫之设计签名小程序
Jun 19 #Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 #Python
TensorFlow数据输入的方法示例
Jun 19 #Python
深入分析python中整型不会溢出问题
Jun 18 #Python
Python登录注册验证功能实现
Jun 18 #Python
You might like
PHP动态编译出现Cannot find autoconf的解决方法
2014/11/05 PHP
PHP使用PDO 连接与连接管理操作实例分析
2020/04/21 PHP
PHP设计模式(五)适配器模式Adapter实例详解【结构型】
2020/05/02 PHP
window.open被浏览器拦截后的自定义提示效果代码
2007/11/19 Javascript
Jquery中getJSON在asp.net中的使用说明
2011/03/10 Javascript
JS读取cookies信息(记录用户名)
2012/01/10 Javascript
几种延迟加载JS代码的方法加快网页的访问速度
2013/10/12 Javascript
javascript中indexOf技术详解
2015/05/07 Javascript
JavaScript给input的value赋值引发的关于基本类型值和引用类型值问题
2015/12/07 Javascript
详解Node.Js如何处理post数据
2016/09/19 Javascript
js使用html2canvas实现屏幕截取的示例代码
2017/08/28 Javascript
vue项目实战总结篇
2018/02/11 Javascript
vue 使某个组件不被 keep-alive 缓存的方法
2018/09/21 Javascript
JavaScript模板引擎实现原理实例详解
2018/12/14 Javascript
Vue使用.sync 实现父子组件的双向绑定数据问题
2019/04/04 Javascript
vue实现绑定事件的方法实例代码详解
2019/06/20 Javascript
Vue CLI项目 axios模块前后端交互的使用(类似ajax提交)
2019/09/01 Javascript
微信小程序和H5页面间相互跳转代码实例
2019/09/19 Javascript
一看就会的vuex实现登录验证(附案例)
2020/01/09 Javascript
JavaScript使用canvas绘制随机验证码
2020/02/17 Javascript
微信小程序实现比较功能的方法汇总(五种方法)
2020/03/07 Javascript
Vue动态加载图片在跨域时无法显示的问题及解决方法
2020/03/10 Javascript
JS正则表达式常见函数与用法小结
2020/04/13 Javascript
vue项目实现设置根据路由高亮对应的菜单项操作
2020/08/06 Javascript
Python爬取网易云音乐热门评论
2017/03/31 Python
Python数据预处理之数据规范化(归一化)示例
2019/01/08 Python
Django 查询数据库并返回页面的例子
2019/08/12 Python
python requests更换代理适用于IP频率限制的方法
2019/08/21 Python
Python自动创建Excel并获取内容
2020/09/16 Python
call在Python中改进数列的实例讲解
2020/12/09 Python
24个canvas基础知识小结
2014/12/17 HTML / CSS
世界上最大的折扣香水店:FragranceNet.com
2016/10/26 全球购物
村容村貌整治方案
2014/05/21 职场文书
网吧七夕活动策划方案
2014/08/31 职场文书
党的群众路线教育实践活动实施方案
2014/10/31 职场文书
夫妻双方自愿离婚协议书怎么写
2014/12/01 职场文书