浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中删除文件的程序代码
Mar 13 Python
把django中admin后台界面的英文修改为中文显示的方法
Jul 26 Python
对pytorch中的梯度更新方法详解
Aug 20 Python
Python实现滑动平均(Moving Average)的例子
Aug 24 Python
python sqlite的Row对象操作示例
Sep 11 Python
sublime3之内网安装python插件Anaconda的流程
Nov 10 Python
使用pandas实现筛选出指定列值所对应的行
Dec 13 Python
python中Mako库实例用法
Dec 31 Python
python regex库实例用法总结
Jan 03 Python
Python 中的Sympy详细使用
Aug 07 Python
python 标准库原理与用法详解之os.path篇
Oct 24 Python
教你使用Python获取QQ音乐某个歌手的歌单
Apr 03 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
DOTA2 玩家自创拉野攻略 特色英雄快速成长篇
2020/04/20 DOTA
php与XML、XSLT、Mysql的结合运用实现代码
2009/11/19 PHP
详解PHP用substr函数截取字符串中的某部分
2016/12/03 PHP
js模拟滚动条(横向竖向)
2013/02/22 Javascript
js jq 单击和双击区分示例介绍
2013/11/05 Javascript
三种取消选中单选框radio的方法
2014/09/09 Javascript
javascript实现通过表格绘制颜色填充矩形的方法
2015/04/21 Javascript
DOM 高级编程
2015/05/06 Javascript
Adapter适配器模式在JavaScript设计模式编程中的运用分析
2016/05/18 Javascript
JS仿淘宝搜索框用户输入事件的实现
2017/06/19 Javascript
JS简单实现滑动加载数据的方法示例
2017/10/18 Javascript
Angular实现搜索框及价格上下限功能
2018/01/19 Javascript
基于vue-cli 打包时抽离项目相关配置文件详解
2018/03/07 Javascript
使用vue-route 的 beforeEach 实现导航守卫(路由跳转前验证登录)功能
2018/03/22 Javascript
vuejs点击class变化的实例
2018/09/05 Javascript
vue如何获取自定义元素属性参数值的方法
2019/05/14 Javascript
判断js数据类型的函数实例详解
2019/05/23 Javascript
layui table 表格上添加日期控件的两种方法
2019/09/28 Javascript
python实现linux服务器批量修改密码并生成execl
2014/04/22 Python
Python中使用第三方库xlrd来读取Excel示例
2015/04/05 Python
用Eclipse写python程序
2018/02/10 Python
将pandas.dataframe的数据写入到文件中的方法
2018/12/07 Python
python实现socket+threading处理多连接的方法
2019/07/23 Python
nginx搭建基于python的web环境的实现步骤
2020/01/03 Python
通过cmd进入python的步骤
2020/06/16 Python
adidas瑞典官方网站:购买阿迪达斯鞋子和运动服
2019/12/11 全球购物
学生出入校管理制度
2014/01/16 职场文书
校园文明倡议书
2014/05/16 职场文书
员工培训协议书
2014/09/15 职场文书
安全隐患整改报告
2014/11/06 职场文书
2015年财务个人工作总结范文
2015/05/22 职场文书
大学生暑期实践报告
2015/07/13 职场文书
Golang的继承模拟实例
2021/06/30 Golang
Django路由层如何获取正确的url
2021/07/15 Python
海贼王十大逆天果实 魂魂果实上榜,岩浆果实攻击力最强
2022/03/18 日漫
SpringBoot2零基础到精通之异常处理与web原生组件注入
2022/03/22 Java/Android