解决tensorflow模型参数保存和加载的问题


Posted in Python onJuly 26, 2018

终于找到bug原因!记一下;还是不熟悉平台的原因造成的!

Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误? 而 直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错?

model.py,里面含有 ModelV 和 ModelP,另外还有 modelP.py 和 modelV.py 分别只含有 ModelP 和 ModeV 这两个对象,先使用 modelP.py 和 modelV.py 分别训练好模型,然后再在 model.py 里加载进来:

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

 def __init__(self):

  self.v1 = tf.Variable(66, name="v1")
  self.v2 = tf.Variable(77, name="v2")
  self.save_path = "model_v/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.saver = tf.train.Saver()
  self.sess = tf.Session()

 def train(self):
  self.sess.run(self.init)
  print 'v2', self.v2.eval(self.sess)

  self.saver.save(self.sess, self.save_path)
  print "ModelV saved."

 def predict(self):

  all_vars = tf.trainable_variables()
  for v in all_vars:
   print(v.name)
  self.saver.restore(self.sess, self.save_path)
  print "ModelV restored."
  print 'v2', self.v2.eval(self.sess)
  print '------------------------------------------------------------------'

class ModelP():

 def __init__(self):

  self.p1 = tf.Variable(88, name="p1")
  self.p2 = tf.Variable(99, name="p2")
  self.save_path = "model_p/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.saver = tf.train.Saver()
  self.sess = tf.Session()

 def train(self):
  self.sess.run(self.init)
  print 'p2', self.p2.eval(self.sess)

  self.saver.save(self.sess, self.save_path)
  print "ModelP saved."

 def predict(self):

  all_vars = tf.trainable_variables()
  for v in all_vars:
   print v.name
  self.saver.restore(self.sess, self.save_path)
  print "ModelP restored."
  print 'p2', self.p2.eval(self.sess)
  print '---------------------------------------------------------------------'


if __name__ == '__main__':
 v = ModelV()
 p = ModelP()
 v.predict()
 #v.train()
 p.predict() 
 #p.train()

这里 tf.global_variables_initializer() 很关键! 尽管你是分别在对象 ModelP 和 ModelV 内部分配和定义的 tf.Variable(),即 v1 v2 和 p1 p2,但是 对 tf 这个模块而言, 这些都是全局变量,可以通过以下代码查看所有的变量,你就会发现同一个文件中同时运行 ModelP 和 ModelV 在初始化之后都打印出了一样的变量,这个是问题的关键所在:

all_vars = tf.trainable_variables()
for v in all_vars:
 print(v.name)

错误。你可以交换 modelP 和 modelV 初始化的顺序,看看错误信息的变化

v1:0
v2:0
p1:0
p2:0
ModelV restored.
v2 77
v1:0
v2:0
p1:0
p2:0
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v2 not found in checkpoint
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v1 not found in checkpoint

实际上,分开运行时,模型保存的参数是正确的,因为在一个模型里的Variable就只有 v1 v2 或者 p1 p2; 但是在一个文件同时运行的时候,模型参数实际上保存的是 v1 v2 p1 p2四个,因为在默认情况下,创建的Saver,会直接保存所有的参数。而 Saver.restore() 又是默认(无Variable参数列表时)按照已经定义好的全局模型变量来加载对应的参数值, 在进行 ModelV.predict时,按照顺序(从debug可以看出,应该是按照参数顺序一次检测)在模型文件中查找相应的 key,此时能够找到对应的v1 v2,加载成功,但是在 ModelP.predict时,在model_p的模型文件中找不到 v1 和 v2,只有 p1 和 p2, 此时就会报错;不过这里的 第一次加载 还有 p1 p2 找不到没有报错,解释不通, 未完待续

Saver.save() 和 Saver.restore() 是一对, 分别只保存和加载模型的参数, 但是模型的结构怎么知道呢? 必须是你定义好了,而且要和保存的模型匹配才能加载;

如果想要在不定义模型的情况下直接加载出模型结构和模型参数值,使用

# 加载 结构,即 模型参数 变量等
new_saver = tf.train.import_meta_graph("model_v/model.ckpt.meta")
print "ModelV construct"
all_vars = tf.trainable_variables()
for v in all_vars:
 print v.name
 #print v.name,v.eval(self.sess) # v 都还未初始化,不能求值
# 加载模型 参数变量 的 值
new_saver.restore(self.sess, tf.train.latest_checkpoint('model_v/'))
print "ModelV restored."
all_vars = tf.trainable_variables()
for v in all_vars:
 print v.name,v.eval(self.sess)

加载 结构,即 模型参数 变量等完成后,就会有变量了,但是不能访问他的值,因为还未赋值,然后再restore一次即可得到值了

那么上述错误的解决方法就是这个改进版本的model.py;其实 tf.train.Saver 是可以带参数的,他可以保存你想要保存的模型参数,如果不带参数,很可能就会保存 tf.trainable_variables() 所有的variable,而 tf.trainable_variables()又是从 tf 全局得到的,因此只要在模型保存和加载时,构造对应的带参数的tf.train.Saver即可,这样就会保存和加载正确的模型了

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

 def __init__(self):

  self.v1 = tf.Variable(66, name="v1")
  self.v2 = tf.Variable(77, name="v2")
  self.save_path = "model_v/model.ckpt"
  self.init = tf.global_variables_initializer()

  self.sess = tf.Session()

 def train(self):
  saver = tf.train.Saver([self.v1, self.v2])
  self.sess.run(self.init)
  print 'v2', self.v2.eval(self.sess)

  saver.save(self.sess, self.save_path)
  print "ModelV saved."

 def predict(self):
  saver = tf.train.Saver([self.v1, self.v2])
  all_vars = tf.trainable_variables()
  for v in all_vars:
   print v.name

  v_vars = [v for v in all_vars if v.name == 'v1:0' or v.name == 'v2:0']
  print "ModelV restored."
  saver.restore(self.sess, self.save_path)
  for v in v_vars:
   print v.name,v.eval(self.sess) 
  print 'v2', self.v2.eval(self.sess)
  print '------------------------------------------------------------------'

class ModelP():

 def __init__(self):

  self.p1 = tf.Variable(88, name="p1")
  self.p2 = tf.Variable(99, name="p2")
  self.save_path = "model_p/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.sess = tf.Session()

 def train(self):
  saver = tf.train.Saver([self.p1, self.p2])
  self.sess.run(self.init)
  print 'p2', self.p2.eval(self.sess)

  saver.save(self.sess, self.save_path)
  print "ModelP saved."

 def predict(self):
  saver = tf.train.Saver([self.p1, self.p2])
  all_vars = tf.trainable_variables()
  p_vars = [v for v in all_vars if v.name == 'p1:0' or v.name == 'p2:0']
  for v in all_vars:
   print v.name
   #print v.name,v.eval(self.sess)
  saver.restore(self.sess, self.save_path)
  print "ModelP restored."
  for p in p_vars:
   print p.name,p.eval(self.sess)
  print 'p2', self.p2.eval(self.sess)
  print '----------------------------------------------------------'


if __name__ == '__main__':
 v = ModelV()
 p = ModelP()
 v.predict()
 #v.train()
 p.predict() 
 #p.train()

小结: 构造的Saver 最好带Variable参数,这样保证 保存和加载能够正确执行

以上这篇解决tensorflow模型参数保存和加载的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python比较两个图片相似度的方法
Mar 13 Python
Python3.6使用tesseract-ocr的正确方法
Oct 17 Python
python实现多张图片拼接成大图
Jan 15 Python
Python中super函数用法实例分析
Mar 18 Python
python itchat给指定联系人发消息的方法
Jun 11 Python
django框架自定义模板标签(template tag)操作示例
Jun 24 Python
pycharm 批量修改变量名称的方法
Aug 01 Python
python中使用while循环的实例
Aug 05 Python
解决pycharm启动后总是不停的updating indices...indexing的问题
Nov 27 Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 Python
Python如何把字典写入到CSV文件的方法示例
Aug 23 Python
Python GUI库Tkiner使用方法代码示例
Nov 27 Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
TensorFlow利用saver保存和提取参数的实例
Jul 26 #Python
You might like
PHP实现远程下载文件到本地
2015/05/17 PHP
php结合md5的加密解密算法实例
2016/09/30 PHP
php关联数组与索引数组及其显示方法
2018/03/12 PHP
Yii2框架自定义验证规则操作示例
2019/02/08 PHP
vs2003 js文件编码问题的解决方法
2010/03/20 Javascript
JQuery EasyUI 数字格式化处理示例
2014/05/05 Javascript
JQuery $.each遍历JavaScript数组对象实例
2014/09/01 Javascript
JQuery中DOM事件绑定用法详解
2015/06/13 Javascript
对JavaScript客户端应用编程的一些建议
2015/06/24 Javascript
AngularJs实现ng1.3+表单验证
2015/12/10 Javascript
全面解析Angular中$Apply()及$Digest()的区别
2016/08/04 Javascript
javascript cookie基础应用之记录用户名的方法
2016/09/20 Javascript
javascript 操作cookies详解及实例
2017/02/22 Javascript
JS实现数组去重方法总结(六种方法)
2017/07/14 Javascript
详解JavaScript中的坐标和距离
2019/05/27 Javascript
使用JS location实现搜索框历史记录功能
2019/12/23 Javascript
基于JavaScript判断两个对象内容是否相等
2020/01/10 Javascript
JavaScript数组去重实现方法小结
2020/01/17 Javascript
vue-socket.io接收不到数据问题的解决方法
2020/05/13 Javascript
nestjs返回给前端数据格式的封装实现
2021/02/22 Javascript
[04:32]DOTA2著名解说配音敌法师 现场专访海涛怒切假腿
2013/12/20 DOTA
[03:05]DOTA2英雄基础教程 嗜血狂魔
2013/12/10 DOTA
获取python文件扩展名和文件名方法
2018/02/02 Python
python os.path模块常用方法实例详解
2018/09/16 Python
NLTK 3.2.4 环境搭建教程
2018/09/19 Python
使用python将请求的requests headers参数格式化方法
2019/01/02 Python
python实例化对象的具体方法
2020/06/17 Python
Python装饰器如何实现修复过程解析
2020/09/05 Python
详解CSS3选择器:nth-child和:nth-of-type之间的差异
2017/09/18 HTML / CSS
css3 column实现卡片瀑布流布局的示例代码
2018/06/22 HTML / CSS
Myprotein加拿大官网:欧洲第一的运动营养品牌
2018/01/06 全球购物
入党积极分子思想汇报范文
2014/01/05 职场文书
中秋节作文(五年级)之关于月亮
2019/09/11 职场文书
HTML5中 rem适配方案与 viewport 适配问题详解
2021/04/27 HTML / CSS
《游戏王:大师决斗》新活动上线 若无符合卡组可免费租用
2022/04/13 其他游戏
详解flex:1什么意思
2022/07/23 HTML / CSS