tensorflow 20:搭网络,导出模型,运行模型的实例


Posted in Python onMay 26, 2020

概述

以前自己都利用别人搭好的工程,修改过来用,很少把模型搭建、导出模型、加载模型运行走一遍,搞了一遍才知道这个事情也不是那么简单的。

搭建模型和导出模型

参考《TensorFlow固化模型》,导出固化的模型有两种方式.

方式1:导出pb图结构和ckpt文件,然后用 freeze_graph 工具冻结生成一个pb(包含结构和参数)

在我的代码里测试了生成pb图结构和ckpt文件,但是没接着往下走,感觉有点麻烦。我用的是第二种方法。

注意我这里只在最后保存了一次ckpt,实际应该在训练中每隔一段时间就保存一次的。

saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # 保存pb和ckpt
 print('save pb file and ckpt file')
 tf.train.write_graph(sess.graph_def, graph_location, "graph.pb",as_text=False)
 checkpoint_path = os.path.join(graph_location, "model.ckpt")
 saver.save(sess, checkpoint_path, global_step=max_step)

方式2:convert_variables_to_constants

我实际使用的就是这种方法。

看名字也知道,就是把变量转化为常量保存,这样就可以愉快的加载使用了。

注意这里需要指明保存的输出节点,我的输出节点为'out/fc2'(我猜测会根据输出节点的依赖推断哪些部分是训练用到的,推理时用不到)。关于输出节点的名字是有规律的,其中out是一个name_scope名字,fc2是op节点的名字。

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 # 固化模型
 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

上述代码会在训练后把训练好的计算图和参数保存到frozen_graph.pb文件。后续就可以用这个模型来测试图片了。

方式2的完整训练和保存模型代码

主要看main函数就行。另外注意deepnn函数最后节点的名字。

"""A deep MNIST classifier using convolutional layers.

See extensive documentation at
https://www.tensorflow.org/get_started/mnist/pros
"""
# Disable linter warnings to maintain consistency with tutorial.
# pylint: disable=invalid-name
# pylint: disable=g-bad-import-order

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile
import os

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants

import tensorflow as tf
FLAGS = None

def deepnn(x):
 """deepnn builds the graph for a deep net for classifying digits.

 Args:
 x: an input tensor with the dimensions (N_examples, 784), where 784 is the
 number of pixels in a standard MNIST image.

 Returns:
 A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
 equal to the logits of classifying the digit into one of 10 classes (the
 digits 0-9). keep_prob is a scalar placeholder for the probability of
 dropout.
 """
 # Reshape to use within a convolutional neural net.
 # Last dimension is for "features" - there is only one here, since images are
 # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
 with tf.name_scope('reshape'):
 x_image = tf.reshape(x, [-1, 28, 28, 1])

 # First convolutional layer - maps one grayscale image to 32 feature maps.
 with tf.name_scope('conv1'):
 W_conv1 = weight_variable([5, 5, 1, 32])
 b_conv1 = bias_variable([32])
 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)

 # Pooling layer - downsamples by 2X.
 with tf.name_scope('pool1'):
 h_pool1 = max_pool_2x2(h_conv1)

 # Second convolutional layer -- maps 32 feature maps to 64.
 with tf.name_scope('conv2'):
 W_conv2 = weight_variable([5, 5, 32, 64])
 b_conv2 = bias_variable([64])
 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

 # Second pooling layer.
 with tf.name_scope('pool2'):
 h_pool2 = max_pool_2x2(h_conv2)

 # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
 # is down to 7x7x64 feature maps -- maps this to 1024 features.
 with tf.name_scope('fc1'):
 W_fc1 = weight_variable([7 * 7 * 64, 1024])
 b_fc1 = bias_variable([1024])

 h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

 # Dropout - controls the complexity of the model, prevents co-adaptation of
 # features.
 with tf.name_scope('dropout'):
 keep_prob = tf.placeholder(tf.float32, name='ratio')
 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

 # Map the 1024 features to 10 classes, one for each digit
 with tf.name_scope('out'):
 W_fc2 = weight_variable([1024, 10])
 b_fc2 = bias_variable([10])

 y_conv = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, name='fc2')
 return y_conv, keep_prob

def conv2d(x, W):
 """conv2d returns a 2d convolution layer with full stride."""
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
 """max_pool_2x2 downsamples a feature map by 2X."""
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
   strides=[1, 2, 2, 1], padding='SAME')

def weight_variable(shape):
 """weight_variable generates a weight variable of a given shape."""
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)

def bias_variable(shape):
 """bias_variable generates a bias variable of a given shape."""
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

def main(_):
 # Import data
 mnist = input_data.read_data_sets(FLAGS.data_dir)

 # Create the model
 with tf.name_scope('input'):
 x = tf.placeholder(tf.float32, [None, 784], name='x')

 # Define loss and optimizer
 y_ = tf.placeholder(tf.int64, [None])

 # Build the graph for the deep net
 y_conv, keep_prob = deepnn(x)

 with tf.name_scope('loss'):
 cross_entropy = tf.losses.sparse_softmax_cross_entropy(
 labels=y_, logits=y_conv)
 cross_entropy = tf.reduce_mean(cross_entropy)

 with tf.name_scope('adam_optimizer'):
 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

 with tf.name_scope('accuracy'):
 correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
 correct_prediction = tf.cast(correct_prediction, tf.float32)
 accuracy = tf.reduce_mean(correct_prediction)

 graph_location = './model'
 print('Saving graph to: %s' % graph_location)
 train_writer = tf.summary.FileWriter(graph_location)
 train_writer.add_graph(tf.get_default_graph())

 saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # save pb file and ckpt file
 #print('save pb file and ckpt file')
 #tf.train.write_graph(sess.graph_def, graph_location, "graph.pb", as_text=False)
 #checkpoint_path = os.path.join(graph_location, "model.ckpt")
 #saver.save(sess, checkpoint_path, global_step=max_step)

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
 parser = argparse.ArgumentParser()
 parser.add_argument('--data_dir', type=str,
   default='./data',
   help='Directory for storing input data')
 FLAGS, unparsed = parser.parse_known_args()
 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

加载模型进行推理

上一节已经训练并导出了frozen_graph.pb。

这一节把它运行起来。

加载模型

下方的代码用来加载模型。推理时计算图里共两个placeholder需要填充数据,一个是图片(这不废话吗),一个是drouout_ratio,drouout_ratio用一个常量作为输入,后续就只需要输入图片了。

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

input_map参数并不是必须的。如果不用input_map,可以在run之前用tf.get_default_graph().get_tensor_by_name获取tensor的句柄。但是我觉得这种方法不是很友好,我这里没用这种方法。

注意input_map里的tensor名字是和搭计算图时的name_scope和op名字有关的,而且后面要补一个‘:0'(这点我还没细究)。

同时要注意,newInput_X的形状是[None, 784],第一维是batch大小,推理时和训练要一致。

(我用的是mnist图片,训练时每个bacth的形状是[batchsize, 784],每个图片是28x28)

运行模型

我是一张张图片单独测试的,运行模型之前先把图片变为[1, 784],以符合newInput_X的维数。

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

注意模型的输出是一个长度为10的一维数组,也就是计算图里全连接的输出。这里没有softmax,只要取最大值的下标即可得到结果。

输出结果:

full_path:./test_images/97_7.jpg
image_np_expanded shape:(1, 784)
result:[-1340.37145996 -283.72436523 1305.03320312 437.6053772 -413.69961548
 -1218.08166504 -1004.83807373 1953.33984375 42.00457001 -504.43829346]
result:7

full_path:./test_images/98_6.jpg
image_np_expanded shape:(1, 784)
result:[ 567.4041748 -550.20904541 623.83496094 -1152.56884766 -217.92695618
 1033.45239258 2496.44750977 -1139.23620605 -5.64091825 -615.28491211]
result:6

full_path:./test_images/99_9.jpg
image_np_expanded shape:(1, 784)
result:[ -532.26409912 -1429.47277832 -368.58096313 505.82876587 358.42163086
 -317.48199463 -1108.6829834 1198.08752441 289.12286377 3083.52539062]
result:9

加载模型进行推理的完整代码

import sys
import os
import cv2
import numpy as np
import tensorflow as tf
test_image_dir = './test_images/'

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 #output = tf.import_graph_def(graph_def)
 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

以上这篇tensorflow 20:搭网络,导出模型,运行模型的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python原始字符串(raw strings)用法实例
Oct 13 Python
Python实现爬虫设置代理IP和伪装成浏览器的方法分享
May 07 Python
Python rstrip()方法实例详解
Nov 11 Python
pygame游戏之旅 添加碰撞效果的方法
Nov 20 Python
Python实现蒙特卡洛算法小实验过程详解
Jul 12 Python
Python 通过截图匹配原图中的位置(opencv)实例
Aug 27 Python
numpy:np.newaxis 实现将行向量转换成列向量
Nov 30 Python
tensorflow 利用expand_dims和squeeze扩展和压缩tensor维度方式
Feb 07 Python
keras .h5转移动端的.tflite文件实现方式
May 25 Python
Python操作Word批量生成合同的实现示例
Aug 28 Python
用python发送微信消息
Dec 21 Python
Python实现制作销售数据可视化看板详解
Nov 27 Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
You might like
使用数据库保存session的方法
2006/10/09 PHP
PHP中cookies使用指南
2007/03/16 PHP
php cookie 登录验证示例代码
2009/03/16 PHP
php下关于中英数字混排的字符串分割问题
2010/04/06 PHP
php递归方法实现无限分类实例代码
2014/02/28 PHP
php验证邮箱和ip地址最简单方法汇总
2015/10/30 PHP
CodeIgniter记录错误日志的方法全面总结
2016/05/17 PHP
遍历echsop的region表形成缓存的程序实例代码
2016/11/01 PHP
php  单例模式详细介绍及实现源码
2016/11/05 PHP
js实现简单模态窗口,背景灰显
2008/11/14 Javascript
JS倒计时代码汇总
2014/11/25 Javascript
Javascript中的return作用及javascript return关键字用法详解
2015/11/05 Javascript
基于JS实现横线提示输入验证码随验证码输入消失(js验证码的实现)
2016/10/27 Javascript
JS批量替换内容中关键词为超链接
2017/02/20 Javascript
vue.js实现单选框、复选框和下拉框示例
2017/07/18 Javascript
weui框架实现上传、预览和删除图片功能代码
2017/08/24 Javascript
用最简单的方法判断JavaScript中this的指向(推荐)
2017/09/04 Javascript
微信小程序实现点击按钮修改view标签背景颜色功能示例【附demo源码下载】
2017/12/06 Javascript
vue+element实现批量删除功能的示例
2018/02/28 Javascript
关于单文件组件.vue的使用
2018/09/20 Javascript
checkbox在vue中的用法小结
2018/11/13 Javascript
javascript实现获取中文汉字拼音首字母
2020/05/19 Javascript
python 排列组合之itertools
2013/03/20 Python
python网络编程之TCP通信实例和socketserver框架使用例子
2014/04/25 Python
python3+PyQt5+Qt Designer实现堆叠窗口部件
2018/04/20 Python
Python 3.7新功能之dataclass装饰器详解
2018/04/21 Python
Python图像处理实现两幅图像合成一幅图像的方法【测试可用】
2019/01/04 Python
django从后台返回html代码的实例
2020/03/11 Python
音乐专业自荐信
2014/02/07 职场文书
物业管理专业求职信
2014/06/11 职场文书
租房协议书范文
2014/08/20 职场文书
立春观后感
2015/06/18 职场文书
优秀大学生申请书
2019/06/24 职场文书
Nginx使用X-Accel-Redirect实现静态文件下载的统计、鉴权、防盗链、限速等
2021/04/04 Servers
解决MySQL添加新用户-ERROR 1045 (28000)的问题
2022/03/03 MySQL
Mysql 数据库中的 redo log 和 binlog 写入策略
2022/04/26 MySQL