tensorflow TFRecords文件的生成和读取的方法


Posted in Python onFebruary 06, 2018

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example { 
  Features features = 1; 
}; 
message Features { 
  map<string, Feature> feature = 1; 
}; 
message Feature { 
  oneof kind { 
  BytesList bytes_list = 1; 
  FloatList float_list = 2; 
  Int64List int64_list = 3; 
} 
};

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle 
import numpy as np 
import glob 
import tensorflow as tf 
import cv2 
import sys 
import os 
 
# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
 
shuffle_data = True 
image_path = '/path/to/image/*.jpg' 
 
# 取得该路径下所有图片的路径,type(addrs)= list 
addrs = glob.glob(image_path) 
# 标签数据的获得具体情况具体分析,type(labels)= list 
labels = ... 
 
# 这里是打乱数据的顺序 
if shuffle_data: 
  c = list(zip(addrs, labels)) 
  shuffle(c) 
  addrs, labels = zip(*c) 
 
# 按需分割数据集 
train_addrs = addrs[0:int(0.7*len(addrs))] 
train_labels = labels[0:int(0.7*len(labels))] 
 
val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] 
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 
 
test_addrs = addrs[int(0.9*len(addrs)):] 
test_labels = labels[int(0.9*len(labels)):] 
 
# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 
def load_image(addr): # A function to Load image 
  img = cv2.imread(addr) 
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
  # 这里/255是为了将像素值归一化到[0,1] 
  img = img / 255. 
  img = img.astype(np.float32) 
  return img 
 
# 将数据转化成对应的属性 
def _int64_feature(value):  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
 
def _float_feature(value): 
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
 
# 下面这段就开始把数据写入TFRecods文件 
 
train_filename = '/path/to/train.tfrecords' # 输出文件地址 
 
# 创建一个writer来写 TFRecords 文件 
writer = tf.python_io.TFRecordWriter(train_filename) 
 
for i in range(len(train_addrs)): 
  # 这是写入操作可视化处理 
  if not i % 1000: 
    print('Train data: {}/{}'.format(i, len(train_addrs))) 
    sys.stdout.flush() 
  # 加载图片 
  img = load_image(train_addrs[i]) 
 
  label = train_labels[i] 
 
  # 创建一个属性(feature) 
  feature = {'train/label': _int64_feature(label), 
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 
 
  # 创建一个 example protocol buffer 
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 
 
  # 将上面的example protocol buffer写入文件 
  writer.write(example.SerializeToString()) 
 
writer.close() 
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
data_path = 'train.tfrecords' # tfrecords 文件的地址 
 
with tf.Session() as sess: 
  # 先定义feature,这里要和之前创建的时候保持一致 
  feature = { 
    'train/image': tf.FixedLenFeature([], tf.string), 
    'train/label': tf.FixedLenFeature([], tf.int64) 
  } 
  # 创建一个队列来维护输入文件列表 
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 
 
  # 定义一个 reader ,读取下一个 record 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
 
  # 解析读入的一个record 
  features = tf.parse_single_example(serialized_example, features=feature) 
 
  # 将字符串解析成图像对应的像素组 
  image = tf.decode_raw(features['train/image'], tf.float32) 
 
  # 将标签转化成int32 
  label = tf.cast(features['train/label'], tf.int32) 
 
  # 这里将图片还原成原来的维度 
  image = tf.reshape(image, [224, 224, 3]) 
 
  # 你还可以进行其他一些预处理.... 
 
  # 这里是创建顺序随机 batches(函数不懂的自行百度) 
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 
 
  # 初始化 
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
  sess.run(init_op) 
 
  # 启动多线程处理输入数据 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
 
  .... 
 
  #关闭线程 
  coord.request_stop() 
  coord.join(threads) 
  sess.close()

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现360皮肤按钮控件示例
Feb 21 Python
python网络编程学习笔记(九):数据库客户端 DB-API
Jun 09 Python
Python 序列化 pickle/cPickle模块使用介绍
Nov 30 Python
Python中编写ORM框架的入门指引
Apr 29 Python
详解使用Python处理文件目录的相关方法
Oct 16 Python
python利用有道翻译实现&quot;语言翻译器&quot;的功能实例
Nov 14 Python
Python迭代器与生成器用法实例分析
Jul 09 Python
python自动化测试之如何解析excel文件
Jun 27 Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 Python
python 函数的缺省参数使用注意事项分析
Sep 17 Python
完美处理python与anaconda环境变量的冲突问题
Apr 07 Python
用Python实现屏幕截图详解
Jan 22 Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
You might like
极典R601SW收音机
2021/03/02 无线电
PHP 获取远程文件内容的函数代码
2010/03/24 PHP
php利用新浪接口查询ip获取地理位置示例
2014/01/20 PHP
PHP实现模仿socket请求返回页面的方法
2014/11/04 PHP
PHP实现CSV文件的导入和导出类
2015/03/24 PHP
用JS实现的一个include函数
2007/07/21 Javascript
JavaScript高级程序设计阅读笔记(十六) javascript检测浏览器和操作系统-detect.js
2012/08/14 Javascript
正则表达式中特殊符号及正则表达式的几种方法总结(replace,test,search)
2013/11/26 Javascript
键盘KeyCode值列表汇总
2013/11/26 Javascript
jQuery选择器简明总结(含用法实例,一目了然)
2014/04/25 Javascript
浅谈js中变量初始化
2015/02/03 Javascript
直接拿来用的页面跳转进度条JS实现
2016/01/06 Javascript
jquery把int类型转换成字符串类型的方法
2016/10/07 Javascript
利用JS轻松实现获取表单数据
2016/12/06 Javascript
基于javascript实现的购物商城商品倒计时实例
2016/12/11 Javascript
JS简单判断函数是否存在的方法
2017/02/13 Javascript
微信小程序使用setData修改数组中单个对象的方法分析
2018/12/30 Javascript
深入理解vue中的slot与slot-scope
2019/04/22 Javascript
简单了解JavaScript异步
2019/05/23 Javascript
JS面向对象之多选框实现
2020/01/17 Javascript
Ant Design Vue 添加区分中英文的长度校验功能
2020/01/21 Javascript
使用python实现扫描端口示例
2014/03/29 Python
python字符类型的一些方法小结
2016/05/16 Python
Python3.4 tkinter,PIL图片转换
2018/06/21 Python
pycharm 设置项目的根目录教程
2020/02/12 Python
浅析matlab中imadjust函数
2020/02/27 Python
python matplotlib.pyplot.plot()参数用法
2020/04/14 Python
使用pymysql查询数据库,把结果保存为列表并获取指定元素下标实例
2020/05/15 Python
Python如何实现机器人聊天
2020/09/10 Python
Python利用Pillow(PIL)库实现验证码图片的全过程
2020/10/04 Python
Origins悦木之源英国官网:雅诗兰黛集团高端植物护肤品牌
2017/11/06 全球购物
凯伦·米莲女装网上商店:Karen Millen
2017/11/07 全球购物
KENZO官网:高田贤三在法国创立的品牌
2019/05/16 全球购物
英国设计师珠宝网站:Joshua James Jewellery
2020/03/01 全球购物
物流专业求职计划书
2014/01/10 职场文书
服装设计专业毕业生求职信
2014/04/09 职场文书