TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python切片用法实例教程
Sep 08 Python
Python的Flask框架中实现分页功能的教程
Apr 20 Python
python logging 日志轮转文件不删除问题的解决方法
Aug 02 Python
一份python入门应该看的学习资料
Apr 11 Python
Python 使用PIL numpy 实现拼接图片的示例
May 08 Python
教你一步步利用python实现贪吃蛇游戏
Jun 27 Python
pandas对dataFrame中某一个列的数据进行处理的方法
Jul 08 Python
python移位运算的实现
Jul 15 Python
基于python代码批量处理图片resize
Jun 04 Python
Python创建简单的神经网络实例讲解
Jan 04 Python
Python 打印自己设计的字体的实例讲解
Jan 04 Python
pycharm 如何查看某一函数源码的快捷键
May 12 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
PHP is_dir() 判断给定文件名是否是一个目录
2010/05/10 PHP
PHP中copy on write写时复制机制介绍
2014/05/13 PHP
PHP代码维护,重构变困难的4种原因分析
2016/01/25 PHP
php通过会话控制实现身份验证实例
2016/10/18 PHP
禁止ajax缓存获取程序最新数据的方法
2013/11/19 Javascript
动态添加删除表格行的js实现代码
2014/02/28 Javascript
Jquery实现自定义弹窗示例
2014/03/12 Javascript
javascript将浮点数转换成整数的三个方法
2014/06/23 Javascript
jquery+CSS实现的多级竖向展开树形TRee菜单效果
2015/08/24 Javascript
CSS或者JS实现鼠标悬停显示另一元素
2016/01/22 Javascript
基于jquery实现三级下拉菜单
2016/05/10 Javascript
在JavaScript中模拟类(class)及类的继承关系
2016/05/20 Javascript
JS扩展类,克隆对象与混合类实例分析
2016/11/26 Javascript
JavaScript复制内容到剪贴板的两种常用方法
2018/02/27 Javascript
微信小程序实现保存图片到相册功能
2018/11/30 Javascript
vue-cli4.x创建企业级项目的方法步骤
2020/06/18 Javascript
[02:45]DOTA2英雄敌法师基础教程
2013/11/25 DOTA
pycharm 使用心得(一)安装和首次使用
2014/06/05 Python
python使用Queue在多个子进程间交换数据的方法
2015/04/18 Python
在Python中处理字符串之isdecimal()方法的使用
2015/05/20 Python
python导入时小括号大作用
2017/01/10 Python
Windows 8.1 64bit下搭建 Scrapy 0.22 环境
2018/11/18 Python
python celery分布式任务队列的使用详解
2019/07/08 Python
Pytorch卷积层手动初始化权值的实例
2019/08/17 Python
Python Django 添加首页尾页上一页下一页代码实例
2019/08/21 Python
Django中的cookie和session
2019/08/27 Python
解决Tensorboard可视化错误:不显示数据 No scalar data was found
2020/02/15 Python
浅谈matplotlib中FigureCanvasXAgg的用法
2020/06/16 Python
学python需要去培训机构吗
2020/07/01 Python
Stella McCartney官网:成衣、包袋、香水、内衣、童装及Adidas系列
2018/12/20 全球购物
意大利奢侈品零售商:ilDuomo Novara
2019/09/11 全球购物
个人找工作的自我评价
2013/10/17 职场文书
大学生学习生活的自我评价
2013/11/01 职场文书
公司联欢会主持词
2015/07/04 职场文书
如何利用map实现Nginx允许多个域名跨域
2021/03/31 Servers
Python MNIST手写体识别详解与试练
2021/11/07 Python