TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxPython事件驱动实例详解
Sep 28 Python
Python连接MySQL并使用fetchall()方法过滤特殊字符
Mar 13 Python
Python 使用requests模块发送GET和POST请求的实现代码
Sep 21 Python
python实现闹钟定时播放音乐功能
Jan 25 Python
Django中使用Whoosh进行全文检索的方法
Mar 31 Python
Python中请不要再用re.compile了
Jun 30 Python
django foreignkey(外键)的实现
Jul 29 Python
通过python3实现投票功能代码实例
Sep 26 Python
Python通过4种方式实现进程数据通信
Mar 12 Python
解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题
Jun 01 Python
利用python实时刷新基金估值(摸鱼小工具)
Sep 15 Python
基于Python实现流星雨效果的绘制
Mar 18 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
fleaphp下不确定的多条件查询的巧妙解决方法
2008/09/11 PHP
解析PHP中empty is_null和isset的测试
2013/06/29 PHP
使用PHP下载CSS文件中的图片的代码
2013/09/24 PHP
thinkphp实现图片上传功能分享
2014/03/04 PHP
再谈javascript图片预加载技术(详细演示)
2011/03/12 Javascript
使用Javascript写的2048小游戏
2015/11/25 Javascript
关于jquery中动态增加select,事件无效的快速解决方法
2016/08/29 Javascript
jquery删除table当前行的实例代码
2016/10/07 Javascript
javascript显示系统当前时间代码
2016/12/29 Javascript
利用Vue构造器创建Form组件的通用解决方法
2018/12/03 Javascript
VuePress 静态网站生成方法步骤
2019/02/14 Javascript
Vue商品控件与购物车联动效果的实例代码
2019/07/21 Javascript
微信小程序页面滚动到指定位置代码实例
2019/09/07 Javascript
js判断在哪个浏览器打开项目的方法
2020/01/21 Javascript
[46:14]VGJ.T vs Liquid 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
[01:14]DOTA2 7.22版本新增神杖效果展示(智力英雄篇)
2019/05/29 DOTA
python更新列表的方法
2015/07/28 Python
python机器学习之随机森林(七)
2018/03/26 Python
Python hashlib模块用法实例分析
2018/06/12 Python
Python之用户输入的实例
2018/06/22 Python
Python简易版停车管理系统
2019/08/12 Python
基于python分析你的上网行为 看看你平时上网都在干嘛
2019/08/13 Python
Python性能测试工具Locust安装及使用
2020/12/01 Python
Jupyter Notebook添加代码自动补全功能的实现
2021/01/07 Python
CSS3区域模块region相关编写示例
2015/08/28 HTML / CSS
HTML5新增的8类INPUT输入类型介绍
2015/07/06 HTML / CSS
美国著名的户外用品品牌:L.L.Bean
2018/01/05 全球购物
莫斯科制造商的廉价皮大衣:Fursk
2020/06/09 全球购物
办公室主任职责范文
2013/11/08 职场文书
副厂长岗位职责
2014/02/02 职场文书
商场消防演习方案
2014/02/12 职场文书
副校长个人对照检查材料思想汇报
2014/10/04 职场文书
如何写好开幕词?
2019/06/24 职场文书
Apache压力测试工具的安装使用
2021/03/31 Servers
Python re.sub 反向引用的实现
2021/07/07 Python
quickjs 封装 JavaScript 沙箱详情
2021/11/02 Javascript