浅谈TensorFlow中读取图像数据的三种方式


Posted in Python onJune 30, 2020

 本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片、大量图片,和TFRecorder读取方式。并且还补充了功能相近的tf函数。

1、处理单张图片

我们训练完模型之后,常常要用图片测试,有的时候,我们并不需要对很多图像做测试,可能就是几张甚至一张。这种情况下没有必要用队列机制。

import tensorflow as tf
import matplotlib.pyplot as plt

def read_image(file_name):
 img = tf.read_file(filename=file_name)  # 默认读取格式为uint8
 print("img 的类型是",type(img));
 img = tf.image.decode_jpeg(img,channels=0) # channels 为1得到的是灰度图,为0则按照图片格式来读
 return img

def main( ):
 with tf.device("/cpu:0"):

  # img_path是文件所在地址包括文件名称,地址用相对地址或者绝对地址都行 
   img_path='./1.jpg'
   img=read_image(img_path)
   with tf.Session() as sess:
   image_numpy=sess.run(img)
   print(image_numpy)
   print(image_numpy.dtype)
   print(image_numpy.shape)
   plt.imshow(image_numpy)
   plt.show()

if __name__=="__main__":
 main()

"""

输出结果为:

img 的类型是 <class 'tensorflow.python.framework.ops.Tensor'>
[[[196 219 209]
  [196 219 209]
  [196 219 209]
  ...

 [[ 71 106  42]
  [ 59  89  39]
  [ 34  63  19]
  ...
  [ 21  52  46]
  [ 15  45  43]
  [ 22  50  53]]]
uint8
(675, 1200, 3)
"""

 

和tf.read_file用法相似的函数还有tf.gfile.FastGFile  tf.gfile.GFile,只是要指定读取方式是'r' 还是'rb' 。

2、需要读取大量图像用于训练

这种情况就需要使用Tensorflow队列机制。首先是获得每张图片的路径,把他们都放进一个list里面,然后用string_input_producer创建队列,再用tf.WholeFileReader读取。具体请看下例:

def get_image_batch(data_file,batch_size):
 data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
 
 #这个num_epochs函数在整个Graph是local Variable,所以在sess.run全局变量的时候也要加上局部变量。 
 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512)
 reader=tf.WholeFileReader()
 _,img_bytes=reader.read(filenames_queue)
 image=tf.image.decode_png(img_bytes,channels=1) #读取的是什么格式,就decode什么格式
 #解码成单通道的,并且获得的结果的shape是[?, ?,1],也就是Graph不知道图像的大小,需要set_shape
 image.set_shape([180,180,1]) #set到原本已知图像的大小。或者直接通过tf.image.resize_images,tf.reshape()
 image=tf.image.convert_image_dtype(image,tf.float32)
 #预处理 下面的一句代码可以换成自己想使用的预处理方式
 #image=tf.divide(image,255.0) 
 return tf.train.batch([image],batch_size)

这里的date_file是指文件夹所在的路径,不包括文件名。第一句是遍历指定目录下的文件名称,存放到一个list中。当然这个做法有很多种方法,比如glob.glob,或者tf.train.match_filename_once

全部代码如下:

import tensorflow as tf
import os
def read_image(data_file,batch_size):
 data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=5,shuffle=True,capacity=30)
 reader=tf.WholeFileReader()
 _,img_bytes=reader.read(filenames_queue)
 image=tf.image.decode_jpeg(img_bytes,channels=1)
 image=tf.image.resize_images(image,(180,180))

 image=tf.image.convert_image_dtype(image,tf.float32)
 return tf.train.batch([image],batch_size)

def main( ):
 img_path=r'F:\dataSet\WIDER\WIDER_train\images\6--Funeral' #本地的一个数据集目录,有足够的图像
 img=read_image(img_path,batch_size=10)
 image=img[0] #取出每个batch的第一个数据
 print(image)
 init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
 with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  try:
   while not coord.should_stop():
    print(image.shape)
  except tf.errors.OutOfRangeError:
   print('read done')
  finally:
   coord.request_stop()
  coord.join(threads)


if __name__=="__main__":
 main()

"""

输出如下:

(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
"""

这段代码可以说写的很是规整了。注意到init里面有对local变量的初始化,并且因为用到了队列,当然要告诉电脑什么时候队列开始, tf.train.Coordinator 和 tf.train.start_queue_runners 就是两个管理队列的类,用法如程序所示。

与 tf.train.string_input_producer相似的函数是 tf.train.slice_input_producer。 tf.train.slice_input_producer和tf.train.string_input_producer的第一个参数形式不一样。等有时间再做一个二者比较的博客

 3、对TFRecorder解码获得图像数据

其实这块和上一种方式差不多的,更重要的是怎么生成TFRecorder文件,这一部分我会补充到另一篇博客上。

仍然使用 tf.train.string_input_producer。

import tensorflow as tf
import matplotlib.pyplot as plt
import os
import cv2
import numpy as np
import glob

def read_image(data_file,batch_size):
 files_path=glob.glob(data_file)
 queue=tf.train.string_input_producer(files_path,num_epochs=None)
 reader = tf.TFRecordReader()
 print(queue)
 _, serialized_example = reader.read(queue)
 features = tf.parse_single_example(
  serialized_example,
  features={
   'image_raw': tf.FixedLenFeature([], tf.string),
   'label_raw': tf.FixedLenFeature([], tf.string),
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 image = tf.cast(image, tf.float32)
 image.set_shape((12*12*3))
 label = tf.decode_raw(features['label_raw'], tf.float32)
 label.set_shape((2))
 # 预处理部分省略,大家可以自己根据需要添加
 return tf.train.batch([image,label],batch_size=batch_size,num_threads=4,capacity=5*batch_size)

def main( ):
 img_path=r'F:\python\MTCNN_by_myself\prepare_data\pnet*.tfrecords' #本地的几个tf文件
 img,label=read_image(img_path,batch_size=10)
 image=img[0]
 init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
 with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  try:
   while not coord.should_stop():
    print(image.shape)
  except tf.errors.OutOfRangeError:
   print('read done')
  finally:
   coord.request_stop()
  coord.join(threads)


if __name__=="__main__":
 main()

在read_image函数中,先使用glob函数获得了存放tfrecord文件的列表,然后根据TFRecord文件是如何存的就如何parse,再set_shape;这里有必要提醒下parse的方式。我们看到这里用的是tf.decode_raw ,因为做TFRecord是将图像数据string化了,数据是串行的,丢失了空间结果。从features中取出image和label的数据,这时就要用 tf.decode_raw  解码,得到的结果当然也是串行的了,所以set_shape 成一个串行的,再reshape。这种方式是取决于你的编码TFRecord方式的。

再举一种例子:

reader=tf.TFRecordReader()
_,serialized_example=reader.read(file_name_queue)
features = tf.parse_single_example(serialized_example, features={
 'data': tf.FixedLenFeature([256,256], tf.float32), ###
 'label': tf.FixedLenFeature([], tf.int64),
 'id': tf.FixedLenFeature([], tf.int64)
})
img = features['data']
label =features['label']
id = features['id']

这个时候就不需要任何解码了。因为做TFRecord的方式就是直接把图像数据append进去了。

参考链接:

https://blog.csdn.net/qq_34914551/article/details/86286184

到此这篇关于浅谈TensorFlow中读取图像数据的三种方式的文章就介绍到这了,更多相关TensorFlow 读取图像数据内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python 随机数生成的代码的详细分析
May 15 Python
使用Python编写Linux系统守护进程实例
Feb 03 Python
Python中的连接符(+、+=)示例详解
Jan 13 Python
python下载图片实现方法(超简单)
Jul 21 Python
MySQL适配器PyMySQL详解
Sep 20 Python
Python Json序列化与反序列化的示例
Jan 31 Python
PyTorch的深度学习入门教程之构建神经网络
Jun 27 Python
python selenium登录豆瓣网过程解析
Aug 10 Python
Python中logging日志库实例详解
Feb 19 Python
Python代码一键转Jar包及Java调用Python新姿势
Mar 10 Python
探秘TensorFlow 和 NumPy 的 Broadcasting 机制
Mar 13 Python
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 Python
python中 _、__、__xx__()区别及使用场景
Jun 30 #Python
Django实现内容缓存实例方法
Jun 30 #Python
Pytorch 卷积中的 Input Shape用法
Jun 29 #Python
Python闭包装饰器使用方法汇总
Jun 29 #Python
使用已经得到的keras模型识别自己手写的数字方式
Jun 29 #Python
Python接口测试环境搭建过程详解
Jun 29 #Python
python字典的值可以修改吗
Jun 29 #Python
You might like
PHP连接access数据库
2008/03/27 PHP
php self,$this,const,static,-&amp;gt;的使用
2009/10/22 PHP
PHP数据类型的总结分析
2013/06/13 PHP
WordPress特定文章对搜索引擎隐藏或只允许搜索引擎查看
2015/12/31 PHP
Joomla开启SEF的方法
2016/05/04 PHP
PHP静态延迟绑定和普通静态效率的对比
2017/10/20 PHP
漂亮的thinkphp 跳转页封装示例
2019/10/16 PHP
JavaScript面向对象之静态与非静态类
2010/02/03 Javascript
关于query Javascript CSS Selector engine
2013/04/12 Javascript
cookie中的path与domain属性详解
2013/12/18 Javascript
一个判断抢购时间是否到达的简单的js函数
2014/06/23 Javascript
微信开发 JS-SDK 6.0.2 经常遇到问题总结
2016/12/08 Javascript
使用smartupload组件实现jsp+jdbc上传下载文件实例解析
2017/01/05 Javascript
Vue中的v-cloak使用解读
2017/03/27 Javascript
Jquery-data的三种用法
2017/04/18 jQuery
详解vue中引入stylus及报错解决方法
2017/09/22 Javascript
详解angular脏检查原理及伪代码实现
2018/06/08 Javascript
使用vue开发移动端管理后台的注意事项
2019/03/07 Javascript
vue轻量级框架无法获取到vue对象解决方法
2019/05/12 Javascript
微信小程序tab左右滑动切换功能的实现代码
2021/02/08 Javascript
[02:41]辉夜杯现场一家三口 “我爸玩风行 我玩血魔”
2015/12/27 DOTA
[47:12]TFT vs Secret Supermajor小组赛C组 BO3 第三场 6.3
2018/06/04 DOTA
介绍Python中的__future__模块
2015/04/27 Python
python UNIX_TIMESTAMP时间处理方法分析
2016/04/18 Python
python中scikit-learn机器代码实例
2018/08/05 Python
详解如何管理多个Python版本和虚拟环境
2019/05/10 Python
python 普通克里金(Kriging)法的实现
2019/12/19 Python
Python列表切片常用操作实例解析
2020/03/10 Python
CSS3自定义滚动条样式 ::webkit-scrollbar的示例代码详解
2020/06/01 HTML / CSS
C语言面试题
2015/10/30 面试题
Java基础知识面试题
2014/03/25 面试题
精细化工应届生求职信
2013/11/17 职场文书
优秀班组事迹材料
2014/12/24 职场文书
学前班语言教学计划
2015/01/20 职场文书
Jupyter notebook 不自动弹出网页的解决方案
2021/05/21 Python
JavaScript获取URL参数的方法分享
2022/04/07 Javascript