TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python标准库中的wave模块绘制乐谱的简单教程
Mar 30 Python
基于Python_脚本CGI、特点、应用、开发环境(详解)
May 23 Python
python学习必备知识汇总
Sep 08 Python
selenium+python 去除启动的黑色cmd窗口方法
May 22 Python
WIn10+Anaconda环境下安装PyTorch(避坑指南)
Jan 30 Python
解决pyqt5中QToolButton无法使用的问题
Jun 21 Python
一文秒懂python读写csv xml json文件各种骚操作
Jul 04 Python
如何使用Python多线程测试并发漏洞
Dec 18 Python
Python破解BiliBili滑块验证码的思路详解(完美避开人机识别)
Feb 17 Python
基于jupyter代码无法在pycharm中运行的解决方法
Apr 21 Python
PyTorch 如何设置随机数种子使结果可复现
May 12 Python
Python之基础函数案例详解
Aug 30 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
十天学会php之第六天
2006/10/09 PHP
PHP中数组合并的两种方法及区别介绍
2012/09/14 PHP
PHP版本升级到7.x后wordpress的一些修改及wordpress技巧
2015/12/25 PHP
Windows下PHP开发环境搭建教程(Apache+PHP+MySQL)
2016/06/13 PHP
PHP精确到毫秒秒杀倒计时实例详解
2019/03/14 PHP
jQuery Tips 为AJAX回调函数传递额外参数的方法
2010/12/28 Javascript
jquery animate图片模向滑动示例代码
2011/01/26 Javascript
鼠标滑上去后图片放大浮出效果的js代码
2011/05/28 Javascript
firebug的一个有趣现象介绍
2011/11/30 Javascript
jquery实现经典的淡入淡出选项卡效果代码
2015/09/22 Javascript
jQuery自动完成插件completer附源码下载
2016/01/04 Javascript
关于cookie的初识和运用(js和jq)
2016/04/07 Javascript
深入浅析JavaScript中数据共享和数据传递
2016/04/25 Javascript
Bootstrap Table使用方法解析
2016/10/19 Javascript
微信公众号开发 自定义菜单跳转页面并获取用户信息实例详解
2016/12/08 Javascript
详解vue-cli + webpack 多页面实例应用
2017/04/25 Javascript
实现微信小程序的wxml文件和wxss文件在webstrom的支持
2017/06/12 Javascript
简单实现js进度条加载效果
2020/03/25 Javascript
Vee-Validate的使用方法详解
2017/09/22 Javascript
Vue项目使用CDN优化首屏加载问题
2018/04/01 Javascript
微信小程序canvas.drawImage完全显示图片问题的解决
2018/11/30 Javascript
微信小程序如何访问公众号文章
2019/07/08 Javascript
Vue实现回到顶部和底部动画效果
2019/07/31 Javascript
Python使用arrow库优雅地处理时间数据详解
2017/10/10 Python
用pyqt5 给按钮设置图标和css样式的方法
2019/06/24 Python
python networkx 根据图的权重画图实现
2019/07/10 Python
500行代码使用python写个微信小游戏飞机大战游戏
2019/10/16 Python
使用CSS3设计地图上的雷达定位提示效果
2016/04/05 HTML / CSS
HTML5 Canvas锯齿图代码实例
2014/04/10 HTML / CSS
医学院学生的自我评价分享
2013/11/19 职场文书
运动会入场式解说词
2014/02/18 职场文书
垃圾桶标语
2014/06/24 职场文书
环境科学专业求职信
2014/08/04 职场文书
车贷收入证明范本
2014/09/14 职场文书
幼儿园父亲节活动总结
2015/02/12 职场文书
Python使用openpyxl模块处理Excel文件
2022/06/05 Python