tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用beautifulsoup从爱奇艺网抓取视频播放
Jan 23 Python
Python的Django框架中URLconf相关的一些技巧整理
Jul 18 Python
Python的Flask开发框架简单上手笔记
Nov 16 Python
Python基于pygame模块播放MP3的方法示例
Sep 30 Python
Python学习之Anaconda的使用与配置方法
Jan 04 Python
python实现推箱子游戏
Mar 25 Python
PyQt5组件读取参数的实例
Jun 25 Python
Python MongoDB 插入数据时已存在则不执行,不存在则插入的解决方法
Sep 24 Python
浅谈PyQt5中异步刷新UI和Python多线程总结
Dec 13 Python
python3连接MySQL8.0的两种方式
Feb 17 Python
用Python爬取LOL所有的英雄信息以及英雄皮肤的示例代码
Jul 13 Python
给numpy.array增加维度的超简单方法
Jun 02 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
《一拳超人》埼玉一拳下去,他们存在了800年毫无意义!
2020/03/02 日漫
php 读取shell管道传输过来的内容
2010/03/01 PHP
php 伪造本地文件包含漏洞的代码
2011/11/03 PHP
自定义php类(查找/修改)xml文档
2013/03/26 PHP
Yii2中SqlDataProvider用法示例
2016/09/22 PHP
php源码 fsockopen获取网页内容实例详解
2016/09/24 PHP
php中10个不同等级压缩优化图片操作示例
2016/11/14 PHP
PHP读取word文档的方法分析【基于COM组件】
2017/08/01 PHP
Laravel实现短信注册的示例代码
2018/05/29 PHP
Javascript 不能释放内存.
2006/09/07 Javascript
将HTMLCollection/NodeList/伪数组转换成数组的实现方法
2011/06/20 Javascript
基于JQuery实现的Select级联
2014/01/27 Javascript
JavaScript中的console.group()函数详细介绍
2014/12/29 Javascript
JavaScript 实现打印,打印预览,打印设置
2014/12/30 Javascript
基于JavaScript实现的折半查找算法示例
2017/04/14 Javascript
在vue中使用v-bind:class的选项卡方法
2018/09/27 Javascript
微信小程序使用for循环动态渲染页面操作示例
2018/12/25 Javascript
ES6 Generator函数的应用实例分析
2019/06/26 Javascript
vue实现放大镜效果
2020/09/17 Javascript
Python实现提取文章摘要的方法
2015/04/21 Python
python将字符串转换成数组的方法
2015/04/29 Python
关于Python 3中print函数的换行详解
2017/08/08 Python
Python求解任意闭区间的所有素数
2018/06/10 Python
基于python实现学生管理系统
2018/10/17 Python
tensorflow图像裁剪进行数据增强操作
2020/06/30 Python
使用css实现android系统的loading加载动画
2019/07/25 HTML / CSS
如何在Canvas上的图形/图像绑定事件监听的实现
2020/09/16 HTML / CSS
Carter’s OshKosh加拿大:购买婴幼儿服装和童装
2018/11/27 全球购物
意大利体育用品和运动服网上商店:Maxi Sport
2019/09/14 全球购物
红领巾心向党广播稿
2014/01/19 职场文书
动员大会主持词
2014/03/20 职场文书
奉献家乡演讲稿
2014/09/16 职场文书
四风问题个人对照检查材料
2014/09/26 职场文书
试用期解除劳动合同通知书
2015/04/16 职场文书
你真的会用Mysql的explain吗
2022/03/31 MySQL
向Spring IOC 容器动态注册bean实现方式
2022/07/15 Java/Android