如何从csv文件构建Tensorflow的数据集


Posted in Python onSeptember 21, 2020

从csv文件构建Tensorflow的数据集

当我们有一系列CSV文件,如何构建Tensorflow的数据集呢?

基本步骤

  1. 获得一组CSV文件的路径
  2. 将这组文件名,转成文件名对应的dataset => file_dataset
  3. 根据file_dataset中的每个文件名,读取文件内容 生成一个内容的dataset => content_dataset
  4. 这样的多个content_dataset, 拼接起来,形成一整个dataset
  5. 因为读出来的每条记录都是string类型, 所以还需要对每条记录做decode

存在一个这样的变量train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',
#	 'generate_csv\\train_01.csv',
#	 'generate_csv\\train_02.csv',
#	 'generate_csv\\train_03.csv',
#	 'generate_csv\\train_04.csv',
#	 'generate_csv\\train_05.csv',
#	 'generate_csv\\train_06.csv',
#	 'generate_csv\\train_07.csv',
#	 'generate_csv\\train_08.csv',
#	 'generate_csv\\train_09.csv',
#	 'generate_csv\\train_10.csv',
#	 'generate_csv\\train_11.csv',
#	 'generate_csv\\train_12.csv',
#	 'generate_csv\\train_13.csv',
#	 'generate_csv\\train_14.csv',
#	 'generate_csv\\train_15.csv',
#	 'generate_csv\\train_16.csv',
#	 'generate_csv\\train_17.csv',
#	 'generate_csv\\train_18.csv',
#	 'generate_csv\\train_19.csv']

接着,我们用提前定义好的API构建文件名数据集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv', shape=(), dtype=string)

第三步, 根据每个文件名,去读取文件里面的内容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),
  cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171', shape=(), dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428', shape=(), dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75', shape=(), dtype=string)

interleave的作用可以类比map, 对每个元素应用操作,然后还能把结果合起来。
因此,有了interleave, 我们就把第三四步,一起完成了
之所以skip(1),是因为这个csv第一行是header.
cycle_length是并行化构建数据集的线程数

好,第五步,解析每条记录

def parse_csv_line(line, n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line, record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x, y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803', 9)

#(<tf.Tensor: shape=(8,), dtype=float32, numpy= array([ 1.2286259 , -1.0806246 , 0.44441614, -0.03521726, 0.9740348 ,-0.00351608, -0.81265247, 0.86560905], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.803], dtype=float32)>)

最后,将每条记录都应用这个方法,就完成了构建。

dataset = dataset.map(parse_csv_line)

完整代码

def csv_2_dataset(filenames, n_readers_thread = 5, batch_size = 32, n_parse_thread = 5, shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line, num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames, batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames, batch_size=32)

model = ...

model.fit(train_set, validation_data=valid_set, 
          steps_per_epoch = 11610 // 32,
          validation_steps = 3870 // 32,
          epochs=100, callbacks=callbacks)

这里的11610 和 3870是什么?

这是train_dataset 和 valid_dataset中数据的数量,需要在训练中手动指定每个batch中参与训练的数据的多少。

model.evaluate(test_set, steps=5160//32)

同理,测试的时候,使用这样的数据集,也需要手动指定。
5160是测试数据集的总量。

以上就是如何从csv文件构建Tensorflow的数据集的详细内容,更多关于csv文件构建Tensorflow的数据集的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python实现基本进制转换的方法
Jul 11 Python
Python ldap实现登录实例代码
Sep 30 Python
基于python 二维数组及画图的实例详解
Apr 03 Python
Python中变量的输入输出实例代码详解
Jul 28 Python
python列表插入append(), extend(), insert()用法详解
Sep 14 Python
python应用文件读取与登录注册功能
Sep 23 Python
利用python、tensorflow、opencv、pyqt5实现人脸实时签到系统
Sep 25 Python
window7下的python2.7版本和python3.5版本的opencv-python安装过程
Oct 24 Python
Python实现RGB与HSI颜色空间的互换方式
Nov 27 Python
使用python实现哈希表、字典、集合操作
Dec 22 Python
浅谈Keras的Sequential与PyTorch的Sequential的区别
Jun 17 Python
python如何写try语句
Jul 14 Python
python打包多类型文件的操作方法
Sep 21 #Python
python 星号(*)的多种用途
Sep 21 #Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 #Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 #Python
python map比for循环快在哪
Sep 21 #Python
通过实例解析Python文件操作实现步骤
Sep 21 #Python
python Paramiko使用示例
Sep 21 #Python
You might like
PHP如何抛出异常处理错误
2011/03/02 PHP
PHP+Mysql+jQuery实现发布微博程序 jQuery篇
2011/10/08 PHP
php之curl设置超时实例
2014/11/03 PHP
解决Yii2邮件发送结果返回成功,但接收不到邮件的问题
2017/05/23 PHP
PHP JWT初识及其简单示例
2018/10/10 PHP
ext 代码生成器
2009/08/07 Javascript
JS中confirm,alert,prompt函数使用区别分析
2010/04/01 Javascript
javascript获取鼠标点击元素对象(示例代码)
2013/12/20 Javascript
js获取当前地址 JS获取当前URL的示例代码
2014/02/26 Javascript
使用phantomjs进行网页抓取的实现代码
2014/09/29 Javascript
解决JavaScript数字精度丢失问题的方法
2015/12/03 Javascript
Bootstrap每天必学之级联下拉菜单
2016/03/27 Javascript
Angular.js ng-file-upload结合springMVC的使用教程
2017/07/10 Javascript
用javascript获取任意颜色的更亮或更暗颜色值示例代码
2017/07/21 Javascript
用Vue.extend构建消息提示组件的方法实例
2017/08/08 Javascript
vue.js 微信支付前端代码分享
2018/02/10 Javascript
JS实现的冒泡排序,快速排序,插入排序算法示例
2019/03/02 Javascript
vue获取data数据改变前后的值方法
2019/11/07 Javascript
解决vuex刷新数据消失问题
2020/11/12 Javascript
Python中os.path用法分析
2015/01/15 Python
Python中一些自然语言工具的使用的入门教程
2015/04/13 Python
Python对文件操作知识汇总
2016/05/15 Python
Python编程之string相关操作实例详解
2017/07/22 Python
python得到电脑的开机时间方法
2018/10/15 Python
python实现文件的备份流程详解
2019/06/18 Python
简单了解python 生成器 列表推导式 生成器表达式
2019/08/22 Python
详解django使用include无法跳转的解决方法
2020/03/19 Python
类、抽象类、接口的差异
2016/06/13 面试题
销售演讲稿范文
2014/01/08 职场文书
《难忘的泼水节》教学反思
2014/02/27 职场文书
校长竞聘演讲稿
2014/05/16 职场文书
借条格式范本
2015/05/25 职场文书
2015上半年个人工作总结
2015/07/27 职场文书
话题作文之呼唤
2019/12/18 职场文书
开发一个封装iframe的vue组件
2021/03/29 Vue.js
了解Redis常见应用场景
2021/06/23 Redis