tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
haskell实现多线程服务器实例代码
Nov 26 Python
Python时区设置方法与pytz查询时区教程
Nov 27 Python
详细讲解用Python发送SMTP邮件的教程
Apr 29 Python
python opencv之SURF算法示例
Feb 24 Python
Python编程flask使用页面模版的方法
Dec 28 Python
pycharm 实现显示project 选项卡的方法
Jan 17 Python
djang常用查询SQL语句的使用代码
Feb 15 Python
Django 源码WSGI剖析过程详解
Aug 05 Python
python 伯努利分布详解
Feb 25 Python
Python实现RabbitMQ6种消息模型的示例代码
Mar 30 Python
django 装饰器 检测登录状态操作
Jul 02 Python
Python常用断言函数实例汇总
Nov 30 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
PHP JSON格式数据交互实例代码详解
2011/01/13 PHP
PHP mysql与mysqli事务使用说明 分享
2013/08/17 PHP
PHP中的日期时间处理利器实例(Carbon)
2017/06/09 PHP
PHP 计算两个特别大的整数实例代码
2018/05/07 PHP
通过修改referer下载文件的方法
2008/05/11 Javascript
Javascript结合css实现网页换肤功能
2009/11/02 Javascript
一个JavaScript变量声明的知识点
2013/10/28 Javascript
关闭浏览器输入框自动补齐 兼容IE,FF,Chrome等主流浏览器
2014/02/11 Javascript
浅谈JavaScript中null和undefined
2015/07/09 Javascript
Jquery全选与反选点击执行一次的解决方案
2015/08/14 Javascript
jquery自定义插件结合baiduTemplate.js实现异步刷新(附源码)
2016/12/22 Javascript
Bootstrap弹出框(Popover)被挤压的问题小结
2017/07/11 Javascript
vue.js给动态绑定的radio列表做批量编辑的方法
2018/02/28 Javascript
Vue中props的详解
2019/05/16 Javascript
jQuery/JS监听input输入框值变化实例
2019/10/17 jQuery
解决pycharm双击但是无法打开的情况
2020/10/31 Javascript
[01:03:56]Mineski vs TNC 2018国际邀请赛淘汰赛BO1 8.21
2018/08/22 DOTA
[39:00]Optic vs VP 2018国际邀请赛淘汰赛BO3 第三场 8.24
2018/08/25 DOTA
Python通过属性手段实现只允许调用一次的示例讲解
2018/04/21 Python
python模块之subprocess模块级方法的使用
2019/03/26 Python
python使用time、datetime返回工作日列表实例代码
2019/05/09 Python
用python做游戏的细节详解
2019/06/25 Python
Python面向对象之Web静态服务器
2019/09/03 Python
Tensorflow读取并输出已保存模型的权重数值方式
2020/01/04 Python
Jupyter Notebook折叠输出的内容实例
2020/04/22 Python
python实现人像动漫化的示例代码
2020/05/17 Python
python用Configobj模块读取配置文件
2020/09/26 Python
python 指定源路径来解决import问题的操作
2021/03/04 Python
HTML5实现表单自动验证功能实例代码
2017/01/11 HTML / CSS
HTML5 Web缓存和运用程序缓存(cookie,session)
2018/01/11 HTML / CSS
Linux内核产生并发的原因
2016/11/08 面试题
电子商务专业学生的学习自我评价
2013/10/27 职场文书
代办社保委托书范文
2014/10/06 职场文书
2016五四青年节活动总结范文
2016/04/06 职场文书
vue实现简单数据双向绑定
2021/04/28 Vue.js
MySQL数据库配置信息查看与修改方法详解
2022/06/25 MySQL