tensorflow模型的save与restore,及checkpoint中读取变量方式


Posted in Python onMay 26, 2020

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
 
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

tensorflow模型的save与restore,及checkpoint中读取变量方式

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python标准库之sqlite3使用实例
Nov 25 Python
巧用python和libnmapd,提取Nmap扫描结果
Aug 23 Python
Python/ArcPy遍历指定目录中的MDB文件方法
Oct 27 Python
利用python提取wav文件的mfcc方法
Jan 09 Python
Python实现简单查找最长子串功能示例
Feb 26 Python
pytorch 更改预训练模型网络结构的方法
Aug 19 Python
pycharm 中mark directory as exclude的用法详解
Feb 14 Python
解决Python在导入文件时的FileNotFoundError问题
Apr 10 Python
Python如何脚本过滤文件中的注释
May 27 Python
Python之Matplotlib文字与注释的使用方法
Jun 18 Python
基于Python模拟浏览器发送http请求
Nov 06 Python
详解Django关于StreamingHttpResponse与FileResponse文件下载的最优方法
Jan 07 Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
You might like
php中使用DOM类读取XML文件的实现代码
2011/12/14 PHP
PHP整数取余返回负数的相关解决方法
2014/05/15 PHP
Yii框架上传图片用法总结
2016/03/28 PHP
Zend Framework教程之Zend_Helpers动作助手ViewRenderer用法详解
2016/07/20 PHP
mac系统下为 php 添加 pcntl 扩展
2016/08/28 PHP
在线一元二次方程计算器实例(方程计算器在线计算)
2013/12/22 Javascript
js脚本获取webform服务器控件的方法
2014/05/16 Javascript
分享javascript计算时间差的示例代码
2020/03/19 Javascript
jquery实现文本框textarea自适应高度
2016/03/09 Javascript
requireJS使用指南
2016/04/27 Javascript
Vue中fragment.js使用方法详解
2017/03/09 Javascript
NodeJs安装npm包一直失败的解决方法
2017/04/28 NodeJs
JavaScript事件冒泡与事件捕获实例分析
2018/08/01 Javascript
vue侧边栏动态生成下级菜单的方法
2018/09/07 Javascript
vue多级复杂列表展开/折叠及全选/分组全选实现
2018/11/05 Javascript
vue 兄弟组件的信息传递的方法实例详解
2019/08/30 Javascript
[03:03]2014DOTA2西雅图国际邀请赛 Alliance战队巡礼
2014/07/07 DOTA
Python跨文件全局变量的实现方法示例
2017/12/10 Python
Python设置在shell脚本中自动补全功能的方法
2018/06/25 Python
Python2实现的图片文本识别功能详解
2018/07/11 Python
Python为何不能用可变对象作为默认参数的值
2019/07/01 Python
Python二次规划和线性规划使用实例
2019/12/09 Python
利用matplotlib实现根据实时数据动态更新图形
2019/12/13 Python
css3实现平移效果(transfrom:translate)的示例
2020/11/13 HTML / CSS
英国男女奢华内衣和泳装购物网站:Figleaves
2017/01/28 全球购物
美国家庭鞋店:Shoe Sensation
2019/09/27 全球购物
ECCO俄罗斯官网:北欧丹麦鞋履及皮具品牌
2020/06/26 全球购物
大三在校生电子商务求职信
2013/10/29 职场文书
关于期中考试的反思
2014/02/02 职场文书
工程力学专业自荐信范文
2014/03/17 职场文书
青春飞扬演讲稿
2014/09/11 职场文书
世界地球日活动总结
2015/02/09 职场文书
交通事故调解协议书
2015/05/20 职场文书
升学宴学生致辞
2015/07/27 职场文书
小学毕业感言200字
2015/07/30 职场文书
导游词之太行山青龙峡
2020/01/14 职场文书