Keras预训练的ImageNet模型实现分类操作


Posted in Python onJuly 07, 2020

本文主要介绍通过预训练的ImageNet模型实现图像分类,主要使用到的网络结构有:VGG16、InceptionV3、ResNet50、MobileNet。

代码:

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
# 加载模型
vgg_model = vgg16.VGG16(weights='imagenet')
inception_model = inception_v3.InceptionV3(weights='imagenet')
resnet_model = resnet50.ResNet50(weights='imagenet')
mobilenet_model = mobilenet.MobileNet(weights='imagenet')
 
# 导入所需的图像预处理模块
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.imagenet_utils import decode_predictions
import matplotlib.pyplot as plt
%matplotlib inline
 
filename= 'images/cat.jpg'
 
# 将图片输入到网络之前执行预处理
'''
1、加载图像,load_img
2、将图像从PIL格式转换为Numpy格式,image_to_array
3、将图像形成批次,Numpy的expand_dims
'''
# 以PIL格式加载图像
original = load_img(filename, target_size=(224, 224))
print('PIL image size', original.size)
plt.imshow(original)
plt.show()
 
# 将输入图像从PIL格式转换为Numpy格式
# In PIL-- 图像为(width, height, channel)
# In Numpy——图像为(height, width, channel)
numpy_image = img_to_array(original)
plt.imshow(np.uint8(numpy_image))
plt.show()
print('numpy array size', numpy_image.size)
 
# 将图像/图像转换为批量格式
# expand_dims将为特定轴上的数据添加额外的维度
# 网络的输入矩阵具有形式(批量大小,高度,宽度,通道)
# 因此,将额外的维度添加到轴0。
image_batch = np.expand_dims(numpy_image, axis=0)
print('image batch size', image_batch.shape)
plt.imshow(np.uint8(image_batch[0]))
 
# 使用各种网络进行预测
# 通过从批处理中的图像的每个通道中减去平均值来预处理输入。 
# 平均值是通过从ImageNet获得的所有图像的R,G,B像素的平均值获得的三个元素的阵列
# 获得每个类的发生概率
# 将概率转换为人类可读的标签
# VGG16 网络模型
# 对输入到VGG模型的图像进行预处理
processed_image = vgg16.preprocess_input(image_batch.copy())
 
# 获取预测得到的属于各个类别的概率
predictions = vgg_model.predict(processed_image)
# 输出预测值
# 将预测概率转换为类别标签
# 缺省情况下将得到最有可能的五种类别
label_vgg = decode_predictions(predictions)
label_vgg
 
# ResNet50网络模型
# 对输入到ResNet50模型的图像进行预处理
processed_image = resnet50.preprocess_input(image_batch.copy())
 
# 获取预测得到的属于各个类别的概率
predictions = resnet_model.predict(processed_image)
 
# 将概率转换为类标签
# 如果要查看前3个预测,可以使用top参数指定它
label_resnet = decode_predictions(predictions, top=3)
label_resnet
 
# MobileNet网络结构
# 对输入到MobileNet模型的图像进行预处理
processed_image = mobilenet.preprocess_input(image_batch.copy())
 
# 获取预测得到属于各个类别的概率
predictions = mobilenet_model.predict(processed_image)
 
# 将概率转换为类标签
label_mobilnet = decode_predictions(predictions)
label_mobilnet
 
# InceptionV3网络结构
# 初始网络的输入大小与其他网络不同。 它接受大小的输入(299,299)。
# 因此,根据它加载具有目标尺寸的图像。
# 加载图像为PIL格式
original = load_img(filename, target_size=(299, 299))
 
# 将PIL格式的图像转换为Numpy数组
numpy_image = img_to_array(original)
 
# 根据批量大小重塑数据
image_batch = np.expand_dims(numpy_image, axis=0)
 
# 将输入图像转换为InceptionV3所能接受的格式
processed_image = inception_v3.preprocess_input(image_batch.copy())
 
# 获取预测得到的属于各个类别的概率
predictions = inception_model.predict(processed_image)
 
# 将概率转换为类标签
label_inception = decode_predictions(predictions)
label_inception
 
import cv2
numpy_image = np.uint8(img_to_array(original)).copy()
numpy_image = cv2.resize(numpy_image,(900,900))
 
cv2.putText(numpy_image, "VGG16: {}, {:.2f}".format(label_vgg[0][0][1], label_vgg[0][0][2]) , (350, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
cv2.putText(numpy_image, "MobileNet: {}, {:.2f}".format(label_mobilenet[0][0][1], label_mobilenet[0][0][2]) , (350, 75), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
cv2.putText(numpy_image, "Inception: {}, {:.2f}".format(label_inception[0][0][1], label_inception[0][0][2]) , (350, 110), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
cv2.putText(numpy_image, "ResNet50: {}, {:.2f}".format(label_resnet[0][0][1], label_resnet[0][0][2]) , (350, 145), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
numpy_image = cv2.resize(numpy_image, (700,700))
cv2.imwrite("images/{}_output.jpg".format(filename.split('/')[-1].split('.')[0]),cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR))
 
plt.figure(figsize=[10,10])
plt.imshow(numpy_image)
plt.axis('off')

训练数据:

Keras预训练的ImageNet模型实现分类操作

运行结果:

Keras预训练的ImageNet模型实现分类操作

以上这篇Keras预训练的ImageNet模型实现分类操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单的Python2.7编程初学经验总结
Apr 01 Python
python统计文本字符串里单词出现频率的方法
May 26 Python
Python MySQL数据库连接池组件pymysqlpool详解
Jul 07 Python
matplotlib 输出保存指定尺寸的图片方法
May 24 Python
Python单向链表和双向链表原理与用法实例详解
Aug 31 Python
Python3.5文件读与写操作经典实例详解
May 01 Python
Python qqbot 实现qq机器人的示例代码
Jul 11 Python
Python3 A*寻路算法实现方式
Dec 24 Python
python mysql 字段与关键字冲突的解决方式
Mar 02 Python
pyinstaller打包找不到文件的问题解决
Apr 15 Python
两行代码解决Jupyter Notebook中文不能显示的问题
Apr 24 Python
Python爬虫之自动爬取某车之家各车销售数据
Jun 02 Python
Scrapy模拟登录赶集网的实现代码
Jul 07 #Python
scrapy框架携带cookie访问淘宝购物车功能的实现代码
Jul 07 #Python
Keras构建神经网络踩坑(解决model.predict预测值全为0.0的问题)
Jul 07 #Python
浅谈django框架集成swagger以及自定义参数问题
Jul 07 #Python
Django REST Swagger实现指定api参数
Jul 07 #Python
python中查看.db文件中表格的名字及表格中的字段操作
Jul 07 #Python
python db类用法说明
Jul 07 #Python
You might like
PHP中数组定义的几种方法
2013/09/01 PHP
php实现文件下载简单示例(代码实现文件下载)
2014/03/10 PHP
php通过array_push()函数添加多个变量到数组末尾的方法
2015/03/18 PHP
Yii2使用自带的UploadedFile实现的文件上传
2016/06/20 PHP
php实现异步将远程链接上内容(图片或内容)写到本地的方法
2016/11/30 PHP
PHP 获取指定地区的天气实例代码
2017/02/08 PHP
Yii2中多表关联查询hasOne hasMany的方法
2017/02/15 PHP
详解PHP字符串替换str_replace()函数四种用法
2017/10/13 PHP
PHP 实现链式操作
2021/03/09 PHP
JS跨域总结
2012/08/30 Javascript
js获取单选框或复选框值及操作
2012/12/18 Javascript
jQuery筛选器children()案例详解(图文)
2013/02/17 Javascript
JS对HTML标签select的获取、添加、删除操作
2013/10/17 Javascript
Javascript表单验证要注意的事项
2014/09/29 Javascript
javascript 判断整数方法分享
2014/12/16 Javascript
JavaScript创建对象的方式小结(4种方式)
2015/12/17 Javascript
浅谈js中的引用和复制(传值和传址)
2016/09/18 Javascript
vscode 开发Vue项目的方法步骤
2018/11/25 Javascript
express+vue+mongodb+session 实现注册登录功能
2018/12/06 Javascript
Vue中的nextTick作用和几个简单的使用场景
2021/01/25 Vue.js
python实现百度关键词排名查询
2014/03/30 Python
python操作xml文件详细介绍
2014/06/09 Python
Django1.7+python 2.78+pycharm配置mysql数据库教程
2014/11/18 Python
Django应用程序中如何发送电子邮件详解
2017/02/04 Python
Python控制Firefox方法总结
2019/06/03 Python
numpy np.newaxis 的实用分享
2019/11/30 Python
将数据集制作成VOC数据集格式的实例
2020/02/17 Python
Python Process创建进程的2种方法详解
2021/01/25 Python
机电专业毕业生推荐信
2013/11/10 职场文书
营销学习心得体会
2014/09/12 职场文书
入党积极分子对十八届四中全会期盼的思想汇报
2014/10/17 职场文书
2015大学迎新晚会策划书
2015/07/16 职场文书
家电创业计划书
2019/08/05 职场文书
nginx配置proxy_pass中url末尾带/与不带/的区别详解
2021/03/31 Servers
MySQL基于索引的压力测试的实现
2021/11/07 MySQL
Python利用Turtle绘制哆啦A梦和小猪佩奇
2022/04/04 Python