python tensorflow学习之识别单张图片的实现的示例


Posted in Python onFebruary 09, 2018

假设我们已经安装好了tensorflow。

一般在安装好tensorflow后,都会跑它的demo,而最常见的demo就是手写数字识别的demo,也就是mnist数据集。

然而我们仅仅是跑了它的demo而已,可能很多人会有和我一样的想法,如果拿来一张数字图片,如何应用我们训练的网络模型来识别出来,下面我们就以mnist的demo来实现它。

1.训练模型

首先我们要训练好模型,并且把模型model.ckpt保存到指定文件夹

saver = tf.train.Saver()   
saver.save(sess, "model_data/model.ckpt")

将以上两行代码加入到训练的代码中,训练完成后保存模型即可,如果这部分有问题,你可以百度查阅资料,tensorflow怎么保存训练模型,在这里我们就不罗嗦了。

2.测试模型

我们训练好模型后,将它保存在了model_data文件夹中,你会发现文件夹中出现了4个文件

python tensorflow学习之识别单张图片的实现的示例

然后,我们就可以对这个模型进行测试了,将待检测图片放在images文件夹下,执行

# -*- coding:utf-8 -*-  
import cv2 
import tensorflow as tf 
import numpy as np 
from sys import path 
path.append('../..') 
from common import extract_mnist 
 
#初始化单个卷积核上的参数 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
#初始化单个卷积核上的偏置值 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
#输入特征x,用卷积核W进行卷积运算,strides为卷积核移动步长, 
#padding表示是否需要补齐边缘像素使输出图像大小不变 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#对x进行最大池化操作,ksize进行池化的范围, 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME') 
 
 
def main(): 
   
  #定义会话 
  sess = tf.InteractiveSession() 
   
  #声明输入图片数据,类别 
  x = tf.placeholder('float',[None,784]) 
  x_img = tf.reshape(x , [-1,28,28,1]) 
 
  W_conv1 = weight_variable([5, 5, 1, 32]) 
  b_conv1 = bias_variable([32]) 
  W_conv2 = weight_variable([5,5,32,64]) 
  b_conv2 = bias_variable([64]) 
  W_fc1 = weight_variable([7*7*64,1024]) 
  b_fc1 = bias_variable([1024]) 
  W_fc2 = weight_variable([1024,10]) 
  b_fc2 = bias_variable([10]) 
 
  saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)  
  saver.restore(sess , 'model_data/model.ckpt') 
 
  #进行卷积操作,并添加relu激活函数 
  h_conv1 = tf.nn.relu(conv2d(x_img,W_conv1) + b_conv1) 
  #进行最大池化 
  h_pool1 = max_pool_2x2(h_conv1) 
 
  #同理第二层卷积层 
  h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2) 
  h_pool2 = max_pool_2x2(h_conv2) 
   
  #将卷积的产出展开 
  h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64]) 
  #神经网络计算,并添加relu激活函数 
  h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1) + b_fc1) 
 
  #输出层,使用softmax进行多分类 
  y_conv=tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2) 
 
  # mnist_data_set = extract_mnist.MnistDataSet('../../data/') 
  # x_img , y = mnist_data_set.next_train_batch(1) 
  im = cv2.imread('images/888.jpg',cv2.IMREAD_GRAYSCALE).astype(np.float32) 
  im = cv2.resize(im,(28,28),interpolation=cv2.INTER_CUBIC) 
  #图片预处理 
  #img_gray = cv2.cvtColor(im , cv2.COLOR_BGR2GRAY).astype(np.float32) 
  #数据从0~255转为-0.5~0.5 
  img_gray = (im - (255 / 2.0)) / 255 
  #cv2.imshow('out',img_gray) 
  #cv2.waitKey(0) 
  x_img = np.reshape(img_gray , [-1 , 784]) 
 
  print x_img 
  output = sess.run(y_conv , feed_dict = {x:x_img}) 
  print 'the y_con :  ', '\n',output 
  print 'the predict is : ', np.argmax(output) 
 
  #关闭会话 
  sess.close() 
 
if __name__ == '__main__': 
  main()

ok,贴一下效果图

python tensorflow学习之识别单张图片的实现的示例

输出:

python tensorflow学习之识别单张图片的实现的示例

最后再贴一个cifar10的,感觉我的输入数据有点问题,因为直接读cifar10的数据测试是没问题的,但是换成自己的图片做预处理后输入结果就有问题,(参考:cv2读入的数据是BGR顺序,PIL读入的数据是RGB顺序,cifar10的数据是RGB顺序),哪位童鞋能指出来记得留言告诉我

# -*- coding:utf-8 -*-   
from sys import path 
import numpy as np 
import tensorflow as tf 
import time 
import cv2 
from PIL import Image 
path.append('../..') 
from common import extract_cifar10 
from common import inspect_image 
 
 
#初始化单个卷积核上的参数 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
#初始化单个卷积核上的偏置值 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
#卷积操作 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
 
 
def main(): 
  #定义会话 
  sess = tf.InteractiveSession() 
   
  #声明输入图片数据,类别 
  x = tf.placeholder('float',[None,32,32,3]) 
  y_ = tf.placeholder('float',[None,10]) 
 
  #第一层卷积层 
  W_conv1 = weight_variable([5, 5, 3, 64]) 
  b_conv1 = bias_variable([64]) 
  #进行卷积操作,并添加relu激活函数 
  conv1 = tf.nn.relu(conv2d(x,W_conv1) + b_conv1) 
  # pool1 
  pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME', name='pool1') 
  # norm1 
  norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm1') 
 
 
  #第二层卷积层 
  W_conv2 = weight_variable([5,5,64,64]) 
  b_conv2 = bias_variable([64]) 
  conv2 = tf.nn.relu(conv2d(norm1,W_conv2) + b_conv2) 
  # norm2 
  norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm2') 
  # pool2 
  pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],strides=[1, 2, 2, 1], padding='SAME', name='pool2') 
 
  #全连接层 
  #权值参数 
  W_fc1 = weight_variable([8*8*64,384]) 
  #偏置值 
  b_fc1 = bias_variable([384]) 
  #将卷积的产出展开 
  pool2_flat = tf.reshape(pool2,[-1,8*8*64]) 
  #神经网络计算,并添加relu激活函数 
  fc1 = tf.nn.relu(tf.matmul(pool2_flat,W_fc1) + b_fc1) 
   
  #全连接第二层 
  #权值参数 
  W_fc2 = weight_variable([384,192]) 
  #偏置值 
  b_fc2 = bias_variable([192]) 
  #神经网络计算,并添加relu激活函数 
  fc2 = tf.nn.relu(tf.matmul(fc1,W_fc2) + b_fc2) 
 
 
  #输出层,使用softmax进行多分类 
  W_fc2 = weight_variable([192,10]) 
  b_fc2 = bias_variable([10]) 
  y_conv=tf.maximum(tf.nn.softmax(tf.matmul(fc2, W_fc2) + b_fc2),1e-30) 
 
  # 
  saver = tf.train.Saver() 
  saver.restore(sess , 'model_data/model.ckpt') 
  #input 
  im = Image.open('images/dog8.jpg') 
  im.show() 
  im = im.resize((32,32)) 
  # r , g , b = im.split() 
  # im = Image.merge("RGB" , (r,g,b)) 
  print im.size , im.mode 
 
  im = np.array(im).astype(np.float32) 
  im = np.reshape(im , [-1,32*32*3]) 
  im = (im - (255 / 2.0)) / 255 
  batch_xs = np.reshape(im , [-1,32,32,3]) 
  #print batch_xs 
  #获取cifar10数据 
  # cifar10_data_set = extract_cifar10.Cifar10DataSet('../../data/') 
  # batch_xs, batch_ys = cifar10_data_set.next_train_batch(1) 
  # print batch_ys 
  output = sess.run(y_conv , feed_dict={x:batch_xs}) 
  print output 
  print 'the out put is :' , np.argmax(output) 
  #关闭会话 
  sess.close() 
 
if __name__ == '__main__': 
  main()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python给你的头像加上圣诞帽
Jan 04 Python
tensorflow实现简单的卷积神经网络
May 24 Python
python实现人人自动回复、抢沙发功能
Jun 08 Python
Numpy之文件存取的示例代码
Aug 03 Python
Python 正则表达式 re.match/re.search/re.sub的使用解析
Jul 22 Python
Pytorch抽取网络层的Feature Map(Vgg)实例
Aug 20 Python
Python实现屏幕录制功能的代码
Mar 02 Python
python opencv实现图片缺陷检测(讲解直方图以及相关系数对比法)
Apr 07 Python
python3.6使用SMTP协议发送邮件
May 20 Python
新版Pycharm中Matplotlib不会弹出独立的显示窗口的问题
Jun 02 Python
Scrapy实现模拟登录的示例代码
Feb 21 Python
教你怎么用PyCharm为同一服务器配置多个python解释器
May 31 Python
python删除服务器文件代码示例
Feb 09 #Python
详解Python使用tensorflow入门指南
Feb 09 #Python
python编程测试电脑开启最大线程数实例代码
Feb 09 #Python
Python实现对一个函数应用多个装饰器的方法示例
Feb 09 #Python
Python+PIL实现支付宝AR红包
Feb 09 #Python
Python 实现12306登录功能实例代码
Feb 09 #Python
Python多层装饰器用法实例分析
Feb 09 #Python
You might like
PHP循环获取GET和POST值的代码
2008/04/09 PHP
PHP和Mysqlweb应用开发核心技术-第1部分 Php基础-2 php语言介绍
2011/07/03 PHP
探讨如何使用SimpleXML函数来加载和解析XML文档
2013/06/07 PHP
VB中的RasEnumConnections函数返回632错误解决方法
2014/07/29 PHP
3款值得推荐的微信开发开源框架
2014/10/28 PHP
PHP实现的简单排列组合算法应用示例
2017/06/20 PHP
php日志函数error_log用法实例分析
2019/09/23 PHP
PHP生成图表pChart的示例解析
2020/07/31 PHP
分享十五个最佳jQuery 幻灯插件和教程
2010/03/27 Javascript
js鼠标左右键 键盘值小结
2010/06/11 Javascript
jQuery Tools tab(幻灯片)
2012/07/14 Javascript
jQuery中replaceWith()方法用法实例
2014/12/25 Javascript
学习JavaScript设计模式(单例模式)
2015/11/26 Javascript
解决低版本的浏览器不支持es6的import问题
2018/03/09 Javascript
JS实现模糊查询带下拉匹配效果
2018/06/21 Javascript
JQuery 实现文件下载的常用方法分析
2019/10/29 jQuery
layui 弹出层值回传解决方式
2019/11/14 Javascript
在Python中使用异步Socket编程性能测试
2014/06/25 Python
Python编程实现双链表,栈,队列及二叉树的方法示例
2017/11/01 Python
Windows 7下Python Web环境搭建图文教程
2018/03/20 Python
Python3实现购物车功能
2018/04/18 Python
Python使用matplotlib和pandas实现的画图操作【经典示例】
2018/06/13 Python
利用python循环创建多个文件的方法
2018/10/25 Python
在pycharm中python切换解释器失败的解决方法
2018/10/29 Python
selenium设置proxy、headers的方法(phantomjs、Chrome、Firefox)
2018/11/29 Python
python实现串口通信的示例代码
2020/02/10 Python
keras 实现轻量级网络ShuffleNet教程
2020/06/19 Python
专门经营化妆刷的美国彩妆品牌:Sigma Beauty
2017/09/11 全球购物
团员的自我评价
2013/12/01 职场文书
红色革命电影观后感
2015/06/18 职场文书
分布式锁为什么要选择Zookeeper而不是Redis?看完这篇你就明白了
2021/05/21 Redis
解决MultipartFile.transferTo(dest) 报FileNotFoundExcep的问题
2021/07/01 Java/Android
实例详解Python的进程,线程和协程
2022/03/13 Python
十大最强水系宝可梦,最美宝可梦排第三,榜首大家最熟悉
2022/03/18 日漫
Golang 字符串的常见操作
2022/04/19 Golang
Spring Boot接口定义和全局异常统一处理
2022/04/20 Java/Android