tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python调用微信公众平台接口操作示例
Jul 08 Python
pip命令无法使用的解决方法
Jun 12 Python
用xpath获取指定标签下的所有text的实例
Jan 02 Python
浅谈python标准库--functools.partial
Mar 13 Python
pygame实现成语填空游戏
Oct 29 Python
关于Numpy数据类型对象(dtype)使用详解
Nov 27 Python
python打印n位数“水仙花数”(实例代码)
Dec 25 Python
Keras使用ImageNet上预训练的模型方式
May 23 Python
记一次Django响应超慢的解决过程
Sep 17 Python
Python离线安装openpyxl模块的步骤
Mar 30 Python
python办公自动化之excel的操作
May 23 Python
利用python调用摄像头的实例分析
Jun 07 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
利用PHP访问带有密码的Redis方法示例
2017/02/09 PHP
PHP INT类型在内存中占字节详解
2019/07/20 PHP
在Windows上安装Node.js模块的方法
2011/09/25 Javascript
在JavaScript里嵌入大量字符串常量的实现方法
2013/07/07 Javascript
鼠标经过tr时,改变tr当前背景颜色
2014/01/13 Javascript
Windows系统中安装nodejs图文教程
2015/02/28 NodeJs
js支持键盘控制的左右切换立体式图片轮播效果代码分享
2015/08/26 Javascript
JS实现的自定义网页拖动类
2015/11/06 Javascript
Node.js中使用socket创建私聊和公聊聊天室
2015/11/19 Javascript
javascript性能优化之DOM交互操作实例分析
2015/12/12 Javascript
全面解析Bootstrap中nav、collapse的使用方法
2016/05/22 Javascript
使用JS正则表达式 替换括号,尖括号等
2016/11/29 Javascript
js实现返回顶部效果
2017/03/10 Javascript
Vue组件开发技巧总结
2018/03/04 Javascript
vue.js实现的经典计算器/科学计算器功能示例
2018/07/11 Javascript
详解微信JS-SDK选择图片遇到的坑
2018/08/15 Javascript
浅谈KOA2 Restful方式路由初探
2019/03/14 Javascript
Vue实战教程之仿肯德基宅急送App
2019/07/19 Javascript
PyMongo安装使用笔记
2015/04/27 Python
python中判断文件编码的chardet(实例讲解)
2017/12/21 Python
Python实现GUI学生信息管理系统
2020/04/05 Python
python调用staf自动化框架的方法
2018/12/26 Python
详解CSS3选择器的使用方法汇总
2015/11/24 HTML / CSS
Timberland俄罗斯官方网上商店:全球领先的户外品牌
2020/03/15 全球购物
公务员培训心得体会
2013/12/28 职场文书
房屋买卖协议书
2014/04/10 职场文书
2014年房地产销售工作总结
2014/12/01 职场文书
2014年留守儿童工作总结
2014/12/10 职场文书
刑事辩护词范文
2015/05/21 职场文书
演讲开头怎么书写?
2019/08/06 职场文书
《好妈妈胜过好老师》:每个孩子的优秀都是有源头的
2020/01/03 职场文书
react国际化react-intl的使用
2021/05/06 Javascript
浅谈Redis 中的过期删除策略和内存淘汰机制
2022/04/03 Redis
MySQL中优化SQL语句的方法(show status、explain分析服务器状态信息)
2022/04/09 MySQL
Spring Data JPA框架Repository自定义实现
2022/04/28 Java/Android
MySql统计函数COUNT的具体使用详解
2022/08/14 MySQL