tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Web开发模板引擎优缺点总结
May 06 Python
Python文件夹与文件的操作实现代码
Jul 13 Python
Python编程实现删除VC临时文件及Debug目录的方法
Mar 22 Python
python方法生成txt标签文件的实例代码
May 10 Python
用Python中的turtle模块画图两只小羊方法
Apr 09 Python
使用python判断jpeg图片的完整性实例
Jun 10 Python
使用NumPy读取MNIST数据的实现代码示例
Nov 20 Python
pandas实现DataFrame显示最大行列,不省略显示实例
Dec 26 Python
python两种注释用法的示例
Oct 09 Python
Python通过Schema实现数据验证方式
Nov 12 Python
PyTorch的Debug指南
May 07 Python
Python中with上下文管理协议的作用及用法
Mar 18 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
PHP 图片文件上传实现代码
2010/12/29 PHP
php+iframe实现隐藏无刷新上传文件
2012/02/10 PHP
PHP mysqli事务操作常用方法分析
2017/07/22 PHP
js几个不错的函数 $$()
2006/10/09 Javascript
判断多个元素(RADIO,CHECKBOX等)是否被选择的原理说明
2009/02/18 Javascript
深入理解JavaScript系列(13) This? Yes,this!
2012/01/18 Javascript
js性能优化 如何更快速加载你的JavaScript页面
2012/03/17 Javascript
window.requestAnimationFrame是什么意思,怎么用
2013/01/13 Javascript
Extjs4 GridPanel的主要配置参数详细介绍
2013/04/18 Javascript
js操纵跨frame的三级联动select下拉选项实例介绍
2013/05/19 Javascript
在JavaScript中call()与apply()区别
2016/01/22 Javascript
复杂的javascript窗口分帧解析
2016/02/19 Javascript
深入浅析Extjs中store分组功能的使用方法
2016/04/20 Javascript
jQuery实现查找链接文字替换属性的方法
2016/06/27 Javascript
Javascript实现图片懒加载插件的方法
2016/10/20 Javascript
js实现定时进度条完成后切换图片
2017/01/04 Javascript
初识NodeJS服务端开发入门(Express+MySQL)
2017/04/07 NodeJs
使用Object.defineProperty如何巧妙找到修改某个变量的准确代码位置
2018/11/02 Javascript
开发中常用的25个JavaScript单行代码(小结)
2019/06/28 Javascript
layui实现下拉框三级联动
2019/07/26 Javascript
[01:34]2014DOTA2 TI预选赛预选赛 选手比赛房大揭秘!
2014/05/20 DOTA
python实现用户管理系统
2018/01/10 Python
python实现Zabbix-API监控
2018/09/17 Python
在PyCharm中实现关闭一个死循环程序的方法
2018/11/29 Python
在python中实现将一张图片剪切成四份的方法
2018/12/05 Python
python pytest进阶之xunit fixture详解
2019/06/27 Python
python3 正则表达式基础廖雪峰
2020/03/25 Python
使用phonegap创建联系人的实现方法
2017/03/30 HTML / CSS
波兰香水和化妆品购物网站:Notino.pl
2017/11/07 全球购物
草莓网化妆品澳大利亚站:Strawberrynet AU
2017/12/18 全球购物
美国波西米亚风格精品店:South Moon Under
2019/10/26 全球购物
销售自我评价
2013/10/22 职场文书
公务员职业生涯规划书范文  
2014/01/19 职场文书
物资采购方案
2014/06/12 职场文书
医院义诊活动总结
2014/07/04 职场文书
青岛导游词
2015/02/12 职场文书