tensorflow 20:搭网络,导出模型,运行模型的实例


Posted in Python onMay 26, 2020

概述

以前自己都利用别人搭好的工程,修改过来用,很少把模型搭建、导出模型、加载模型运行走一遍,搞了一遍才知道这个事情也不是那么简单的。

搭建模型和导出模型

参考《TensorFlow固化模型》,导出固化的模型有两种方式.

方式1:导出pb图结构和ckpt文件,然后用 freeze_graph 工具冻结生成一个pb(包含结构和参数)

在我的代码里测试了生成pb图结构和ckpt文件,但是没接着往下走,感觉有点麻烦。我用的是第二种方法。

注意我这里只在最后保存了一次ckpt,实际应该在训练中每隔一段时间就保存一次的。

saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # 保存pb和ckpt
 print('save pb file and ckpt file')
 tf.train.write_graph(sess.graph_def, graph_location, "graph.pb",as_text=False)
 checkpoint_path = os.path.join(graph_location, "model.ckpt")
 saver.save(sess, checkpoint_path, global_step=max_step)

方式2:convert_variables_to_constants

我实际使用的就是这种方法。

看名字也知道,就是把变量转化为常量保存,这样就可以愉快的加载使用了。

注意这里需要指明保存的输出节点,我的输出节点为'out/fc2'(我猜测会根据输出节点的依赖推断哪些部分是训练用到的,推理时用不到)。关于输出节点的名字是有规律的,其中out是一个name_scope名字,fc2是op节点的名字。

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 # 固化模型
 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

上述代码会在训练后把训练好的计算图和参数保存到frozen_graph.pb文件。后续就可以用这个模型来测试图片了。

方式2的完整训练和保存模型代码

主要看main函数就行。另外注意deepnn函数最后节点的名字。

"""A deep MNIST classifier using convolutional layers.

See extensive documentation at
https://www.tensorflow.org/get_started/mnist/pros
"""
# Disable linter warnings to maintain consistency with tutorial.
# pylint: disable=invalid-name
# pylint: disable=g-bad-import-order

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile
import os

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants

import tensorflow as tf
FLAGS = None

def deepnn(x):
 """deepnn builds the graph for a deep net for classifying digits.

 Args:
 x: an input tensor with the dimensions (N_examples, 784), where 784 is the
 number of pixels in a standard MNIST image.

 Returns:
 A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
 equal to the logits of classifying the digit into one of 10 classes (the
 digits 0-9). keep_prob is a scalar placeholder for the probability of
 dropout.
 """
 # Reshape to use within a convolutional neural net.
 # Last dimension is for "features" - there is only one here, since images are
 # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
 with tf.name_scope('reshape'):
 x_image = tf.reshape(x, [-1, 28, 28, 1])

 # First convolutional layer - maps one grayscale image to 32 feature maps.
 with tf.name_scope('conv1'):
 W_conv1 = weight_variable([5, 5, 1, 32])
 b_conv1 = bias_variable([32])
 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)

 # Pooling layer - downsamples by 2X.
 with tf.name_scope('pool1'):
 h_pool1 = max_pool_2x2(h_conv1)

 # Second convolutional layer -- maps 32 feature maps to 64.
 with tf.name_scope('conv2'):
 W_conv2 = weight_variable([5, 5, 32, 64])
 b_conv2 = bias_variable([64])
 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

 # Second pooling layer.
 with tf.name_scope('pool2'):
 h_pool2 = max_pool_2x2(h_conv2)

 # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
 # is down to 7x7x64 feature maps -- maps this to 1024 features.
 with tf.name_scope('fc1'):
 W_fc1 = weight_variable([7 * 7 * 64, 1024])
 b_fc1 = bias_variable([1024])

 h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

 # Dropout - controls the complexity of the model, prevents co-adaptation of
 # features.
 with tf.name_scope('dropout'):
 keep_prob = tf.placeholder(tf.float32, name='ratio')
 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

 # Map the 1024 features to 10 classes, one for each digit
 with tf.name_scope('out'):
 W_fc2 = weight_variable([1024, 10])
 b_fc2 = bias_variable([10])

 y_conv = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, name='fc2')
 return y_conv, keep_prob

def conv2d(x, W):
 """conv2d returns a 2d convolution layer with full stride."""
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
 """max_pool_2x2 downsamples a feature map by 2X."""
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
   strides=[1, 2, 2, 1], padding='SAME')

def weight_variable(shape):
 """weight_variable generates a weight variable of a given shape."""
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)

def bias_variable(shape):
 """bias_variable generates a bias variable of a given shape."""
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

def main(_):
 # Import data
 mnist = input_data.read_data_sets(FLAGS.data_dir)

 # Create the model
 with tf.name_scope('input'):
 x = tf.placeholder(tf.float32, [None, 784], name='x')

 # Define loss and optimizer
 y_ = tf.placeholder(tf.int64, [None])

 # Build the graph for the deep net
 y_conv, keep_prob = deepnn(x)

 with tf.name_scope('loss'):
 cross_entropy = tf.losses.sparse_softmax_cross_entropy(
 labels=y_, logits=y_conv)
 cross_entropy = tf.reduce_mean(cross_entropy)

 with tf.name_scope('adam_optimizer'):
 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

 with tf.name_scope('accuracy'):
 correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
 correct_prediction = tf.cast(correct_prediction, tf.float32)
 accuracy = tf.reduce_mean(correct_prediction)

 graph_location = './model'
 print('Saving graph to: %s' % graph_location)
 train_writer = tf.summary.FileWriter(graph_location)
 train_writer.add_graph(tf.get_default_graph())

 saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # save pb file and ckpt file
 #print('save pb file and ckpt file')
 #tf.train.write_graph(sess.graph_def, graph_location, "graph.pb", as_text=False)
 #checkpoint_path = os.path.join(graph_location, "model.ckpt")
 #saver.save(sess, checkpoint_path, global_step=max_step)

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
 parser = argparse.ArgumentParser()
 parser.add_argument('--data_dir', type=str,
   default='./data',
   help='Directory for storing input data')
 FLAGS, unparsed = parser.parse_known_args()
 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

加载模型进行推理

上一节已经训练并导出了frozen_graph.pb。

这一节把它运行起来。

加载模型

下方的代码用来加载模型。推理时计算图里共两个placeholder需要填充数据,一个是图片(这不废话吗),一个是drouout_ratio,drouout_ratio用一个常量作为输入,后续就只需要输入图片了。

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

input_map参数并不是必须的。如果不用input_map,可以在run之前用tf.get_default_graph().get_tensor_by_name获取tensor的句柄。但是我觉得这种方法不是很友好,我这里没用这种方法。

注意input_map里的tensor名字是和搭计算图时的name_scope和op名字有关的,而且后面要补一个‘:0'(这点我还没细究)。

同时要注意,newInput_X的形状是[None, 784],第一维是batch大小,推理时和训练要一致。

(我用的是mnist图片,训练时每个bacth的形状是[batchsize, 784],每个图片是28x28)

运行模型

我是一张张图片单独测试的,运行模型之前先把图片变为[1, 784],以符合newInput_X的维数。

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

注意模型的输出是一个长度为10的一维数组,也就是计算图里全连接的输出。这里没有softmax,只要取最大值的下标即可得到结果。

输出结果:

full_path:./test_images/97_7.jpg
image_np_expanded shape:(1, 784)
result:[-1340.37145996 -283.72436523 1305.03320312 437.6053772 -413.69961548
 -1218.08166504 -1004.83807373 1953.33984375 42.00457001 -504.43829346]
result:7

full_path:./test_images/98_6.jpg
image_np_expanded shape:(1, 784)
result:[ 567.4041748 -550.20904541 623.83496094 -1152.56884766 -217.92695618
 1033.45239258 2496.44750977 -1139.23620605 -5.64091825 -615.28491211]
result:6

full_path:./test_images/99_9.jpg
image_np_expanded shape:(1, 784)
result:[ -532.26409912 -1429.47277832 -368.58096313 505.82876587 358.42163086
 -317.48199463 -1108.6829834 1198.08752441 289.12286377 3083.52539062]
result:9

加载模型进行推理的完整代码

import sys
import os
import cv2
import numpy as np
import tensorflow as tf
test_image_dir = './test_images/'

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 #output = tf.import_graph_def(graph_def)
 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

以上这篇tensorflow 20:搭网络,导出模型,运行模型的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
寻找网站后台地址的python脚本
Sep 01 Python
在Python中处理字符串之ljust()方法的使用简介
May 19 Python
Windows和Linux下Python输出彩色文字的方法教程
May 02 Python
Python3利用SMTP协议发送E-mail电子邮件的方法
Sep 30 Python
python Pygame的具体使用讲解
Nov 03 Python
python控制nao机器人身体动作实例详解
Apr 29 Python
python使用paramiko模块通过ssh2协议对交换机进行配置的方法
Jul 25 Python
pytorch 实现模型不同层设置不同的学习率方式
Jan 06 Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 Python
Python爬取新型冠状病毒“谣言”新闻进行数据分析
Feb 16 Python
Python用摘要算法生成token及检验token的示例代码
Dec 01 Python
通过python-pptx模块操作ppt文件的方法
Dec 26 Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
You might like
用在PHP里的JS打印函数
2006/10/09 PHP
php学习之 认清变量的作用范围
2010/01/26 PHP
php中addslashes函数与sql防注入
2014/11/17 PHP
详解PHP序列化反序列化的方法
2015/10/27 PHP
在PHP中使用FastCGI解析漏洞及修复方案
2015/11/10 PHP
高质量PHP代码的50个实用技巧必备(下)
2016/01/22 PHP
简单谈谈php延迟静态绑定
2016/01/26 PHP
PHP生成随机数的方法总结
2018/03/01 PHP
Laravel 登录后清空COOKIE的操作方法
2019/10/14 PHP
用ADODB.Stream转换
2007/01/22 Javascript
jQuery的实现原理的模拟代码 -2 数据部分
2010/08/01 Javascript
boxy基于jquery的弹出层对话框插件扩展应用 弹出层选择器
2010/11/21 Javascript
javascript继承之为什么要继承
2012/11/10 Javascript
jQuery语法总结和注意事项小结
2012/11/11 Javascript
jquery获取焦点和失去焦点事件代码
2013/04/21 Javascript
Google Maps API地图应用示例分享
2014/10/23 Javascript
JavaScript操作Oracle数据库示例
2015/03/06 Javascript
两种JS实现屏蔽鼠标右键的方法
2020/08/20 Javascript
简单说明Python中的装饰器的用法
2015/04/24 Python
详解pyqt5 动画在QThread线程中无法运行问题
2018/05/05 Python
numpy.linspace 生成等差数组的方法
2018/07/02 Python
将tensorflow的ckpt模型存储为npy的实例
2018/07/09 Python
python selenium 弹出框处理的实现
2019/02/26 Python
浅析Python 字符编码与文件处理
2020/09/24 Python
Jupyter Notebook 远程访问配置详解
2021/01/11 Python
CSS3为背景图设置遮罩并解决遮罩样式继承问题
2020/06/22 HTML / CSS
HTML5 weui使用笔记
2019/11/21 HTML / CSS
美国狗旅行和户外用品领先供应商:kurgo
2020/08/18 全球购物
部队学习十八大感言
2014/01/11 职场文书
个人优缺点自我评价
2014/01/27 职场文书
可口可乐广告词
2014/03/20 职场文书
元旦晚会主持词
2014/03/24 职场文书
《合作意向书》怎么写?
2019/08/20 职场文书
利用Java设置Word文本框中的文字旋转方向的实现方法
2021/06/28 Java/Android
Python获取江苏疫情实时数据及爬虫分析
2021/08/02 Python
十大好看的穿越动漫排名:《瑞克和莫蒂》第一,国漫《有药》在榜
2022/03/18 日漫