Tensorflow使用tfrecord输入数据格式


Posted in Python onJune 19, 2018

Tensorflow 提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式。

1. TFRecord格式介绍

TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
  oneof kind{
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。

2. 将自己的数据转化为TFRecord格式

准备数据

在上一篇中,我们为了像伟大的MNIST致敬,所以选择图像的前缀来进行不同类别的分类依据,但是大多数的情况下,在进行分类任务的过程中,不同的类别都会放在不同的文件夹下,而且类别的个数往往浮动性又很大,所以针对这样的情况,我们现在利用不同类别在不同文件夹中的图像来生成TFRecord.

我们在Iris&Contact这个文件夹下有两个文件夹,分别为iris,contact。对于每个文件夹中存放的是对应的图片

转换数据

数据准备好以后,就开始准备生成TFRecord,具体代码如下:

import os 
import tensorflow as tf 
from PIL import Image 
import matplotlib.pyplot as plt 

cwd='/home/ruyiwei/Documents/Iris&Contact/'
classes={'iris','contact'} 
writer= tf.python_io.TFRecordWriter("iris_contact.tfrecords") 

for index,name in enumerate(classes):
  class_path=cwd+name+'/'
  for img_name in os.listdir(class_path): 
    img_path=class_path+img_name 
    img=Image.open(img_path)
    img= img.resize((512,80))
    img_raw=img.tobytes()
    #plt.imshow(img) # if you want to check you image,please delete '#'
    #plt.show()
    example = tf.train.Example(features=tf.train.Features(feature={
      "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
      'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    })) 
    writer.write(example.SerializeToString()) 

writer.close()

3. Tensorflow从TFRecord中读取数据

def read_and_decode(filename): # read iris_contact.tfrecords
  filename_queue = tf.train.string_input_producer([filename])# create a queue

  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)#return file_name and file
  features = tf.parse_single_example(serialized_example,
                    features={
                      'label': tf.FixedLenFeature([], tf.int64),
                      'img_raw' : tf.FixedLenFeature([], tf.string),
                    })#return image and label

  img = tf.decode_raw(features['img_raw'], tf.uint8)
  img = tf.reshape(img, [512, 80, 3]) #reshape image to 512*80*3
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
  label = tf.cast(features['label'], tf.int32) #throw label tensor
  return img, label

4. 将TFRecord中的数据保存为图片

filename_queue = tf.train.string_input_producer(["iris_contact.tfrecords"]) 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  #return file and file_name
features = tf.parse_single_example(serialized_example,
                  features={
                    'label': tf.FixedLenFeature([], tf.int64),
                    'img_raw' : tf.FixedLenFeature([], tf.string),
                  }) 
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [512, 80, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: 
  init_op = tf.initialize_all_variables()
  sess.run(init_op)
  coord=tf.train.Coordinator()
  threads= tf.train.start_queue_runners(coord=coord)
  for i in range(20):
    example, l = sess.run([image,label])#take out image and label
    img=Image.fromarray(example, 'RGB')
    img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#save image
    print(example, l)
  coord.request_stop()
  coord.join(threads)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中精确输出JSON浮点数的方法
Apr 18 Python
定制FileField中的上传文件名称实例
Aug 23 Python
机器学习10大经典算法详解
Dec 07 Python
Tensorflow之构建自己的图片数据集TFrecords的方法
Feb 07 Python
对numpy中布尔型数组的处理方法详解
Apr 17 Python
Selenium 模拟浏览器动态加载页面的实现方法
May 16 Python
详解Python3之数据指纹MD5校验与对比
Jun 11 Python
通过python实现随机交换礼物程序详解
Jul 10 Python
Python 实现将数组/矩阵转换成Image类
Jan 09 Python
Python使用ElementTree美化XML格式的操作
Mar 06 Python
68行Python代码实现带难度升级的贪吃蛇
Jan 18 Python
浅谈Python中对象是如何被调用的
Apr 06 Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 #Python
python如何爬取个性签名
Jun 19 #Python
详解TensorFlow查看ckpt中变量的几种方法
Jun 19 #Python
TensorFlow 滑动平均的示例代码
Jun 19 #Python
python3个性签名设计实现代码
Jun 19 #Python
TensorFlow 模型载入方法汇总(小结)
Jun 19 #Python
python3爬虫之设计签名小程序
Jun 19 #Python
You might like
shell脚本作为保证PHP脚本不挂掉的守护进程实例分享
2013/07/15 PHP
php CI框架插入一条或多条sql记录示例
2014/07/29 PHP
解决phpcms更换javascript的幻灯片代码调用图片问题
2014/12/26 PHP
php7安装yar扩展的方法详解
2017/08/03 PHP
PHP设计模式之抽象工厂模式实例分析
2019/03/25 PHP
解决Extjs4中form表单提交后无法进入success函数问题
2013/11/26 Javascript
javascript日期格式化示例分享
2014/03/05 Javascript
Javascript 浮点运算精度问题分析与解决
2014/03/26 Javascript
Jquery+asp.net后台数据传到前台js进行解析的方法
2014/05/11 Javascript
AngularJS+Node.js实现在线聊天室
2015/08/28 Javascript
设置jQueryUI DatePicker默认语言为中文
2016/06/04 Javascript
jQuery progressbar通过Ajax请求实现后台进度实时功能
2016/10/11 Javascript
JavaScript调试的多个必备小Tips
2017/01/15 Javascript
vue2.0全局组件之pdf详解
2017/06/26 Javascript
js实现首屏延迟加载实现方法 js实现多屏单张图片延迟加载效果
2017/07/17 Javascript
原生js jquery ajax请求以及jsonp的调用方法
2017/08/04 jQuery
vue-cli 3.x配置跨域代理的实现方法
2019/04/12 Javascript
Vue.js递归组件实现组织架构树和选人功能
2019/07/04 Javascript
Vue使用Three.js加载glTF模型的方法详解
2020/06/14 Javascript
如何搭建一个完整的Vue3.0+ts的项目步骤
2020/10/18 Javascript
Python实现包含min函数的栈
2016/04/29 Python
django开发之settings.py中变量的全局引用详解
2017/03/29 Python
Python 调用Java实例详解
2017/06/02 Python
Tornado协程在python2.7如何返回值(实现方法)
2017/06/22 Python
python:pandas合并csv文件的方法(图书数据集成)
2018/04/12 Python
详解python使用pip安装第三方库(工具包)速度慢、超时、失败的解决方案
2018/12/02 Python
python 实现交换两个列表元素的位置示例
2019/06/26 Python
Python如何向SQLServer存储二进制图片
2020/06/08 Python
重构Python代码的六个实例
2020/11/25 Python
自动化系在校本科生求职信
2013/10/23 职场文书
个人评价范文分享
2014/01/11 职场文书
标准自荐信范文
2014/01/29 职场文书
企业内部培训方案
2014/02/04 职场文书
幼儿园六一儿童节文艺汇演主持词
2014/03/21 职场文书
植树节标语
2014/06/27 职场文书
nginx之内存池的实现
2022/06/28 Servers