TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
win7 下搭建sublime的python开发环境的配置方法
Jun 18 Python
Python将图片转换为字符画的方法
Jun 16 Python
python list是否包含另一个list所有元素的实例
May 04 Python
解决python matplotlib imshow无法显示的问题
May 24 Python
Python操作mongodb的9个步骤
Jun 04 Python
详解Python3.6的py文件打包生成exe
Jul 13 Python
python函数装饰器之带参数的函数和带参数的装饰器用法示例
Nov 06 Python
使用pygame写一个古诗词填空通关游戏
Dec 03 Python
flask 使用 flask_apscheduler 做定时循环任务的实现
Dec 10 Python
如何查看python关键字
Jan 17 Python
【超详细】八大排序算法的各项比较以及各自特点
Mar 31 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
帖几个PHP的无限分类实现想法~
2007/01/02 PHP
PHP操作xml代码
2010/06/17 PHP
php+ajax 实现输入读取数据库显示匹配信息
2015/10/08 PHP
php arsort 数组降序排序详细介绍
2016/11/17 PHP
Yii2框架实现数据库常用操作总结
2017/02/08 PHP
jQuery的强大选择器小结
2009/12/27 Javascript
js AppendChild与insertBefore用法详细对比
2013/12/16 Javascript
node.js不得不说的12点内容
2014/07/14 Javascript
jquery实现鼠标滑过显示提示框的方法
2015/02/05 Javascript
Javascript实现Array和String互转换的方法
2015/12/21 Javascript
jQuery插件实现文字无缝向上滚动效果代码
2016/02/25 Javascript
三种带箭头提示框总结实例
2016/06/14 Javascript
textarea 在浏览器中固定大小和禁止拖动的实现方法
2016/12/03 Javascript
详解HTML5 使用video标签实现选择摄像头功能
2017/10/25 Javascript
JavaScript基础心法 深浅拷贝(浅拷贝和深拷贝)
2018/03/05 Javascript
JavaScript实现区块链
2018/03/14 Javascript
对layui初始化列表的CheckBox属性详解
2019/09/13 Javascript
vue解决使用$http获取数据时报错的问题
2019/10/30 Javascript
40行代码把Vue3的响应式集成进React做状态管理
2020/05/20 Javascript
[09:23]国际邀请赛采访专栏:iG战队VK,Tongfu战队Cu
2013/08/05 DOTA
[44:40]Spirit vs Navi Supermajor小组赛 A组败者组第一轮 BO3 第一场 6.2
2018/06/03 DOTA
python连接数据库的方法
2017/10/19 Python
python遍历文件夹下所有excel文件
2018/01/03 Python
Python读取YAML文件过程详解
2019/12/30 Python
如何在windows下安装Pycham2020软件(方法步骤详解)
2020/05/03 Python
Python3 pyecharts生成Html文件柱状图及折线图代码实例
2020/09/29 Python
css3背景_动力节点Java学院整理
2017/07/11 HTML / CSS
详解HTML5 LocalStorage 本地存储
2016/12/23 HTML / CSS
北美三大旅游网站之一:Travelocity
2017/08/12 全球购物
工业学校毕业生自荐书
2014/01/03 职场文书
个人简历自我评价范文
2014/02/04 职场文书
英语系毕业生求职信
2014/07/13 职场文书
2015大学生求职信范文
2015/03/20 职场文书
销售督导岗位职责
2015/04/10 职场文书
2015年学校财务工作总结
2015/05/19 职场文书
2015年大学组织委员个人工作总结
2015/10/23 职场文书