浅谈TensorFlow中读取图像数据的三种方式


Posted in Python onJune 30, 2020

 本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片、大量图片,和TFRecorder读取方式。并且还补充了功能相近的tf函数。

1、处理单张图片

我们训练完模型之后,常常要用图片测试,有的时候,我们并不需要对很多图像做测试,可能就是几张甚至一张。这种情况下没有必要用队列机制。

import tensorflow as tf
import matplotlib.pyplot as plt

def read_image(file_name):
 img = tf.read_file(filename=file_name)  # 默认读取格式为uint8
 print("img 的类型是",type(img));
 img = tf.image.decode_jpeg(img,channels=0) # channels 为1得到的是灰度图,为0则按照图片格式来读
 return img

def main( ):
 with tf.device("/cpu:0"):

  # img_path是文件所在地址包括文件名称,地址用相对地址或者绝对地址都行 
   img_path='./1.jpg'
   img=read_image(img_path)
   with tf.Session() as sess:
   image_numpy=sess.run(img)
   print(image_numpy)
   print(image_numpy.dtype)
   print(image_numpy.shape)
   plt.imshow(image_numpy)
   plt.show()

if __name__=="__main__":
 main()

"""

输出结果为:

img 的类型是 <class 'tensorflow.python.framework.ops.Tensor'>
[[[196 219 209]
  [196 219 209]
  [196 219 209]
  ...

 [[ 71 106  42]
  [ 59  89  39]
  [ 34  63  19]
  ...
  [ 21  52  46]
  [ 15  45  43]
  [ 22  50  53]]]
uint8
(675, 1200, 3)
"""

 

和tf.read_file用法相似的函数还有tf.gfile.FastGFile  tf.gfile.GFile,只是要指定读取方式是'r' 还是'rb' 。

2、需要读取大量图像用于训练

这种情况就需要使用Tensorflow队列机制。首先是获得每张图片的路径,把他们都放进一个list里面,然后用string_input_producer创建队列,再用tf.WholeFileReader读取。具体请看下例:

def get_image_batch(data_file,batch_size):
 data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
 
 #这个num_epochs函数在整个Graph是local Variable,所以在sess.run全局变量的时候也要加上局部变量。 
 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512)
 reader=tf.WholeFileReader()
 _,img_bytes=reader.read(filenames_queue)
 image=tf.image.decode_png(img_bytes,channels=1) #读取的是什么格式,就decode什么格式
 #解码成单通道的,并且获得的结果的shape是[?, ?,1],也就是Graph不知道图像的大小,需要set_shape
 image.set_shape([180,180,1]) #set到原本已知图像的大小。或者直接通过tf.image.resize_images,tf.reshape()
 image=tf.image.convert_image_dtype(image,tf.float32)
 #预处理 下面的一句代码可以换成自己想使用的预处理方式
 #image=tf.divide(image,255.0) 
 return tf.train.batch([image],batch_size)

这里的date_file是指文件夹所在的路径,不包括文件名。第一句是遍历指定目录下的文件名称,存放到一个list中。当然这个做法有很多种方法,比如glob.glob,或者tf.train.match_filename_once

全部代码如下:

import tensorflow as tf
import os
def read_image(data_file,batch_size):
 data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=5,shuffle=True,capacity=30)
 reader=tf.WholeFileReader()
 _,img_bytes=reader.read(filenames_queue)
 image=tf.image.decode_jpeg(img_bytes,channels=1)
 image=tf.image.resize_images(image,(180,180))

 image=tf.image.convert_image_dtype(image,tf.float32)
 return tf.train.batch([image],batch_size)

def main( ):
 img_path=r'F:\dataSet\WIDER\WIDER_train\images\6--Funeral' #本地的一个数据集目录,有足够的图像
 img=read_image(img_path,batch_size=10)
 image=img[0] #取出每个batch的第一个数据
 print(image)
 init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
 with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  try:
   while not coord.should_stop():
    print(image.shape)
  except tf.errors.OutOfRangeError:
   print('read done')
  finally:
   coord.request_stop()
  coord.join(threads)


if __name__=="__main__":
 main()

"""

输出如下:

(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
"""

这段代码可以说写的很是规整了。注意到init里面有对local变量的初始化,并且因为用到了队列,当然要告诉电脑什么时候队列开始, tf.train.Coordinator 和 tf.train.start_queue_runners 就是两个管理队列的类,用法如程序所示。

与 tf.train.string_input_producer相似的函数是 tf.train.slice_input_producer。 tf.train.slice_input_producer和tf.train.string_input_producer的第一个参数形式不一样。等有时间再做一个二者比较的博客

 3、对TFRecorder解码获得图像数据

其实这块和上一种方式差不多的,更重要的是怎么生成TFRecorder文件,这一部分我会补充到另一篇博客上。

仍然使用 tf.train.string_input_producer。

import tensorflow as tf
import matplotlib.pyplot as plt
import os
import cv2
import numpy as np
import glob

def read_image(data_file,batch_size):
 files_path=glob.glob(data_file)
 queue=tf.train.string_input_producer(files_path,num_epochs=None)
 reader = tf.TFRecordReader()
 print(queue)
 _, serialized_example = reader.read(queue)
 features = tf.parse_single_example(
  serialized_example,
  features={
   'image_raw': tf.FixedLenFeature([], tf.string),
   'label_raw': tf.FixedLenFeature([], tf.string),
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 image = tf.cast(image, tf.float32)
 image.set_shape((12*12*3))
 label = tf.decode_raw(features['label_raw'], tf.float32)
 label.set_shape((2))
 # 预处理部分省略,大家可以自己根据需要添加
 return tf.train.batch([image,label],batch_size=batch_size,num_threads=4,capacity=5*batch_size)

def main( ):
 img_path=r'F:\python\MTCNN_by_myself\prepare_data\pnet*.tfrecords' #本地的几个tf文件
 img,label=read_image(img_path,batch_size=10)
 image=img[0]
 init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
 with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  try:
   while not coord.should_stop():
    print(image.shape)
  except tf.errors.OutOfRangeError:
   print('read done')
  finally:
   coord.request_stop()
  coord.join(threads)


if __name__=="__main__":
 main()

在read_image函数中,先使用glob函数获得了存放tfrecord文件的列表,然后根据TFRecord文件是如何存的就如何parse,再set_shape;这里有必要提醒下parse的方式。我们看到这里用的是tf.decode_raw ,因为做TFRecord是将图像数据string化了,数据是串行的,丢失了空间结果。从features中取出image和label的数据,这时就要用 tf.decode_raw  解码,得到的结果当然也是串行的了,所以set_shape 成一个串行的,再reshape。这种方式是取决于你的编码TFRecord方式的。

再举一种例子:

reader=tf.TFRecordReader()
_,serialized_example=reader.read(file_name_queue)
features = tf.parse_single_example(serialized_example, features={
 'data': tf.FixedLenFeature([256,256], tf.float32), ###
 'label': tf.FixedLenFeature([], tf.int64),
 'id': tf.FixedLenFeature([], tf.int64)
})
img = features['data']
label =features['label']
id = features['id']

这个时候就不需要任何解码了。因为做TFRecord的方式就是直接把图像数据append进去了。

参考链接:

https://blog.csdn.net/qq_34914551/article/details/86286184

到此这篇关于浅谈TensorFlow中读取图像数据的三种方式的文章就介绍到这了,更多相关TensorFlow 读取图像数据内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现的系统实用log类实例
Jun 30 Python
在Python的Django框架中生成CSV文件的方法
Jul 22 Python
python中reload(module)的用法示例详解
Sep 15 Python
python3 pillow生成简单验证码图片的示例
Sep 19 Python
python好玩的项目—色情图片识别代码分享
Nov 07 Python
Python实现学校管理系统
Jan 11 Python
Python计算开方、立方、圆周率,精确到小数点后任意位的方法
Jul 17 Python
python语言线程标准库threading.local解读总结
Nov 10 Python
Python使用QQ邮箱发送邮件实例与QQ邮箱设置详解
Feb 18 Python
Python生成器generator原理及用法解析
Jul 20 Python
记一次Django响应超慢的解决过程
Sep 17 Python
pycharm部署django项目到云服务器的详细流程
Jun 29 Python
python中 _、__、__xx__()区别及使用场景
Jun 30 #Python
Django实现内容缓存实例方法
Jun 30 #Python
Pytorch 卷积中的 Input Shape用法
Jun 29 #Python
Python闭包装饰器使用方法汇总
Jun 29 #Python
使用已经得到的keras模型识别自己手写的数字方式
Jun 29 #Python
Python接口测试环境搭建过程详解
Jun 29 #Python
python字典的值可以修改吗
Jun 29 #Python
You might like
刚才在简化php的库,结果发现很多东西
2006/12/31 PHP
初级的用php写的采集程序
2007/03/16 PHP
PHP 字符串长度判断效率更高的方法
2014/03/02 PHP
PHP+ajaxfileupload+jcrop插件完美实现头像上传剪裁
2014/06/09 PHP
PHP反射使用实例和PHP反射API的中文说明
2014/07/02 PHP
Yii2实现ajax上传图片插件用法
2016/04/28 PHP
CI框架无限级分类+递归的实现代码
2016/11/01 PHP
JavaScript 产生不重复的随机数三种实现思路
2012/12/13 Javascript
获取下拉列表框的值是数组,split,$.inArray示例
2013/11/13 Javascript
javascript获得当前的信息的一些常用命令
2015/02/25 Javascript
JavaScript返回网页中锚点数目的方法
2015/04/03 Javascript
Jquery实现瀑布流布局(备有详细注释)
2015/07/31 Javascript
简单实现轮播图效果的实例
2016/07/15 Javascript
前端程序员必须知道的高性能Javascript知识
2016/08/24 Javascript
nodejs个人博客开发第七步 后台登陆
2017/04/12 NodeJs
浅谈Angular文字折叠展开组件的原理分析
2017/11/24 Javascript
Javascript中parseInt的正确使用方式
2018/10/17 Javascript
javascript sort()对数组中的元素进行排序详解
2019/10/13 Javascript
[01:15:44]首部DOTA2纪录片今日23时全网上映
2014/03/19 DOTA
[03:10]2014DOTA2 TI马来劲旅Titan首战告捷目标只是8强
2014/07/10 DOTA
Python批量重命名同一文件夹下文件的方法
2015/05/25 Python
解决Pycharm界面的子窗口不见了的问题
2019/01/17 Python
python爬虫租房信息在地图上显示的方法
2019/05/13 Python
用Python写一个自动木马程序
2019/09/17 Python
Python字典中的值为列表或字典的构造实例
2019/12/16 Python
Python如何在main中调用函数内的函数方式
2020/06/01 Python
python使用matplotlib绘制折线图的示例代码
2020/09/22 Python
Html5 语法与规则简要概述
2014/07/29 HTML / CSS
西班牙鞋子和箱包在线销售网站:zapatos.es
2020/02/17 全球购物
J2EE系统只能是基于web
2015/09/08 面试题
军训自我鉴定
2014/01/22 职场文书
如何写股份合作协议书
2014/09/11 职场文书
教学反思怎么写
2016/02/24 职场文书
使用 JavaScript 制作页面效果
2021/04/21 Javascript
Redis Cluster集群动态扩容的实现
2021/07/15 Redis
使用python绘制分组对比柱状图
2022/04/21 Python