tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单了解python模块概念
Jan 11 Python
Python iter()函数用法实例分析
Mar 17 Python
python3使用smtplib实现发送邮件功能
May 22 Python
PyCharm Anaconda配置PyQt5开发环境及创建项目的教程详解
Mar 24 Python
Python内存映射文件读写方式
Apr 24 Python
opencv 图像加法与图像融合的实现代码
Jul 08 Python
python 解决pycharm运行py文件只有unittest选项的问题
Sep 01 Python
Python通用唯一标识符uuid模块使用案例
Sep 10 Python
如何利用Python写个坦克大战
Nov 18 Python
python3处理word文档实例分析
Dec 01 Python
用Python制作音乐海报
Jan 26 Python
Python中json.dumps()函数的使用解析
May 17 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
PHP的范围解析操作符(::)的含义分析说明
2011/07/03 PHP
php 伪造本地文件包含漏洞的代码
2011/11/03 PHP
CodeIgniter实现更改view文件夹路径的方法
2014/07/04 PHP
PHP实现文件上传与下载实例与总结
2016/03/13 PHP
微信 开发生成带参数的二维码的实例
2016/11/23 PHP
php制作圆形用户头像的实例_自定义封装类源代码
2017/09/18 PHP
总结PHP内存释放以及垃圾回收
2018/03/29 PHP
laravel Validator ajax返回错误信息的方法
2019/09/29 PHP
jQuery.ajax 用户登录验证代码
2010/10/29 Javascript
js静态方法与实例方法分析
2011/07/04 Javascript
我的Node.js学习之路(一)
2014/07/06 Javascript
innerHTML属性,outerHTML属性,textContent属性,innerText属性区别详解
2015/03/13 Javascript
基于JS实现简单的样式切换效果代码
2015/09/04 Javascript
JS实现超简单的仿QQ折叠菜单效果
2015/09/21 Javascript
Jquery $when done then的用法详解
2016/05/20 Javascript
解决bootstrap导航栏navbar在IE8上存在缺陷的方法
2016/07/01 Javascript
关于Vue Webpack2单元测试示例详解
2017/08/14 Javascript
浅析java线程中断的办法
2018/07/29 Javascript
antd日期选择器禁止选择当天之前的时间操作
2020/10/29 Javascript
python3利用tcp实现文件夹远程传输
2018/07/28 Python
对Python的zip函数妙用,旋转矩阵详解
2018/12/13 Python
解决pycharm中的run和debug失效无法点击运行
2020/06/09 Python
聊聊python中的循环遍历
2020/09/07 Python
详解anaconda离线安装pytorchGPU版
2020/09/08 Python
纽约的奢华内衣店:Journelle
2016/07/29 全球购物
linux面试题参考答案(6)
2014/08/29 面试题
继承时候类的执行顺序问题,一般都是选择题,问你将会打印出什么?
2015/11/18 面试题
本科毕业生自荐信
2014/05/26 职场文书
镇党委书记群众路线整改措施思想汇报
2014/10/13 职场文书
工作检讨书怎么写
2015/01/23 职场文书
研究生导师推荐信
2015/03/25 职场文书
2015年加油站工作总结
2015/05/13 职场文书
雷锋之歌观后感
2015/06/10 职场文书
2019年恭贺升学祝福语集锦
2019/08/15 职场文书
导游词之五台山
2019/10/11 职场文书
纯html+css实现奥运五环的示例代码
2021/08/02 HTML / CSS