使用Keras训练好的.h5模型来测试一个实例


Posted in Python onJuly 06, 2020

环境:python 3.6 +opencv3+Keras

训练集:MNIST

下面划重点:因为MNIST使用的是黑底白字的图片,所以你自己手写数字的时候一定要注意把得到的图片也改成黑底白字的,否则会识别错(至少我得到的结论是这样的 ,之前用白底黑字的图总是识别出错)

注意:需要测试图片需要为与训练模时相同大小的图片,RGB图像需转为gray

代码:

import cv2
import numpy as np
from keras.models import load_model

model = load_model('fm_cnn_BN.h5') #选取自己的.h模型名称
image = cv2.imread('6_b.png')
img = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY) # RGB图像转为gray

#需要用reshape定义出例子的个数,图片的 通道数,图片的长与宽。具体的参加keras文档
img = (img.reshape(1, 1, 28, 28)).astype('int32')/255 
predict = model.predict_classes(img)
print ('识别为:')
print (predict)

cv2.imshow("Image1", image)
cv2.waitKey(0)

补充知识:keras转tf并加速(1)Keras转TensorFlow,并调用转换后模型进行预测

由于方便快捷,所以先使用Keras来搭建网络并进行训练,得到比较好的模型后,这时候就该考虑做成服务使用的问题了,TensorFlow的serving就很合适,所以需要把Keras保存的模型转为TensorFlow格式来使用。

Keras模型转TensorFlow

其实由于TensorFlow本身以及把Keras作为其高层简化API,且也是建议由浅入深地来研究应用,TensorFlow本身就对Keras的模型格式转化有支持,所以核心的代码很少。这里给出一份代码:https://github.com/amir-abdi/keras_to_tensorflow,作者提供了一份很好的工具,能够满足绝大多数人的需求了。原理很简单:原理很简单,首先用 Keras 读取 .h5 模型文件,然后用 tensorflow 的 convert_variables_to_constants 函数将所有变量转换成常量,最后再 write_graph 就是一个包含了网络以及参数值的 .pb 文件了。

如果你的Keras模型是一个包含了网络结构和权重的h5文件,那么使用下面的命令就可以了:

python keras_to_tensorflow.py 
 --input_model="path/to/keras/model.h5" 
 --output_model="path/to/save/model.pb"

两个参数,一个输入路径,一个输出路径。输出路径即使你没创建好,代码也会帮你创建。建议使用绝对地址。此外作者还做了很多选项,比如如果你的keras模型文件分为网络结构和权重两个文件也可以支持,或者你想给转化后的网络节点编号,或者想在TensorFlow下继续训练等等,这份代码都是支持的,只是使用上需要输入不同的参数来设置。

如果转换成功则输出如下:

begin====================================================
I1229 14:29:44.819010 140709034264384 keras_to_tf.py:119] Input nodes names are: [u'input_1']
I1229 14:29:44.819385 140709034264384 keras_to_tf.py:137] Converted output node names are: [u'dense_2/Sigmoid']
INFO:tensorflow:Froze 322 variables.
I1229 14:29:47.091161 140709034264384 tf_logging.py:82] Froze 322 variables.
Converted 322 variables to const ops.
I1229 14:29:48.504235 140709034264384 keras_to_tf.py:170] Saved the freezed graph at /path/to/save/model.pb

这里首先把输入的层和输出的层名字给出来了,也就是“input_1”和“dense_2/Sigmoid”,这两个下面会用到。另外还告诉你冻结了多少个变量,以及你输出的模型路径,pb文件就是TensorFlow下的模型文件。

使用TensorFlow模型

转换后我们当然要使用一下看是否转换成功,其实也就是TensorFlow的常见代码,如果只用过Keras的,可以参考一下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
import cv2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
 
# img = cv2.imread(os.path.expanduser('/test_imgs/img_1.png'))
# img = cv2.resize(img, dsize=(1000, 1000), interpolation=cv2.INTER_LINEAR)
# img = img.astype(float)
# img /= 255
# img = np.array([img])
 
# 初始化TensorFlow的session
with tf.Session() as sess:
 # 读取得到的pb文件加载模型
 with gfile.FastGFile("/path/to/save/model.pb",'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 # 把图加到session中
 tf.import_graph_def(graph_def, name='')
 
 # 获取当前计算图
 graph = tf.get_default_graph()
 
 # 从图中获输出那一层
 pred = graph.get_tensor_by_name("dense_2/Sigmoid:0")
 
 # 运行并预测输入的img
 res = sess.run(pred, feed_dict={"input_1:0": img})
 
 # 执行得到结果
 pred_index = res[0][0]
 print('Predict:', pred_index)

在代码中可以看到,我们用到了上面得到的输入层和输出层的名称,但是在后面加了一个“:0”,也就是索引,因为名称只是指定了一个层,大部分层的输出都是一个tensor,但依然有输出多个tensor的层,所以需要制定是第几个输出,对于一个输出的情况,那就是索引0了。输入同理。

如果你输出res,会得到这样的结果:

('Predict:', array([[0.9998584]], dtype=float32))

这也就是为什么我们要取res[0][0]了,这个输出其实取决于具体的需求,因为这里我是对一张图做二分类预测,所以会得到这样一个结果

运行的结果如果和使用Keras模型时一样,那就说明转换成功了!

以上这篇使用Keras训练好的.h5模型来测试一个实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxpython中利用线程防止假死的实现方法
Aug 11 Python
使用Python压缩和解压缩zip文件的教程
May 06 Python
Python urllib、urllib2、httplib抓取网页代码实例
May 09 Python
Python语法快速入门指南
Oct 12 Python
mac系统安装Python3初体验
Jan 02 Python
对python中词典的values值的修改或新增KEY详解
Jan 20 Python
python nmap实现端口扫描器教程
May 28 Python
Python字典生成式、集合生成式、生成器用法实例分析
Jan 07 Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 Python
python解释器pycharm安装及环境变量配置教程图文详解
Feb 26 Python
python 可视化库PyG2Plot的使用
Jan 21 Python
python编程简单几行代码实现视频转换Gif示例
Oct 05 Python
Keras实现DenseNet结构操作
Jul 06 #Python
基于Python和C++实现删除链表的节点
Jul 06 #Python
基于Python 的语音重采样函数解析
Jul 06 #Python
python interpolate插值实例
Jul 06 #Python
基于Python实现2种反转链表方法代码实例
Jul 06 #Python
简单了解Django项目应用创建过程
Jul 06 #Python
如何在mac下配置python虚拟环境
Jul 06 #Python
You might like
《超神学院》霸气归来, 天使彦上演维多利亚的秘密
2020/03/02 国漫
消息持续发送的完整例子
2006/10/09 PHP
phpmyadmin 3.4 空密码登录的实现方法
2010/05/29 PHP
php基础学习之变量的使用
2011/06/09 PHP
PHP实现过滤各种HTML标签
2015/05/17 PHP
详解yii2实现分库分表的方案与思路
2017/02/03 PHP
php + WebUploader实现图片批量上传功能
2019/05/06 PHP
解决php extension 加载顺序问题
2019/08/16 PHP
IE不出现Flash激活框的小发现的js实现方法
2007/09/07 Javascript
js获取单元格自定义属性值的代码(IE/Firefox)
2010/04/05 Javascript
jquery中获取元素的几种方式小结
2011/07/05 Javascript
Js冒泡事件详解及阻止示例
2014/03/21 Javascript
Angular用来控制元素的展示与否的原生指令介绍
2015/01/07 Javascript
Select下拉框模糊查询功能实现代码
2016/07/22 Javascript
javascript使用闭包模拟对象的私有属性和方法
2016/10/05 Javascript
angularjs实现首页轮播图效果
2017/04/14 Javascript
微信小程序中显示html格式内容的方法
2017/04/25 Javascript
js实现1,2,3,5数字按照概率生成
2017/09/12 Javascript
判断滚动条滑到底部触发事件(实例讲解)
2017/11/15 Javascript
[01:23:35]Ti4主赛事胜者组 DK vs EG 1
2014/07/19 DOTA
[53:15]2018DOTA2亚洲邀请赛3月29日 小组赛A组 LGD VS TNC
2018/03/30 DOTA
[46:28]EG vs Liquid 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.23
2019/09/05 DOTA
[51:50]完美世界DOTA2联赛 Magma vs GXR 第一场 11.07
2020/11/10 DOTA
Python+Pika+RabbitMQ环境部署及实现工作队列的实例教程
2016/06/29 Python
基于python 处理中文路径的终极解决方法
2018/04/12 Python
对python中的for循环和range内置函数详解
2018/04/17 Python
python自定义时钟类、定时任务类
2021/02/22 Python
django 多对多表的创建和插入代码实现
2019/09/09 Python
python批量处理txt文件的实例代码
2020/01/13 Python
Python urlopen()参数代码示例解析
2020/12/10 Python
英国时尚饰品和发饰购物网站:Claire’s
2017/07/04 全球购物
斯凯奇新西兰官网:SKECHERS新西兰
2018/02/22 全球购物
办公设备采购方案
2014/03/16 职场文书
售房协议书范本
2015/08/11 职场文书
图解上海144收音机
2021/04/22 无线电
自动在Windows中运行Python脚本并定时触发功能实现
2021/09/04 Python