30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之编写类之四再论继承
Oct 11 Python
python实现挑选出来100以内的质数
Mar 24 Python
Python的Flask开发框架简单上手笔记
Nov 16 Python
python3使用pyqt5制作一个超简单浏览器的实例
Oct 19 Python
python字符串替换第一个字符串的方法
Jun 26 Python
python区块及区块链的开发详解
Jul 03 Python
Python3从零开始搭建一个语音对话机器人的实现
Aug 23 Python
opencv-python 读取图像并转换颜色空间实例
Dec 09 Python
解决python 执行sql语句时所传参数含有单引号的问题
Jun 06 Python
Android Q之气泡弹窗的实现示例
Jun 23 Python
Python urlopen()参数代码示例解析
Dec 10 Python
详解解决jupyter不能使用pytorch的问题
Feb 18 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
php读取XML的常见方法实例总结
2017/04/25 PHP
laravel7学习之无限级分类的最新实现方法
2020/09/30 PHP
javascript 有趣而诡异的数组
2009/04/06 Javascript
js获取RadioButtonList的Value/Text及选中值等信息实现代码
2013/03/05 Javascript
javaScript array(数组)使用字符串作为数组下标的方法
2013/11/19 Javascript
javascript中parentNode,childNodes,children的应用详解
2013/12/17 Javascript
我的Node.js学习之路(二)NPM模块管理
2014/07/06 Javascript
javascript实现根据身份证号读取相关信息
2014/12/17 Javascript
JS判断键盘是否按的回车键并触发指定按钮点击操作的方法
2017/02/13 Javascript
在vue中实现点击选择框阻止弹出层消失的方法
2018/09/15 Javascript
vue将毫秒数转化为正常日期格式的实例
2018/09/16 Javascript
Nuxt.js开启SSR渲染的教程详解
2018/11/30 Javascript
在Node.js中将SVG图像转换为PNG,JPEG,TIFF,WEBP和HEIF格式的方法
2019/08/22 Javascript
vue使用map代替Aarry数组循环遍历的方法
2020/04/30 Javascript
Vue + Element-ui的下拉框el-select获取额外参数详解
2020/08/14 Javascript
vue项目里面引用svg文件并给svg里面的元素赋值
2020/08/17 Javascript
在vue中使用eslint,配合vscode的操作
2020/11/09 Javascript
k8s node节点重新加入master集群的实现
2021/02/22 Javascript
[01:18:45]DOTA2-DPC中国联赛 正赛 DLG vs Dragon BO3 第三场2月1日
2021/03/11 DOTA
Django数据库操作的实例(增删改查)
2017/09/04 Python
用Python下载一个网页保存为本地的HTML文件实例
2018/05/21 Python
Pipenv一键搭建python虚拟环境的方法
2018/05/22 Python
Python二维码生成识别实例详解
2019/07/16 Python
Python3 A*寻路算法实现方式
2019/12/24 Python
用CSS3将你的设计带入下个高度
2009/08/08 HTML / CSS
美国床垫和床上用品公司:Nest Bedding
2017/06/12 全球购物
Ben Sherman官方网站:英国男装品牌
2019/10/22 全球购物
乐高西班牙官方商店:LEGO Shop ES
2019/12/01 全球购物
给酒店员工的表扬信
2014/01/11 职场文书
入学生会自荐书范文
2014/02/05 职场文书
《记承天寺夜游》教学反思
2014/02/16 职场文书
2014年保洁员工作总结
2014/11/19 职场文书
李白故里导游词
2015/02/12 职场文书
情况说明书格式及范文
2019/06/24 职场文书
图文详解nginx日志切割的实现
2022/01/18 Servers
Python如何快速找到多个字典中的公共键(key)
2022/04/29 Python