浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实例一个类背后发生了什么
Feb 09 Python
python中set常用操作汇总
Jun 30 Python
python2.7读取文件夹下所有文件名称及内容的方法
Feb 24 Python
Python3 利用requests 库进行post携带账号密码请求数据的方法
Oct 26 Python
Django之创建引擎索引报错及解决详解
Jul 17 Python
使用Python制作表情包实现换脸功能
Jul 19 Python
pandas中的数据去重处理的实现方法
Feb 10 Python
使用keras实现孪生网络中的权值共享教程
Jun 11 Python
Python selenium键盘鼠标事件实现过程详解
Jul 28 Python
详解向scrapy中的spider传递参数的几种方法(2种)
Sep 28 Python
Python实现迪杰斯特拉算法并生成最短路径的示例代码
Dec 01 Python
M1芯片安装python3.9.1的实现
Feb 02 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
PHPMyadmin 配置文件详解(配置)
2009/12/03 PHP
详解cookie验证的php应用的一种SSO解决办法
2017/10/20 PHP
javascript flash下fromCharCode和charCodeAt方法使用说明
2008/01/12 Javascript
jquery 批量上传图片实现代码
2010/01/28 Javascript
javascript模拟的Ping效果代码 (Web Ping)
2011/03/13 Javascript
Javascript 面向对象(二)封装代码
2012/05/23 Javascript
js 回车提交表单两种实现方法
2012/12/31 Javascript
input:checkbox多选框实现单选效果跟radio一样
2014/06/16 Javascript
原生javascript获取元素样式
2014/12/31 Javascript
JavaScript使用addEventListener添加事件监听用法实例
2015/06/01 Javascript
jQuery头像裁剪工具jcrop用法实例(附演示与demo源码下载)
2016/01/22 Javascript
BootStrap 附加导航组件
2016/07/22 Javascript
HTML Table 空白单元格补全的简单实现
2016/10/13 Javascript
JavaScript中省略元素对数组长度的影响
2016/10/26 Javascript
微信小程序之小豆瓣图书实例
2016/11/30 Javascript
jQuery中DOM节点的删除方法总结(超全面)
2017/01/22 Javascript
JavaScript实现精美个性导航栏筋斗云效果
2017/10/29 Javascript
vue实现图片滚动的示例代码(类似走马灯效果)
2018/03/03 Javascript
浅谈Vue内置component组件的应用场景
2018/03/27 Javascript
AngularJS实现与后台服务器进行交互的示例讲解
2018/08/13 Javascript
详解Vue-cli3.X使用px2rem遇到的问题
2019/08/09 Javascript
vue.js实现左边导航切换右边内容
2019/10/21 Javascript
一篇超完整的Vue新手入门指导教程
2020/11/18 Vue.js
python使用nntp读取新闻组内容的方法
2015/05/08 Python
Python中的相关分析correlation analysis的实现
2019/08/29 Python
jupyter notebook 实现matplotlib图动态刷新
2020/04/22 Python
Python实现手绘图效果实例分享
2020/07/22 Python
Python爬取数据并实现可视化代码解析
2020/08/12 Python
STAUD官方网站:洛杉矶独有的闲适风格
2019/04/11 全球购物
美国知名眼镜网站:Target Optical
2020/04/04 全球购物
淘宝客服专员岗位职责
2014/04/11 职场文书
房地产资料员岗位职责
2014/07/02 职场文书
地心历险记观后感
2015/06/15 职场文书
Python selenium模拟网页点击爬虫交管12123违章数据
2021/05/26 Python
Python实现自动玩连连看的脚本分享
2022/04/04 Python
Nginx 502 bad gateway错误解决的九种方案及原因
2022/08/14 Servers