tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python HTTP客户端自定义Cookie实现实例
Apr 28 Python
PyQt实现界面翻转切换效果
Apr 20 Python
Python实现快速计算词频功能示例
Jun 25 Python
python实现对指定字符串补足固定长度倍数截断输出的方法
Nov 15 Python
对Python+opencv将图片生成视频的实例详解
Jan 08 Python
Python基本数据结构之字典类型dict用法分析
Jun 08 Python
pytz格式化北京时间多出6分钟问题的解决方法
Jun 21 Python
python覆盖写入,追加写入的实例
Jun 26 Python
wxPython绘图模块wxPyPlot实现数据可视化
Nov 19 Python
Python操作多维数组输出和矩阵运算示例
Nov 28 Python
python数据爬下来保存的位置
Feb 17 Python
一劳永逸彻底解决pip install慢的办法
May 24 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
PHP MySQL应用中使用XOR运算加密算法分享
2011/08/28 PHP
php实现把数组按指定的个数分隔
2014/02/17 PHP
使用ExtJS技术实现的拖动树结点
2010/08/05 Javascript
在次封装easyui-Dialog插件实现代码
2010/11/14 Javascript
jquery选择器大全 全面详解jquery选择器
2014/03/06 Javascript
JS中改变this指向的方法(call和apply、bind)
2016/03/26 Javascript
使用PBFunc在Powerbuilder中支付宝当面付款功能
2016/10/01 Javascript
JavaScript和jQuery获取input框的绝对位置实现方法
2016/10/13 Javascript
详解AngularJS中的表单验证(推荐)
2016/11/17 Javascript
Js自动截取字符串长度,添加省略号(……)的实现方法
2017/03/06 Javascript
Angular CLI 安装和使用教程
2017/09/13 Javascript
原生js的ajax和解决跨域的jsonp(实例讲解)
2017/10/16 Javascript
js的函数的按值传递参数(实例讲解)
2017/11/16 Javascript
元素全屏的设置与监听实例
2017/11/28 Javascript
浅谈vue自定义全局组件并通过全局方法 Vue.use() 使用该组件
2017/12/07 Javascript
通过vue-cli来学习修改Webpack多环境配置和发布问题
2017/12/22 Javascript
详解vue使用vue-layer-mobile组件实现toast,loading效果
2018/08/31 Javascript
vue2.0 实现富文本编辑器功能
2019/05/26 Javascript
在layui中对table中的数据进行判断(0、1)转换为提示信息的方法
2019/09/28 Javascript
javascript设计模式 ? 中介者模式原理与用法实例分析
2020/04/20 Javascript
Vue中inheritAttrs的使用实例详解
2020/12/31 Vue.js
对Python 语音识别框架详解
2018/12/24 Python
Python实现将HTML转成PDF的方法分析
2019/05/04 Python
HTML5网页音乐播放器的示例代码
2017/11/09 HTML / CSS
施华洛世奇美国官网:SWAROVSKI美国
2018/02/08 全球购物
萨克斯第五大道英国:Saks Fifth Avenue英国
2019/04/01 全球购物
欧洲最大的预定车位市场:JustPark
2020/01/06 全球购物
牵手50新加坡:专为黄金岁月的单身人士而设的交友网站
2020/08/16 全球购物
波兰在线运动商店:YesSport
2020/07/23 全球购物
翻译专业应届生求职信
2013/11/23 职场文书
《生命 生命》教学反思
2014/04/19 职场文书
团支书竞选演讲稿
2014/04/28 职场文书
小学标准化建设汇报材料
2014/08/16 职场文书
班主任师德师风自我剖析材料
2014/10/02 职场文书
三十年同学聚会感言
2015/07/30 职场文书
秀!学妹看见都惊呆的Python小招数!【详细语言特性使用技巧】
2021/04/27 Python