Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python类属性与实例属性用法分析
May 09 Python
Python实现简单HTML表格解析的方法
Jun 15 Python
详解Python的Django框架中inclusion_tag的使用
Jul 21 Python
python 实现求解字符串集的最长公共前缀方法
Jul 20 Python
python3对拉勾数据进行可视化分析的方法详解
Apr 03 Python
Python求解正态分布置信区间教程
Nov 20 Python
Python批量启动多线程代码实例
Feb 18 Python
OpenCV Python实现拼图小游戏
Mar 23 Python
python删除某个目录文件夹的方法
May 26 Python
Python如何将装饰器定义为类
Jul 30 Python
Python自动化办公Excel模块openpyxl原理及用法解析
Nov 05 Python
python3中calendar返回某一时间点实例讲解
Nov 18 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
eAccelerator的安装与使用详解
2013/06/13 PHP
ThinkPHP查询返回简单字段数组的方法
2014/08/25 PHP
Symfony2实现在controller中获取url的方法
2016/03/18 PHP
Linux系统中为php添加pcntl扩展
2016/08/28 PHP
自动完成JS类(纯JS, Ajax模式)
2009/03/12 Javascript
js event事件的传递与冒泡处理
2009/12/06 Javascript
Jquery遍历节点的方法小集
2014/01/22 Javascript
浅谈JavaScript function函数种类
2014/12/29 Javascript
js实现当复选框选择匿名登录时隐藏登录框效果
2015/08/14 Javascript
Bootstrap每天必学之级联下拉菜单
2016/03/27 Javascript
浅谈Javascript数据属性与访问器属性
2016/07/26 Javascript
值得分享的Bootstrap Table使用教程
2016/11/23 Javascript
AngularJS 中ui-view传参的实例详解
2017/08/25 Javascript
Node.js API详解之 vm模块用法实例分析
2020/05/27 Javascript
Python GAE、Django导出Excel的方法
2008/11/24 Python
python连接mongodb操作数据示例(mongodb数据库配置类)
2013/12/31 Python
六个窍门助你提高Python运行效率
2015/06/09 Python
python利用不到一百行代码实现一个小siri
2017/03/02 Python
python django事务transaction源码分析详解
2017/03/17 Python
python构建自定义回调函数详解
2017/06/20 Python
详谈python中冒号与逗号的区别
2018/04/18 Python
pandas的连接函数concat()函数的具体使用方法
2019/07/09 Python
Django中的用户身份验证示例详解
2019/08/07 Python
如何给Python代码进行加密
2020/01/10 Python
python实现拼接图片
2020/03/23 Python
导致python中import错误的原因是什么
2020/07/01 Python
python利用xlsxwriter模块 操作 Excel
2020/10/14 Python
详解利用canvas实现环形进度条的方法
2019/06/12 HTML / CSS
美国派对用品及装饰品网上商店:Shindigz
2016/07/30 全球购物
幼儿园教师个人反思
2014/01/30 职场文书
机关党员2014全国两会学习心得体会
2014/03/10 职场文书
会计的岗位职责
2014/03/15 职场文书
2014年两会学习心得范例
2014/03/17 职场文书
工会换届选举方案
2014/05/21 职场文书
新郎父母婚礼答谢词
2015/09/29 职场文书
社交电商模式的兴起:这些新的商机千万别错过
2019/07/26 职场文书