浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之简单入门说明(变量和控制语言使用方法)
Mar 25 Python
用实例解释Python中的继承和多态的概念
Apr 27 Python
python实现用于测试网站访问速率的方法
May 26 Python
详解python异步编程之asyncio(百万并发)
Jul 07 Python
利用Python如何制作好玩的GIF动图详解
Jul 11 Python
深入理解Django自定义信号(signals)
Oct 15 Python
matplotlib实现热成像图colorbar和极坐标图的方法
Dec 13 Python
Django Docker容器化部署之Django-Docker本地部署
Oct 09 Python
python进程的状态、创建及使用方法详解
Dec 06 Python
Django之choices选项和富文本编辑器的使用详解
Apr 01 Python
python 实现学生信息管理系统的示例
Nov 28 Python
Appium+Python实现简单的自动化登录测试的实现
Jan 26 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
改写函数实现PHP二维/三维数组转字符串
2013/09/13 PHP
YII实现分页的方法
2014/07/09 PHP
thinkPHP实现MemCache分布式缓存功能
2016/03/23 PHP
PHP实现百度人脸识别
2019/05/06 PHP
JavaScript 学习历程和心得分享
2010/12/12 Javascript
JavaScript基础知识之数据类型
2012/08/06 Javascript
JS随即打乱数组实现代码
2012/12/03 Javascript
node.js中的fs.realpathSync方法使用说明
2014/12/16 Javascript
实例详解jQuery Mockjax 插件模拟 Ajax 请求
2016/01/12 Javascript
jquery插件方式实现table查询功能的简单实例
2016/06/06 Javascript
把普通对象转换成json格式的对象的简单实例
2016/07/04 Javascript
sublime text配置node.js调试(图文教程)
2017/11/23 Javascript
Vue实现导出excel表格功能
2018/03/30 Javascript
微信小程序使用scroll-view标签实现自动滑动到底部功能的实例代码
2018/11/09 Javascript
vue学习之Vue-Router用法实例分析
2020/01/06 Javascript
Vue SPA 首屏优化方案
2021/02/26 Vue.js
[06:48]DOTA2-DPC中国联赛2月26日Recap集锦
2021/03/11 DOTA
用实例详解Python中的Django框架中prefetch_related()函数对数据库查询的优化
2015/04/01 Python
python实现Floyd算法
2018/01/03 Python
Python实现简易版的Web服务器(推荐)
2018/01/29 Python
基于python3 OpenCV3实现静态图片人脸识别
2018/05/25 Python
浅谈Pycharm中的Python Console与Terminal
2019/01/17 Python
Python适配器模式代码实现解析
2019/08/02 Python
python脚本实现音频m4a格式转成MP3格式的实例代码
2019/10/09 Python
使用Python进行防病毒免杀解析
2019/12/13 Python
基于Python采集爬取微信公众号历史数据
2020/11/27 Python
美国五金商店:Ace Hardware
2018/03/27 全球购物
优秀大学生的自我评价
2014/01/16 职场文书
打架检讨书300字
2014/02/02 职场文书
竞选学生会演讲稿
2014/04/25 职场文书
ktv服务员岗位职责
2015/02/09 职场文书
《和时间赛跑》读后感3篇
2019/12/16 职场文书
《卧薪尝胆》读后感3篇
2019/12/26 职场文书
JVM上高性能数据格式库包Apache Arrow入门和架构详解(Gkatziouras)
2021/05/26 Servers
SpringBoot 拦截器妙用你真的了解吗
2021/07/01 Java/Android
微信小程序中wxs文件的一些妙用分享
2022/02/18 Javascript