30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之词典、字典、dict
Aug 29 Python
python排序方法实例分析
Apr 30 Python
python实现红包裂变算法
Feb 16 Python
Numpy之文件存取的示例代码
Aug 03 Python
opencv3/python 鼠标响应操作详解
Dec 11 Python
Python 2种方法求某个范围内的所有素数(质数)
Jan 31 Python
Python count函数使用方法实例解析
Mar 23 Python
python实现人脸签到系统
Apr 13 Python
python将dict中的unicode打印成中文实例
May 11 Python
python怎么判断模块安装完成
Jun 19 Python
python正则表达式的懒惰匹配和贪婪匹配说明
Jul 13 Python
python如何发送带有附件、正文为HTML的邮件
Feb 27 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
比较全的PHP 会话(session 时间设定)使用入门代码
2008/06/05 PHP
php使用curl和正则表达式抓取网页数据示例
2014/04/13 PHP
php实现在线考试系统【附源码】
2018/09/18 PHP
js用图作提交按钮或超连接
2008/03/26 Javascript
JS 自定义函数缺省值的设置方法
2010/05/05 Javascript
JavaScript基础语法让人疑惑的地方小结
2012/05/23 Javascript
jQuery在页面加载时动态修改图片尺寸的方法
2015/03/20 Javascript
在JavaScript中处理字符串之link()方法的使用
2015/06/08 Javascript
JavaScript常用标签和方法总结
2015/09/01 Javascript
JAVASCRIPT代码编写俄罗斯方块网页版
2015/11/26 Javascript
jQuery插件简单学习实例教程
2016/07/01 Javascript
Jquery Easyui菜单组件Menu使用详解(15)
2016/12/18 Javascript
node.js学习之事件模块Events的使用示例
2017/09/28 Javascript
jQuery中each方法的使用详解
2018/03/18 jQuery
ng-alain表单使用方式详解
2018/07/10 Javascript
Vue强制组件重新渲染的方法讨论
2020/02/03 Javascript
JavaScript使用setTimeout实现倒计时效果
2021/02/19 Javascript
[58:25]VP vs RNG 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
Python中shape计算矩阵的方法示例
2017/04/21 Python
详解Python文本操作相关模块
2017/06/22 Python
python模拟键盘输入 切换键盘布局过程解析
2019/08/15 Python
Python如何使用turtle库绘制图形
2020/02/26 Python
Python装饰器用法与知识点小结
2020/03/09 Python
python实现单张图像拼接与批量图片拼接
2020/03/23 Python
老生常谈CSS中的长度单位
2016/06/27 HTML / CSS
Bally澳大利亚官网:瑞士奢侈品牌
2018/11/01 全球购物
几道数据库的面试题或笔试题
2014/05/31 面试题
中国央视网签名寄语
2014/01/18 职场文书
医务工作者先进事迹材料
2014/01/26 职场文书
《孙权劝学》教学反思
2014/04/23 职场文书
入党积极分子学习优秀共产党员先进事迹思想汇报
2014/09/13 职场文书
销售顾问工作计划书
2014/09/15 职场文书
连锁超市项目计划书
2014/09/15 职场文书
司法局群众路线教育实践活动开展情况总结
2014/10/25 职场文书
2015年维修电工工作总结
2015/04/25 职场文书
通过feDisplacementMap和feImage实现水波特效
2022/04/24 HTML / CSS