将自己的数据集制作成TFRecord格式教程


Posted in Python onFebruary 17, 2020

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

将自己的数据集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename): 
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python批量修改文件后缀的方法
Jan 26 Python
Python语言实现百度语音识别API的使用实例
Dec 13 Python
mac系统安装Python3初体验
Jan 02 Python
Python中函数的基本定义与调用及内置函数详解
May 13 Python
OpenCV图像颜色反转算法详解
May 13 Python
浅谈Django中的QueryDict元素为数组的坑
Mar 31 Python
python属于软件吗
Jun 18 Python
浅谈python处理json和redis hash的坑
Jul 16 Python
Python os库常用操作代码汇总
Nov 03 Python
用python制作个视频下载器
Feb 01 Python
Python实现GIF动图以及视频卡通化详解
Dec 06 Python
LeetCode189轮转数组python示例
Aug 05 Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
python itsdangerous模块的具体使用方法
Feb 17 #Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 #Python
TensorFlow通过文件名/文件夹名获取标签,并加入队列的实现
Feb 17 #Python
Django 项目通过加载不同env文件来区分不同环境
Feb 17 #Python
You might like
简体中文转换为繁体中文的PHP函数
2006/10/09 PHP
PHP批量检测并去除文件BOM头代码实例
2014/05/08 PHP
PHP判断一个gif图片是否为动态图片的方法
2014/11/19 PHP
护卫神php套件 php版本升级方法(php5.5.24)
2015/05/10 PHP
Yii2.0 Basic代码中路由链接被转义的处理方法
2016/09/21 PHP
jQuery+css实现图片滚动效果(附源码)
2013/03/18 Javascript
面向对象继承实例(a如何继承b问题)(自写)
2013/07/01 Javascript
了不起的node.js读书笔记之mongodb数据库交互
2014/12/22 Javascript
js面向对象之公有、私有、静态属性和方法详解
2015/04/17 Javascript
jquery 中ajax执行的优先级
2015/06/22 Javascript
所见即所得的富文本编辑器bootstrap-wysiwyg使用方法详解
2016/05/27 Javascript
js实现图片缓慢放大缩小效果
2016/08/02 Javascript
基于Bootstrap的Metronic框架实现页面链接收藏夹功能
2016/08/29 Javascript
浅谈JS使用[ ]来访问对象属性
2016/09/21 Javascript
Bootstrap 3.x打印预览背景色与文字显示异常的解决
2016/11/06 Javascript
Spring shiro + bootstrap + jquery.validate 实现登录、注册功能
2017/06/02 jQuery
Vue.js 动态为img的src赋值方法
2018/03/14 Javascript
electron实现qq快捷登录的方法示例
2018/10/22 Javascript
Mint UI实现A-Z字母排序的城市选择列表
2018/12/28 Javascript
tsconfig.json配置详解
2019/05/17 Javascript
vue读取本地的excel文件并显示在网页上方法示例
2019/05/29 Javascript
Vue 图片压缩并上传至服务器功能
2020/01/15 Javascript
vscode 使用Prettier插件格式化配置使用代码详解
2020/08/10 Javascript
Vue 实现监听窗口关闭事件,并在窗口关闭前发送请求
2020/09/01 Javascript
JavaScript实现滑块验证解锁
2021/01/07 Javascript
[00:23]DOTA2群星共贺开放测试 25日无码时代来袭
2013/09/23 DOTA
[49:58]完美世界DOTA2联赛PWL S3 Magma vs DLG 第一场 12.18
2020/12/19 DOTA
Python本地与全局命名空间用法实例
2015/06/16 Python
python基于pygame实现响应游戏中事件的方法(附源码)
2015/11/11 Python
Python导入模块时遇到的错误分析
2017/08/30 Python
python3实现基于用户的协同过滤
2018/05/31 Python
python 读取摄像头数据并保存的实例
2018/08/03 Python
全球度假村:Club Med
2017/11/27 全球购物
如何在存储过程中使用Loop
2016/01/05 面试题
会计求职信
2014/05/29 职场文书
mysql配置SSL证书登录的实现
2021/09/04 MySQL