解决Keras TensorFlow 混编中 trainable=False设置无效问题


Posted in Python onJune 28, 2020

这是最近碰到一个问题,先描述下问题:

首先我有一个训练好的模型(例如vgg16),我要对这个模型进行一些改变,例如添加一层全连接层,用于种种原因,我只能用TensorFlow来进行模型优化,tf的优化器,默认情况下对所有tf.trainable_variables()进行权值更新,问题就出在这,明明将vgg16的模型设置为trainable=False,但是tf的优化器仍然对vgg16做权值更新

以上就是问题描述,经过谷歌百度等等,终于找到了解决办法,下面我们一点一点的来复原整个问题。

trainable=False 无效

首先,我们导入训练好的模型vgg16,对其设置成trainable=False

from keras.applications import VGG16
import tensorflow as tf
from keras import layers
# 导入模型
base_mode = VGG16(include_top=False)
# 查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
# 设置 trainable=False
# base_mode.trainable = False似乎也是可以的
for layer in base_mode.layers:
  layer.trainable = False

设置好trainable=False后,再次查看可训练的变量,发现并没有变化,也就是说设置无效

# 再次查看可训练的变量
tf.trainable_variables()

[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]

解决的办法

解决的办法就是在导入模型的时候建立一个variable_scope,将需要训练的变量放在另一个variable_scope,然后通过tf.get_collection获取需要训练的变量,最后通过tf的优化器中var_list指定需要训练的变量

from keras import models
with tf.variable_scope('base_model'):
  base_model = VGG16(include_top=False, input_shape=(224,224,3))
with tf.variable_scope('xxx'):
  model = models.Sequential()
  model.add(base_model)
  model.add(layers.Flatten())
  model.add(layers.Dense(10))
# 获取需要训练的变量
trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx')
trainable_var

[<tf.Variable 'xxx_2/dense_1/kernel:0' shape=(25088, 10) dtype=float32_ref>,
<tf.Variable 'xxx_2/dense_1/bias:0' shape=(10,) dtype=float32_ref>]

# 定义tf优化器进行训练,这里假设有一个loss
loss = model.output / 2; # 随便定义的,方便演示
train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)

总结

在keras与TensorFlow混编中,keras中设置trainable=False对于TensorFlow而言并不起作用

解决的办法就是通过variable_scope对变量进行区分,在通过tf.get_collection来获取需要训练的变量,最后通过tf优化器中var_list指定训练

以上这篇解决Keras TensorFlow 混编中 trainable=False设置无效问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows下Python2与Python3两个版本共存的方法详解
Feb 12 Python
微信跳一跳自动运行python脚本
Jan 08 Python
django初始化数据库的实例
May 27 Python
python样条插值的实现代码
Dec 17 Python
python下的opencv画矩形和文字注释的实现方法
Jul 09 Python
python求加权平均值的实例(附纯python写法)
Aug 22 Python
python Pillow图像处理方法汇总
Oct 16 Python
pycharm运行scrapy过程图解
Nov 22 Python
python3.8与pyinstaller冲突问题的快速解决方法
Jan 16 Python
python IDLE添加行号显示教程
Apr 25 Python
python中二分查找法的实现方法
Dec 06 Python
详解使用python爬取抖音app视频(appium可以操控手机)
Jan 26 Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
You might like
以文件形式缓存php变量的方法
2015/06/26 PHP
CI框架简单邮件发送类实例
2016/05/18 PHP
微信开发之php表单微信中自动提交两次问题解决办法
2017/01/08 PHP
php生成微信红包数组的方法
2019/09/05 PHP
PHP 使用位运算实现四则运算的代码
2021/03/09 PHP
脚本吧 - 幻宇工作室用到js,超强推荐share.js
2006/12/23 Javascript
点击下载链接 弹出页面实现代码
2009/10/01 Javascript
javascript event 事件解析
2011/01/31 Javascript
eval与window.eval的差别分析
2011/03/17 Javascript
js导航栏单击事件背景变换示例代码
2014/01/13 Javascript
jquery中show()、hide()和toggle()用法实例
2015/01/15 Javascript
学习JavaScript设计模式之装饰者模式
2016/01/19 Javascript
javascript实现获取图片大小及图片等比缩放的方法
2016/11/24 Javascript
Node.js pipe实现源码解析
2017/08/12 Javascript
JavaScript树的深度优先遍历和广度优先遍历算法示例
2018/07/30 Javascript
如何为vuex实现带参数的 getter和state.commit
2019/01/04 Javascript
file-loader打包图片文件时路径错误输出为[object-module]的解决方法
2020/01/03 Javascript
Python实现字典的key和values的交换
2015/08/04 Python
python使用tomorrow实现多线程的例子
2019/07/20 Python
Python 3 判断2个字典相同
2019/08/06 Python
python+Django+pycharm+mysql 搭建首个web项目详解
2019/11/29 Python
Python使用psutil获取进程信息的例子
2019/12/17 Python
Python tkinter布局与按钮间距设置方式
2020/03/04 Python
Python代码一键转Jar包及Java调用Python新姿势
2020/03/10 Python
Python数据正态性检验实现过程
2020/04/18 Python
利用Python实现Excel的文件间的数据匹配功能
2020/06/16 Python
python 字符串格式化的示例
2020/09/21 Python
CSS Grid布局教程之什么是网格布局
2014/12/30 HTML / CSS
css3实现元素环绕中心点布局的方法示例
2019/01/15 HTML / CSS
美国玩具公司:U.S.Toy
2018/05/19 全球购物
中兴通讯全球官方网站:ZTE
2020/12/26 全球购物
JDBC操作数据库的基本流程是什么
2014/10/28 面试题
婚前协议书范本
2014/04/15 职场文书
重点工程汇报材料
2014/08/27 职场文书
2015年劳动部工作总结
2015/05/23 职场文书
Python中time标准库的使用教程
2022/04/13 Python