Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用matplotlib实现在坐标系中画一个矩形的方法
May 20 Python
Python爬虫爬验证码实现功能详解
Apr 14 Python
Python 常用的安装Module方式汇总
May 06 Python
Python实现读取json文件到excel表
Nov 18 Python
python中的随机函数random的用法示例
Jan 27 Python
python中kmeans聚类实现代码
Feb 23 Python
Python Logging 日志记录入门学习
Jun 02 Python
Python小游戏之300行代码实现俄罗斯方块
Jan 04 Python
Python中使用logging和traceback模块记录日志和跟踪异常
Apr 09 Python
django 单表操作实例详解
Jul 30 Python
Python Tkinter Entry和Text的添加与使用详解
Mar 04 Python
python tkinter模块的简单使用
Apr 07 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
PHP获取表单textarea数据中的换行问题
2010/09/10 PHP
注意:php5.4删除了session_unregister函数
2013/08/05 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(十二)
2014/06/25 PHP
PHP原生函数一定好吗?
2014/12/08 PHP
老版本PHP转义Json里的特殊字符的函数
2015/06/08 PHP
PHP将MySQL的查询结果转换为数组并用where拼接的示例
2016/05/13 PHP
PHP按符号截取字符串的指定部分的实现方法
2018/09/10 PHP
js几个验证函数代码
2010/03/25 Javascript
Javascript变量函数浅析
2011/09/02 Javascript
js获取指定日期前后的日期代码
2013/08/20 Javascript
js操作输入框提示信息且响应鼠标事件
2014/03/25 Javascript
JavaScript在浏览器标题栏上显示当前日期和时间的方法
2015/03/19 Javascript
js实现简洁的滑动门菜单(选项卡)效果代码
2015/09/04 Javascript
jQuery+CSS3实现3D立方体旋转效果
2015/11/10 Javascript
JavaScript原型及原型链终极详解
2016/01/04 Javascript
ES6新特性二:Iterator(遍历器)和for-of循环详解
2017/04/20 Javascript
js指定日期增加指定月份的实现方法
2018/12/19 Javascript
浅谈Javascript中的对象和继承
2019/04/19 Javascript
Vue.js+cube-ui(Scroll组件)实现类似头条效果的横向滚动导航条
2019/06/24 Javascript
用Angular实现一个扫雷的游戏示例
2020/05/15 Javascript
Javascript基于OOP实实现探测器功能代码实例
2020/08/26 Javascript
Vue中关闭弹窗组件时销毁并隐藏操作
2020/09/01 Javascript
vue 解决provide和inject响应的问题
2020/11/12 Javascript
[02:57]2014DOTA2国际邀请赛 选手辛苦解说更辛苦
2014/07/10 DOTA
[03:11]2014DOTA2国际邀请赛-VG掉入败者组 独家专访357
2014/07/19 DOTA
在Python下利用OpenCV来旋转图像的教程
2015/04/16 Python
Python编程之event对象的用法实例分析
2017/03/23 Python
详解appium+python 启动一个app步骤
2017/12/20 Python
python实现五子棋游戏(pygame版)
2020/01/19 Python
tensorflow dataset.shuffle、dataset.batch、dataset.repeat顺序区别详解
2020/06/03 Python
Farah官方网站:男士服装及配件
2019/11/01 全球购物
英国高街奥特莱斯:Highstreet Outlet
2019/11/21 全球购物
第一范式(1NF)、第二范式(2NF)和第三范式(3NF)之间的区别是什么?
2016/04/28 面试题
生产管理的三大手法
2013/11/11 职场文书
高中毕业自我鉴定
2013/12/22 职场文书
解决MySQL存储时间出现不一致的问题
2021/04/28 MySQL