tensorflow 20:搭网络,导出模型,运行模型的实例


Posted in Python onMay 26, 2020

概述

以前自己都利用别人搭好的工程,修改过来用,很少把模型搭建、导出模型、加载模型运行走一遍,搞了一遍才知道这个事情也不是那么简单的。

搭建模型和导出模型

参考《TensorFlow固化模型》,导出固化的模型有两种方式.

方式1:导出pb图结构和ckpt文件,然后用 freeze_graph 工具冻结生成一个pb(包含结构和参数)

在我的代码里测试了生成pb图结构和ckpt文件,但是没接着往下走,感觉有点麻烦。我用的是第二种方法。

注意我这里只在最后保存了一次ckpt,实际应该在训练中每隔一段时间就保存一次的。

saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # 保存pb和ckpt
 print('save pb file and ckpt file')
 tf.train.write_graph(sess.graph_def, graph_location, "graph.pb",as_text=False)
 checkpoint_path = os.path.join(graph_location, "model.ckpt")
 saver.save(sess, checkpoint_path, global_step=max_step)

方式2:convert_variables_to_constants

我实际使用的就是这种方法。

看名字也知道,就是把变量转化为常量保存,这样就可以愉快的加载使用了。

注意这里需要指明保存的输出节点,我的输出节点为'out/fc2'(我猜测会根据输出节点的依赖推断哪些部分是训练用到的,推理时用不到)。关于输出节点的名字是有规律的,其中out是一个name_scope名字,fc2是op节点的名字。

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 # 固化模型
 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

上述代码会在训练后把训练好的计算图和参数保存到frozen_graph.pb文件。后续就可以用这个模型来测试图片了。

方式2的完整训练和保存模型代码

主要看main函数就行。另外注意deepnn函数最后节点的名字。

"""A deep MNIST classifier using convolutional layers.

See extensive documentation at
https://www.tensorflow.org/get_started/mnist/pros
"""
# Disable linter warnings to maintain consistency with tutorial.
# pylint: disable=invalid-name
# pylint: disable=g-bad-import-order

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile
import os

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants

import tensorflow as tf
FLAGS = None

def deepnn(x):
 """deepnn builds the graph for a deep net for classifying digits.

 Args:
 x: an input tensor with the dimensions (N_examples, 784), where 784 is the
 number of pixels in a standard MNIST image.

 Returns:
 A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
 equal to the logits of classifying the digit into one of 10 classes (the
 digits 0-9). keep_prob is a scalar placeholder for the probability of
 dropout.
 """
 # Reshape to use within a convolutional neural net.
 # Last dimension is for "features" - there is only one here, since images are
 # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
 with tf.name_scope('reshape'):
 x_image = tf.reshape(x, [-1, 28, 28, 1])

 # First convolutional layer - maps one grayscale image to 32 feature maps.
 with tf.name_scope('conv1'):
 W_conv1 = weight_variable([5, 5, 1, 32])
 b_conv1 = bias_variable([32])
 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)

 # Pooling layer - downsamples by 2X.
 with tf.name_scope('pool1'):
 h_pool1 = max_pool_2x2(h_conv1)

 # Second convolutional layer -- maps 32 feature maps to 64.
 with tf.name_scope('conv2'):
 W_conv2 = weight_variable([5, 5, 32, 64])
 b_conv2 = bias_variable([64])
 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

 # Second pooling layer.
 with tf.name_scope('pool2'):
 h_pool2 = max_pool_2x2(h_conv2)

 # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
 # is down to 7x7x64 feature maps -- maps this to 1024 features.
 with tf.name_scope('fc1'):
 W_fc1 = weight_variable([7 * 7 * 64, 1024])
 b_fc1 = bias_variable([1024])

 h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

 # Dropout - controls the complexity of the model, prevents co-adaptation of
 # features.
 with tf.name_scope('dropout'):
 keep_prob = tf.placeholder(tf.float32, name='ratio')
 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

 # Map the 1024 features to 10 classes, one for each digit
 with tf.name_scope('out'):
 W_fc2 = weight_variable([1024, 10])
 b_fc2 = bias_variable([10])

 y_conv = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, name='fc2')
 return y_conv, keep_prob

def conv2d(x, W):
 """conv2d returns a 2d convolution layer with full stride."""
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
 """max_pool_2x2 downsamples a feature map by 2X."""
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
   strides=[1, 2, 2, 1], padding='SAME')

def weight_variable(shape):
 """weight_variable generates a weight variable of a given shape."""
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)

def bias_variable(shape):
 """bias_variable generates a bias variable of a given shape."""
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

def main(_):
 # Import data
 mnist = input_data.read_data_sets(FLAGS.data_dir)

 # Create the model
 with tf.name_scope('input'):
 x = tf.placeholder(tf.float32, [None, 784], name='x')

 # Define loss and optimizer
 y_ = tf.placeholder(tf.int64, [None])

 # Build the graph for the deep net
 y_conv, keep_prob = deepnn(x)

 with tf.name_scope('loss'):
 cross_entropy = tf.losses.sparse_softmax_cross_entropy(
 labels=y_, logits=y_conv)
 cross_entropy = tf.reduce_mean(cross_entropy)

 with tf.name_scope('adam_optimizer'):
 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

 with tf.name_scope('accuracy'):
 correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
 correct_prediction = tf.cast(correct_prediction, tf.float32)
 accuracy = tf.reduce_mean(correct_prediction)

 graph_location = './model'
 print('Saving graph to: %s' % graph_location)
 train_writer = tf.summary.FileWriter(graph_location)
 train_writer.add_graph(tf.get_default_graph())

 saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # save pb file and ckpt file
 #print('save pb file and ckpt file')
 #tf.train.write_graph(sess.graph_def, graph_location, "graph.pb", as_text=False)
 #checkpoint_path = os.path.join(graph_location, "model.ckpt")
 #saver.save(sess, checkpoint_path, global_step=max_step)

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
 parser = argparse.ArgumentParser()
 parser.add_argument('--data_dir', type=str,
   default='./data',
   help='Directory for storing input data')
 FLAGS, unparsed = parser.parse_known_args()
 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

加载模型进行推理

上一节已经训练并导出了frozen_graph.pb。

这一节把它运行起来。

加载模型

下方的代码用来加载模型。推理时计算图里共两个placeholder需要填充数据,一个是图片(这不废话吗),一个是drouout_ratio,drouout_ratio用一个常量作为输入,后续就只需要输入图片了。

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

input_map参数并不是必须的。如果不用input_map,可以在run之前用tf.get_default_graph().get_tensor_by_name获取tensor的句柄。但是我觉得这种方法不是很友好,我这里没用这种方法。

注意input_map里的tensor名字是和搭计算图时的name_scope和op名字有关的,而且后面要补一个‘:0'(这点我还没细究)。

同时要注意,newInput_X的形状是[None, 784],第一维是batch大小,推理时和训练要一致。

(我用的是mnist图片,训练时每个bacth的形状是[batchsize, 784],每个图片是28x28)

运行模型

我是一张张图片单独测试的,运行模型之前先把图片变为[1, 784],以符合newInput_X的维数。

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

注意模型的输出是一个长度为10的一维数组,也就是计算图里全连接的输出。这里没有softmax,只要取最大值的下标即可得到结果。

输出结果:

full_path:./test_images/97_7.jpg
image_np_expanded shape:(1, 784)
result:[-1340.37145996 -283.72436523 1305.03320312 437.6053772 -413.69961548
 -1218.08166504 -1004.83807373 1953.33984375 42.00457001 -504.43829346]
result:7

full_path:./test_images/98_6.jpg
image_np_expanded shape:(1, 784)
result:[ 567.4041748 -550.20904541 623.83496094 -1152.56884766 -217.92695618
 1033.45239258 2496.44750977 -1139.23620605 -5.64091825 -615.28491211]
result:6

full_path:./test_images/99_9.jpg
image_np_expanded shape:(1, 784)
result:[ -532.26409912 -1429.47277832 -368.58096313 505.82876587 358.42163086
 -317.48199463 -1108.6829834 1198.08752441 289.12286377 3083.52539062]
result:9

加载模型进行推理的完整代码

import sys
import os
import cv2
import numpy as np
import tensorflow as tf
test_image_dir = './test_images/'

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 #output = tf.import_graph_def(graph_def)
 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

以上这篇tensorflow 20:搭网络,导出模型,运行模型的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
把大数据数字口语化(python与js)两种实现
Feb 21 Python
python的文件操作方法汇总
Nov 10 Python
分享一个简单的python读写文件脚本
Nov 25 Python
Python基于贪心算法解决背包问题示例
Nov 27 Python
Python实现按当前日期(年、月、日)创建多级目录的方法
Apr 26 Python
python使用turtle绘制分形树
Jun 22 Python
Python实现将多个空格换为一个空格.md的方法
Dec 20 Python
详解python列表生成式和列表生成式器区别
Mar 27 Python
用Cython加速Python到“起飞”(推荐)
Aug 01 Python
深度学习入门之Pytorch 数据增强的实现
Feb 26 Python
如何让python的运行速度得到提升
Jul 08 Python
用Python创建简易网站图文教程
Jun 11 Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
You might like
PHP与javascript的两种交互方式
2006/10/09 PHP
PHP mcrypt可逆加密算法分析
2011/07/19 PHP
比较简单实用的PHP无限分类源码分享(思路不错)
2011/10/13 PHP
PHP在网页中动态生成PDF文件详细教程
2014/07/05 PHP
如何在HTML 中嵌入 PHP 代码
2015/05/13 PHP
php+curl 发送图片处理代码分享
2015/07/09 PHP
PHP中filter函数校验数据的方法详解
2015/07/31 PHP
php实现概率性随机抽奖代码
2016/01/02 PHP
Javascript Object.extend
2010/05/18 Javascript
JavaScript中的函数嵌套使用
2015/06/04 Javascript
jquery实现表单验证并阻止非法提交
2015/07/09 Javascript
JavaScript页面实时显示当前时间实例代码
2016/10/23 Javascript
微信小程序 购物车简单实例
2016/10/24 Javascript
JS绘制微信小程序画布时钟
2016/12/24 Javascript
bootstrap手风琴制作方法详解
2017/01/11 Javascript
简单实现js悬浮导航效果
2017/02/05 Javascript
Bootstrap table使用方法总结
2017/05/10 Javascript
vue中appear的用法
2017/08/17 Javascript
JS中关于正则的巧妙操作
2017/08/31 Javascript
vue以组件或者插件的形式实现throttle或者debounce
2019/05/22 Javascript
es6 super关键字的理解与应用实例分析
2020/02/15 Javascript
python随机生成指定长度密码的方法
2015/04/04 Python
python爬取淘宝商品详情页数据
2018/02/23 Python
Python格式化输出字符串方法小结【%与format】
2018/10/29 Python
Django 创建/删除用户的示例代码
2019/07/24 Python
Python综合应用名片管理系统案例详解
2020/01/03 Python
Python读取分割压缩TXT文本文件实例
2020/02/14 Python
windows10环境下用anaconda和VScode配置的图文教程
2020/03/30 Python
利用python制作拼图小游戏的全过程
2020/12/04 Python
CSS3 rgb and rgba(透明色)的使用详解
2020/09/25 HTML / CSS
请说出这段代码执行后a和b的值分别是多少
2015/03/28 面试题
护校行动方案
2014/05/31 职场文书
《惊弓之鸟》教学反思
2016/02/20 职场文书
教你如何让spark sql写mysql的时候支持update操作
2022/02/15 MySQL
再谈python_tkinter弹出对话框创建
2022/03/20 Python
td 内容自动换行 table表格td设置宽度后文字太多自动换行
2022/12/24 HTML / CSS