Tensorflow之构建自己的图片数据集TFrecords的方法


Posted in Python onFebruary 07, 2018

学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。

tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图片数据集tfrecords。

流程是:制作数据集—读取数据集—-加入队列

先贴完整的代码:

#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = os.getcwd()

classes = {'test','test1','test2'}
#制作二进制数据
def create_record():
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((64, 64))
      img_raw = img.tobytes() #将图片转化为原生bytes
      print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(feature={
          "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
          'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
      writer.write(example.SerializeToString())
  writer.close()

data = create_record()

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

if __name__ == '__main__':
  if 0:
    data = create_record("train.tfrecords")
  else:
    img, label = read_and_decode("train.tfrecords")
    print "tengxing",img,label
    #使用shuffle_batch可以随机打乱输入 next_batch挨着往下取
    # shuffle_batch才能实现[img,label]的同步,也即特征和label的同步,不然可能输入的特征和label不匹配
    # 比如只有这样使用,才能使img和label一一对应,每次提取一个image和对应的label
    # shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果
    # Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的[img,label],送入队列中
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                          batch_size=4, capacity=2000,
                          min_after_dequeue=1000)

    # 初始化所有的op
    init = tf.initialize_all_variables()

    with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

制作数据集

#制作二进制数据
def create_record():
  cwd = os.getcwd()
  classes = {'1','2','3'}
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((28, 28))
      img_raw = img.tobytes() #将图片转化为原生bytes
      #print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(
          feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
          }
        )
      )
      writer.write(example.SerializeToString())
  writer.close()

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

读取数据集

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List

加入队列

with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

这样就可以的到和tensorflow官方的二进制数据集了,

注意:

  1. 启动队列那条code不要忘记,不然卡死
  2. 使用的时候记得使用val和l,不然会报类型错误:TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.
  3. 算交叉熵时候:cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels)算交叉熵
  4. 最后评估的时候用tf.nn.in_top_k(logits,labels,1)选logits最大的数的索引和label比较
  5. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))算交叉熵,所以label必须转成one-hot向量

实例2:将图片文件夹下的图片转存tfrecords的数据集。

############################################################################################ 
#!/usr/bin/python2.7 
# -*- coding: utf-8 -*- 
#Author : zhaoqinghui 
#Date  : 2016.5.10 
#Function: image convert to tfrecords  
############################################################################################# 
 
import tensorflow as tf 
import numpy as np 
import cv2 
import os 
import os.path 
from PIL import Image 
 
#参数设置 
############################################################################################### 
train_file = 'train.txt' #训练图片 
name='train'   #生成train.tfrecords 
output_directory='./tfrecords' 
resize_height=32 #存储图片高度 
resize_width=32 #存储图片宽度 
############################################################################################### 
def _int64_feature(value): 
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
def load_file(examples_list_file): 
  lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')]) 
  examples = [] 
  labels = [] 
  for example, label in lines: 
    examples.append(example) 
    labels.append(label) 
  return np.asarray(examples), np.asarray(labels), len(lines) 
 
def extract_image(filename, resize_height, resize_width): 
  image = cv2.imread(filename) 
  image = cv2.resize(image, (resize_height, resize_width)) 
  b,g,r = cv2.split(image)     
  rgb_image = cv2.merge([r,g,b])    
  return rgb_image 
 
def transform2tfrecord(train_file, name, output_directory, resize_height, resize_width): 
  if not os.path.exists(output_directory) or os.path.isfile(output_directory): 
    os.makedirs(output_directory) 
  _examples, _labels, examples_num = load_file(train_file) 
  filename = output_directory + "/" + name + '.tfrecords' 
  writer = tf.python_io.TFRecordWriter(filename) 
  for i, [example, label] in enumerate(zip(_examples, _labels)): 
    print('No.%d' % (i)) 
    image = extract_image(example, resize_height, resize_width) 
    print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label)) 
    image_raw = image.tostring() 
    example = tf.train.Example(features=tf.train.Features(feature={ 
      'image_raw': _bytes_feature(image_raw), 
      'height': _int64_feature(image.shape[0]), 
      'width': _int64_feature(image.shape[1]), 
      'depth': _int64_feature(image.shape[2]), 
      'label': _int64_feature(label) 
    })) 
    writer.write(example.SerializeToString()) 
  writer.close() 
 
def disp_tfrecords(tfrecord_list_file): 
  filename_queue = tf.train.string_input_producer([tfrecord_list_file]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
 features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  #print(repr(image)) 
  height = features['height'] 
  width = features['width'] 
  depth = features['depth'] 
  label = tf.cast(features['label'], tf.int32) 
  init_op = tf.initialize_all_variables() 
  resultImg=[] 
  resultLabel=[] 
  with tf.Session() as sess: 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
    for i in range(21): 
      image_eval = image.eval() 
      resultLabel.append(label.eval()) 
      image_eval_reshape = image_eval.reshape([height.eval(), width.eval(), depth.eval()]) 
      resultImg.append(image_eval_reshape) 
      pilimg = Image.fromarray(np.asarray(image_eval_reshape)) 
      pilimg.show() 
    coord.request_stop() 
    coord.join(threads) 
    sess.close() 
  return resultImg,resultLabel 
 
def read_tfrecord(filename_queuetemp): 
  filename_queue = tf.train.string_input_producer([filename_queuetemp]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
    features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  # image 
  tf.reshape(image, [256, 256, 3]) 
  # normalize 
  image = tf.cast(image, tf.float32) * (1. /255) - 0.5 
  # label 
  label = tf.cast(features['label'], tf.int32) 
  return image, label 
 
def test(): 
  transform2tfrecord(train_file, name , output_directory, resize_height, resize_width) #转化函数   
  img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords') #显示函数 
  img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords') #读取函数 
  print label 
 
if __name__ == '__main__': 
  test()

这样就可以得到自己专属的数据集.tfrecords了  ,它可以直接用于tensorflow的数据集。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
php使用递归与迭代实现快速排序示例
Jan 23 Python
python实现进程间通信简单实例
Jul 23 Python
在Python的Django框架中创建语言文件
Jul 27 Python
Python3实现的画图及加载图片动画效果示例
Jan 19 Python
Windows下的Jupyter Notebook 安装与自定义启动(图文详解)
Feb 21 Python
python实现梯度下降算法
Mar 24 Python
利用Python如何实现一个小说网站雏形
Nov 23 Python
在Python 中实现图片加框和加字的方法
Jan 26 Python
机器学习实战之knn算法pandas
Jun 22 Python
keras的ImageDataGenerator和flow()的用法说明
Jul 03 Python
Python 打印自己设计的字体的实例讲解
Jan 04 Python
Python matplotlib安装以及实现简单曲线的绘制
Apr 26 Python
python深度优先搜索和广度优先搜索
Feb 07 #Python
Python Flask基础教程示例代码
Feb 07 #Python
Python装饰器用法实例总结
Feb 07 #Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 #Python
Python自定义线程池实现方法分析
Feb 07 #Python
使用apidoc管理RESTful风格Flask项目接口文档方法
Feb 07 #Python
Python列表推导式、字典推导式与集合推导式用法实例分析
Feb 07 #Python
You might like
PHP 和 XML: 使用expat函数(一)
2006/10/09 PHP
php过滤所有恶意字符(批量过滤post,get敏感数据)
2014/03/18 PHP
php实现表单多按钮提交action的处理方法
2015/10/24 PHP
[原创]静态页面也可以实现预览 列表不同的显示方式
2006/10/14 Javascript
DHTML Slide Show script图片轮换
2008/03/03 Javascript
javascript globalStorage类代码
2009/06/04 Javascript
jquery中获取id值方法小结
2013/09/22 Javascript
JavaScript实现找出数组中最长的连续数字序列
2014/09/03 Javascript
jquery调取json数据实现省市级联的方法
2015/01/29 Javascript
js比较日期大小的方法
2015/05/12 Javascript
jQuery ui autocomplete选择列表被Bootstrap模态窗遮挡的完美解决方法
2016/09/23 Javascript
JavaScript生成简单等差数列
2017/11/28 Javascript
浅谈MUI框架中加载外部网页或服务器数据的方法
2018/01/31 Javascript
vue.js中npm安装教程图解
2018/04/10 Javascript
解决node修改后需频繁手动重启的问题
2018/05/13 Javascript
Vue引入Stylus知识点总结
2020/01/16 Javascript
JavaScript交换变量的常用方法小结【4种方法】
2020/05/07 Javascript
JavaScript 中判断变量是否为数字的示例代码
2020/10/22 Javascript
关于vue 项目中浏览器跨域的配置问题
2020/11/10 Javascript
[43:48]Ti4正赛第一天 VG vs NEWBEE 2
2014/07/19 DOTA
[54:47]Liquid vs VP Supermajor决赛 BO 第五场 6.10
2018/07/05 DOTA
利用Python画ROC曲线和AUC值计算
2016/09/19 Python
python爬虫_自动获取seebug的poc实例
2017/08/05 Python
启动Atom并运行python文件的步骤
2018/11/09 Python
Python基于正则表达式实现计算器功能
2020/07/13 Python
CSS3实现粒子旋转伸缩加载动画
2016/04/22 HTML / CSS
美国在线家装零售商:Build.com
2016/09/02 全球购物
爱尔兰电子产品购物网站:Komplett.ie
2018/04/04 全球购物
SmartBuyGlasses荷兰:购买太阳镜和眼镜
2020/03/16 全球购物
共产党员公开承诺书范文
2014/03/28 职场文书
个人银行贷款担保书
2014/04/01 职场文书
组工干部演讲稿
2014/09/02 职场文书
2015年法院工作总结范文
2015/04/28 职场文书
新年晚会主持词开场白
2015/05/28 职场文书
初中地理教学反思
2016/02/19 职场文书
导游词之南京汤山温泉
2019/11/26 职场文书