Tensorflow之构建自己的图片数据集TFrecords的方法


Posted in Python onFebruary 07, 2018

学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。

tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图片数据集tfrecords。

流程是:制作数据集—读取数据集—-加入队列

先贴完整的代码:

#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = os.getcwd()

classes = {'test','test1','test2'}
#制作二进制数据
def create_record():
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((64, 64))
      img_raw = img.tobytes() #将图片转化为原生bytes
      print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(feature={
          "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
          'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
      writer.write(example.SerializeToString())
  writer.close()

data = create_record()

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

if __name__ == '__main__':
  if 0:
    data = create_record("train.tfrecords")
  else:
    img, label = read_and_decode("train.tfrecords")
    print "tengxing",img,label
    #使用shuffle_batch可以随机打乱输入 next_batch挨着往下取
    # shuffle_batch才能实现[img,label]的同步,也即特征和label的同步,不然可能输入的特征和label不匹配
    # 比如只有这样使用,才能使img和label一一对应,每次提取一个image和对应的label
    # shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果
    # Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的[img,label],送入队列中
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                          batch_size=4, capacity=2000,
                          min_after_dequeue=1000)

    # 初始化所有的op
    init = tf.initialize_all_variables()

    with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

制作数据集

#制作二进制数据
def create_record():
  cwd = os.getcwd()
  classes = {'1','2','3'}
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((28, 28))
      img_raw = img.tobytes() #将图片转化为原生bytes
      #print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(
          feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
          }
        )
      )
      writer.write(example.SerializeToString())
  writer.close()

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

读取数据集

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List

加入队列

with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

这样就可以的到和tensorflow官方的二进制数据集了,

注意:

  1. 启动队列那条code不要忘记,不然卡死
  2. 使用的时候记得使用val和l,不然会报类型错误:TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.
  3. 算交叉熵时候:cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels)算交叉熵
  4. 最后评估的时候用tf.nn.in_top_k(logits,labels,1)选logits最大的数的索引和label比较
  5. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))算交叉熵,所以label必须转成one-hot向量

实例2:将图片文件夹下的图片转存tfrecords的数据集。

############################################################################################ 
#!/usr/bin/python2.7 
# -*- coding: utf-8 -*- 
#Author : zhaoqinghui 
#Date  : 2016.5.10 
#Function: image convert to tfrecords  
############################################################################################# 
 
import tensorflow as tf 
import numpy as np 
import cv2 
import os 
import os.path 
from PIL import Image 
 
#参数设置 
############################################################################################### 
train_file = 'train.txt' #训练图片 
name='train'   #生成train.tfrecords 
output_directory='./tfrecords' 
resize_height=32 #存储图片高度 
resize_width=32 #存储图片宽度 
############################################################################################### 
def _int64_feature(value): 
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
def load_file(examples_list_file): 
  lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')]) 
  examples = [] 
  labels = [] 
  for example, label in lines: 
    examples.append(example) 
    labels.append(label) 
  return np.asarray(examples), np.asarray(labels), len(lines) 
 
def extract_image(filename, resize_height, resize_width): 
  image = cv2.imread(filename) 
  image = cv2.resize(image, (resize_height, resize_width)) 
  b,g,r = cv2.split(image)     
  rgb_image = cv2.merge([r,g,b])    
  return rgb_image 
 
def transform2tfrecord(train_file, name, output_directory, resize_height, resize_width): 
  if not os.path.exists(output_directory) or os.path.isfile(output_directory): 
    os.makedirs(output_directory) 
  _examples, _labels, examples_num = load_file(train_file) 
  filename = output_directory + "/" + name + '.tfrecords' 
  writer = tf.python_io.TFRecordWriter(filename) 
  for i, [example, label] in enumerate(zip(_examples, _labels)): 
    print('No.%d' % (i)) 
    image = extract_image(example, resize_height, resize_width) 
    print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label)) 
    image_raw = image.tostring() 
    example = tf.train.Example(features=tf.train.Features(feature={ 
      'image_raw': _bytes_feature(image_raw), 
      'height': _int64_feature(image.shape[0]), 
      'width': _int64_feature(image.shape[1]), 
      'depth': _int64_feature(image.shape[2]), 
      'label': _int64_feature(label) 
    })) 
    writer.write(example.SerializeToString()) 
  writer.close() 
 
def disp_tfrecords(tfrecord_list_file): 
  filename_queue = tf.train.string_input_producer([tfrecord_list_file]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
 features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  #print(repr(image)) 
  height = features['height'] 
  width = features['width'] 
  depth = features['depth'] 
  label = tf.cast(features['label'], tf.int32) 
  init_op = tf.initialize_all_variables() 
  resultImg=[] 
  resultLabel=[] 
  with tf.Session() as sess: 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
    for i in range(21): 
      image_eval = image.eval() 
      resultLabel.append(label.eval()) 
      image_eval_reshape = image_eval.reshape([height.eval(), width.eval(), depth.eval()]) 
      resultImg.append(image_eval_reshape) 
      pilimg = Image.fromarray(np.asarray(image_eval_reshape)) 
      pilimg.show() 
    coord.request_stop() 
    coord.join(threads) 
    sess.close() 
  return resultImg,resultLabel 
 
def read_tfrecord(filename_queuetemp): 
  filename_queue = tf.train.string_input_producer([filename_queuetemp]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
    features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  # image 
  tf.reshape(image, [256, 256, 3]) 
  # normalize 
  image = tf.cast(image, tf.float32) * (1. /255) - 0.5 
  # label 
  label = tf.cast(features['label'], tf.int32) 
  return image, label 
 
def test(): 
  transform2tfrecord(train_file, name , output_directory, resize_height, resize_width) #转化函数   
  img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords') #显示函数 
  img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords') #读取函数 
  print label 
 
if __name__ == '__main__': 
  test()

这样就可以得到自己专属的数据集.tfrecords了  ,它可以直接用于tensorflow的数据集。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常用内置函数总结
Feb 08 Python
Python实现将不规范的英文名字首字母大写
Nov 15 Python
Python数据结构与算法之常见的分配排序法示例【桶排序与基数排序】
Dec 15 Python
python时间日期函数与利用pandas进行时间序列处理详解
Mar 13 Python
Python实现动态添加属性和方法操作示例
Jul 25 Python
解决python执行不输出系统命令弹框的问题
Jun 24 Python
Pycharm运行加载文本出现错误的解决方法
Jun 27 Python
调试Django时打印SQL语句的日志代码实例
Sep 12 Python
TensorFlow Saver:保存和读取模型参数.ckpt实例
Feb 10 Python
Python面向对象魔法方法和单例模块代码实例
Mar 25 Python
学点简单的Django之第一个Django程序的实现
Feb 24 Python
Python控制台输出俄罗斯方块移动和旋转功能
Apr 18 Python
python深度优先搜索和广度优先搜索
Feb 07 #Python
Python Flask基础教程示例代码
Feb 07 #Python
Python装饰器用法实例总结
Feb 07 #Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 #Python
Python自定义线程池实现方法分析
Feb 07 #Python
使用apidoc管理RESTful风格Flask项目接口文档方法
Feb 07 #Python
Python列表推导式、字典推导式与集合推导式用法实例分析
Feb 07 #Python
You might like
不用iconv库的gb2312与utf-8的互换函数
2006/10/09 PHP
mysql下创建字段并设置主键的php代码
2010/05/16 PHP
Look And Say 序列php实现代码
2011/05/22 PHP
php删除文件夹及其文件夹下所有文件的函数代码
2013/01/23 PHP
PHP ajax 异步执行不等待执行结果的处理方法
2015/05/27 PHP
PHP Socket网络操作类定义与用法示例
2017/08/30 PHP
javascript编程起步(第四课)
2007/02/27 Javascript
Array.prototype.slice 使用扩展
2010/06/09 Javascript
JQuery对checkbox操作 (循环获取)
2011/05/20 Javascript
jQuery语法高亮插件支持各种程序源代码语法着色加亮
2013/04/27 Javascript
javascript获取下拉列表框当中的文本值示例代码
2013/07/31 Javascript
JS+css 图片自动缩放自适应大小
2013/08/08 Javascript
Blocksit插件实现瀑布流数据无限( 异步)加载
2014/06/20 Javascript
jQuery代码实现发展历程时间轴特效
2015/07/30 Javascript
js生成随机数的过程解析
2015/11/24 Javascript
通过javascript进行UTF-8编码的实现方法
2016/06/27 Javascript
利用Angularjs中模块ui-route管理状态的方法
2016/12/27 Javascript
JavaScript控制输入框中只能输入中文、数字和英文的方法【基于正则实现】
2017/03/03 Javascript
详谈AngularJs 控制器、数据绑定、作用域
2017/07/09 Javascript
Angular 4.0学习教程之架构详解
2017/09/12 Javascript
jq源码解析之绑在$,jQuery上面的方法(实例讲解)
2017/10/13 jQuery
vue router动态路由下让每个子路由都是独立组件的解决方案
2018/04/24 Javascript
jQuery.extend 与 jQuery.fn.extend的用法及区别实例分析
2018/07/25 jQuery
js代码规范之Eslint安装与配置详解
2018/09/08 Javascript
js中实例与对象的区别讲解
2019/01/21 Javascript
javascript使用正则表达式实现注册登入校验
2020/09/23 Javascript
python pickle存储、读取大数据量列表、字典数据的方法
2019/07/07 Python
Python实现某论坛自动签到功能
2019/08/20 Python
Python 私有化操作实例分析
2019/11/21 Python
css图标制作教程制作云图标
2014/01/19 HTML / CSS
在html5的Canvas上绘制椭圆的几种方法总结
2013/01/07 HTML / CSS
一些高难度的SQL面试题
2016/11/29 面试题
高中班长自我鉴定
2013/12/20 职场文书
会计专业应届生自荐信
2014/06/28 职场文书
食品安全主题班会
2015/08/13 职场文书
MongoDB支持的数据类型
2022/04/11 MongoDB