浅谈Tensorflow模型的保存与恢复加载


Posted in Python onApril 26, 2018

近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

总结一下Tensorflow常用的模型保存方式。

保存checkpoint模型文件(.ckpt)

首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。

模型保存

使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:

import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))

程序生成并保存四个文件(在版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta)

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

加上global_step参数代表在每1000次迭代后保存模型,会在模型文件后加上"-1000",model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

每1000次迭代保存一次模型,但是模型的结构信息文件不会变,就只用1000次迭代时保存一下,不用相应的每1000次保存一次,所以当我们不需要保存meta文件时,可以加上write_meta_graph=False参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

如果想每两小时保存一次模型,并且只保存最新的4个模型,可以加上使用max_to_keep(默认值为5,如果想每训练一个epoch就保存一次,可以将其设置为None或0,但是没啥用不推荐), keep_checkpoint_every_n_hours参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)

同时在tf.train.Saver()类中,如果我们不指定任何信息,则会保存所有的参数信息,我们也可以指定部分想要保存的内容,例如只保存x, y参数(可传入参数list或dict):

tf.train.Saver([x, y]).save(sess, ckpt_file_path)

ps. 在模型训练过程中需要在保存后拿到的变量或参数名属性name不能丢,不然模型还原后不能通过get_tensor_by_name()获取。

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)

首先还原模型结构,然后还原变量(参数)信息,最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将MongoDB里的ObjectId转换为时间戳的方法
Mar 13 Python
python实现JAVA源代码从ANSI到UTF-8的批量转换方法
Aug 10 Python
python获取命令行输入参数列表的实例代码
Jun 23 Python
对python Tkinter Text的用法详解
Oct 11 Python
python安装pil库方法及代码
Jun 25 Python
django多文件上传,form提交,多对多外键保存的实例
Aug 06 Python
Pytorch之Variable的用法
Dec 31 Python
python在不同条件下的输入与输出
Feb 13 Python
Python实现动态给类和对象添加属性和方法操作示例
Feb 29 Python
python和js交互调用的方法
Jun 23 Python
属性与 @property 方法让你的python更高效
Sep 21 Python
python 制作磁力搜索工具
Mar 04 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
Apr 26 #Python
Python实现的计算器功能示例
Apr 26 #Python
python email smtplib模块发送邮件代码实例
Apr 26 #Python
Python利用正则表达式实现计算器算法思路解析
Apr 25 #Python
Python实现随机生成手机号及正则验证手机号的方法
Apr 25 #Python
Python实现按中文排序的方法示例
Apr 25 #Python
Python实现的基于优先等级分配糖果问题算法示例
Apr 25 #Python
You might like
php去除字符串换行符示例分享
2014/02/13 PHP
PHP防止post重复提交数据的简单例子
2014/06/07 PHP
PHP实现更新中间关联表数据的两种方法
2014/09/01 PHP
PHP如何将XML转成数组
2016/04/04 PHP
Yii框架实现记录日志到自定义文件的方法
2017/05/23 PHP
PHP编程实现计算抽奖概率算法完整实例
2017/08/09 PHP
向大师们学习Javascript(视频与PPT)
2009/12/27 Javascript
Javascript 判断是否存在函数的方法
2013/01/03 Javascript
Javascript代码在页面加载时的执行顺序介绍
2013/05/03 Javascript
JS实现的生成随机数的4个函数分享
2015/02/11 Javascript
JavaScript基础知识点归纳(推荐)
2016/07/09 Javascript
Javascript点击按钮随机改变数字与其颜色
2016/09/01 Javascript
Bootstrap3 多个模态对话框无法显示的解决方案
2017/02/23 Javascript
详解Vue.js搭建路由报错 router.map is not a function
2017/06/27 Javascript
关于Ajax的原理以及代码封装详解
2017/09/08 Javascript
深入koa-bodyparser原理解析
2019/01/16 Javascript
[01:06]DOTA2隆重推出2016冬季勇士令状 内含上海特级锦标赛互动指南
2016/02/17 DOTA
Python+django实现文件上传
2016/01/17 Python
Python引用类型和值类型的区别与使用解析
2017/10/17 Python
python pandas dataframe 按列或者按行合并的方法
2018/04/12 Python
浅谈利用numpy对矩阵进行归一化处理的方法
2018/07/11 Python
详解python中的Turtle函数库
2018/11/19 Python
利用python提取wav文件的mfcc方法
2019/01/09 Python
python之拟合的实现
2019/07/19 Python
python 使用OpenCV进行简单的人像分割与合成
2021/02/02 Python
深入浅析css3 border-image边框图像详解
2015/11/24 HTML / CSS
美国购买体育、音乐会和剧院门票网站:SelectATicket
2019/09/08 全球购物
法国隐形眼镜网站:VisionDirect.fr
2020/03/03 全球购物
开业庆典邀请函
2014/01/08 职场文书
革命英雄事迹演讲稿
2014/09/13 职场文书
爱情保证书
2015/01/17 职场文书
收入证明怎么写
2015/06/12 职场文书
oracle重置序列从0开始递增1
2022/02/28 Oracle
聊聊基于pytorch实现Resnet对本地数据集的训练问题
2022/03/25 Python
《我的美好婚事》动画化决定纪念插画与先导PV公开
2022/04/06 日漫
改造DE1103三步曲
2022/04/07 无线电