30秒轻松实现TensorFlow物体检测


Posted in Python onMarch 14, 2018

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image

接下来进行环境设置

%matplotlib inline 
sys.path.append("..")

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd())

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8)

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现将文本转换成语音的方法
May 28 Python
Python的网络编程库Gevent的安装及使用技巧
Jun 24 Python
Python标准库笔记struct模块的使用
Feb 22 Python
Python实现求一个集合所有子集的示例
May 04 Python
Django项目中model的数据处理以及页面交互方法
May 30 Python
想学python 这5本书籍你必看!
Dec 11 Python
Python字符串的常见操作实例小结
Apr 08 Python
python 整数越界问题详解
Jun 27 Python
Python中six模块基础用法
Dec 08 Python
Python3实现监控新型冠状病毒肺炎疫情的示例代码
Feb 13 Python
Python进程Multiprocessing模块原理解析
Feb 28 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 Python
tensorflow识别自己手写数字
Mar 14 #Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
You might like
PHP中使用unset销毁变量并内存释放问题
2012/07/05 PHP
windows服务器中检测PHP SSL是否开启以及开启SSL的方法
2014/04/25 PHP
基于php+MySql实现学生信息管理系统实例
2020/08/04 PHP
js树形控件脚本代码
2008/07/24 Javascript
js 提交和设置表单的值
2008/12/19 Javascript
纯JS实现的批量图片预览加载功能
2011/08/14 Javascript
js 走马灯简单实例
2013/11/21 Javascript
Jquery给基本控件的取值、赋值示例
2014/05/23 Javascript
JavaScript中的原型和继承详解(图文)
2014/07/18 Javascript
jQuery链使用指南
2015/01/20 Javascript
javascript使用avalon绑定实现checkbox全选
2015/05/06 Javascript
深入理解Node.js 事件循环和回调函数
2016/11/02 Javascript
jquery 标签 隔若干行加空白或者加虚线的方法
2016/12/07 Javascript
Vue.js实现简单动态数据处理
2017/02/13 Javascript
支持移动端原生js轮播图
2017/02/16 Javascript
使用watch监听路由变化和watch监听对象的实例
2018/02/24 Javascript
如何使用pm2快速将项目部署到远程服务器
2019/03/12 Javascript
[55:47]DOTA2上海特级锦标赛C组小组赛#2 LGD VS Newbee第三局
2016/02/27 DOTA
[51:06]DOTA2-DPC中国联赛 正赛 Elephant vs Aster BO3 第二场 1月26日
2021/03/11 DOTA
Python文件去除注释的方法
2015/05/25 Python
Python进阶之尾递归的用法实例
2018/01/31 Python
Python实现合并两个有序链表的方法示例
2019/01/31 Python
python3 map函数和filter函数详解
2019/08/26 Python
Python安装whl文件过程图解
2020/02/18 Python
python2 对excel表格操作完整示例
2020/02/23 Python
通过Python pyecharts输出保存图片代码实例
2020/11/25 Python
面向对象概念面试题(.NET)
2016/11/04 面试题
python re模块和正则表达式
2021/03/24 Python
国际金融专业大学生职业生涯规划书
2013/12/28 职场文书
争论的故事教学反思
2014/02/06 职场文书
简单的大学生自我鉴定
2014/02/18 职场文书
硕士生工作推荐信
2014/03/07 职场文书
乡镇综治宣传月活动总结
2014/07/02 职场文书
浅谈JS的二进制家族
2021/05/09 Javascript
mysql sql常用语句大全
2022/06/21 MySQL
MySQL数据库查询之多表查询总结
2022/08/05 MySQL