TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python深入学习之内存管理
Aug 31 Python
跟老齐学Python之总结参数的传递
Oct 10 Python
Python实现的使用telnet登陆聊天室实例
Jun 17 Python
python函数局部变量用法实例分析
Aug 04 Python
python类中super()和__init__()的区别
Oct 18 Python
Python 实现域名解析为ip的方法
Feb 14 Python
PyCharm安装Markdown插件的两种方法
Jun 24 Python
python BlockingScheduler定时任务及其他方式的实现
Sep 19 Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 Python
Python文件操作基础流程解析
Mar 19 Python
Python selenium爬取微博数据代码实例
May 22 Python
python爬虫泛滥的解决方法详解
Nov 25 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
如何冲泡挂耳包咖啡?技巧是什么
2021/03/04 冲泡冲煮
php检测用户是否用手机(Mobile)访问网站的类
2014/01/09 PHP
5款适合PHP使用的HTML编辑器推荐
2015/07/03 PHP
WordPress中用于获取搜索表单的PHP函数使用解析
2016/01/05 PHP
获取css样式表内样式的js函数currentStyle(IE),defaultView(FF)
2011/02/14 Javascript
ajax处理php返回json数据的实例代码
2013/01/24 Javascript
js中自定义方法实现停留几秒sleep
2014/07/11 Javascript
js改变embed标签src值的方法
2015/04/10 Javascript
Node.js本地文件操作之文件拷贝与目录遍历的方法
2016/02/16 Javascript
js图片上传前预览功能(兼容所有浏览器)
2016/08/24 Javascript
D3.js实现散点图和气泡图的方法详解
2016/09/21 Javascript
详解照片瀑布流效果(js,jquery分别实现与知识点总结)
2017/01/01 Javascript
用jQuery旋转插件jqueryrotate制作转盘抽奖
2017/02/10 Javascript
使用electron实现百度网盘悬浮窗口功能的示例代码
2018/10/24 Javascript
vue+element实现打印页面功能
2019/05/20 Javascript
Laravel admin实现消息提醒、播放音频功能
2019/07/10 Javascript
vue下使用nginx刷新页面404的问题解决
2019/08/02 Javascript
ES6的异步操作之promise用法和async函数的具体使用
2019/12/06 Javascript
如何实现小程序与小程序之间的跳转
2020/11/04 Javascript
Python中 Lambda表达式全面解析
2016/11/28 Python
Python聊天室程序(基础版)
2018/04/01 Python
Jupyter中直接显示Matplotlib的图形方法
2018/05/24 Python
Django REST framework 分页的实现代码
2019/06/19 Python
Pandas之MultiIndex对象的示例详解
2019/06/25 Python
win10安装tesserocr配置 Python使用tesserocr识别字母数字验证码
2020/01/16 Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
2020/02/11 Python
python cookie反爬处理的实现
2020/11/01 Python
python中reload重载实例用法
2020/12/15 Python
关于HTML5 Placeholder新标签低版本浏览器下不兼容的问题分析及解决办法
2016/01/27 HTML / CSS
研讨会主持词
2014/04/02 职场文书
2014年寒假社会实践活动心得体会
2014/04/07 职场文书
趣味运动会策划方案
2014/06/02 职场文书
经典团队口号大全
2014/06/21 职场文书
2015年大学学生会工作总结
2015/05/13 职场文书
《作风建设永远在路上》心得体会
2016/01/21 职场文书
Go中的条件语句Switch示例详解
2021/08/23 Golang