Tensorflow分类器项目自定义数据读入的实现


Posted in Python onFebruary 05, 2019

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

Tensorflow分类器项目自定义数据读入的实现

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
  labels = []
  images = []
  print('Handle images')
  # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
  for index, name in enumerate(class_names):
    # 这里是拼接后的子目录path
    classpath = os.path.join(image_path, name)
    # 先判断一下是否是目录
    if not os.path.isdir(classpath):
      continue
    # limit是测试时候用的这里可以去除
    limit = 0
    for image_name in os.listdir(classpath):
      if limit >= max_size:
        break
      # 这里是拼接后的待处理的图片path
      imagepath = os.path.join(classpath, image_name)
      count = count + 1
      limit = limit + 1
      # 利用Image打开图片
      img = Image.open(imagepath)
      # 缩放到你最初确定要处理的图片分辨率大小
      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
      # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
      img = img.convert("L")
      # 转为numpy数组
      img = np.array(img)
      # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
      # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
      labels.append([index])
      # 添加到images中,最后统一处理
      images.append(img)
      # 循环中一些状态的输出,可以去除
      print("{} class: {} {} limit: {} {}"
         .format(count, index + 1, class_names[index], limit, imagepath))
  # 最后一次性将images和labels都转换成numpy数组
  npy_data = np.array(images)
  npy_labels = np.array(labels)
  # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
  np.savez(data_path, x=npy_data, y=npy_labels)
  print("Save images by npz")
else:
  # 如果存在序列化号的数据,便直接读取,提高速度
  npy_data = np.load(data_path)["x"]
  npy_labels = np.load(data_path)["y"]
  print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
  train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中返回字典键的值的values()方法使用
May 22 Python
使用python3.5仿微软记事本notepad
Jun 15 Python
浅谈用VSCode写python的正确姿势
Dec 16 Python
python实现二叉查找树实例代码
Feb 08 Python
python去掉空白行的多种实现代码
Mar 19 Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 Python
Python使用Pandas库常见操作详解
Jan 16 Python
从训练好的tensorflow模型中打印训练变量实例
Jan 20 Python
基于python SMTP实现自动发送邮件教程解析
Jun 02 Python
pytorch快速搭建神经网络_Sequential操作
Jun 17 Python
python实现简单的学生管理系统
Feb 22 Python
Python中zipfile压缩包模块的使用
May 14 Python
在Python 字典中一键对应多个值的实例
Feb 03 #Python
Django csrf 两种方法设置form的实例
Feb 03 #Python
解决django前后端分离csrf验证的问题
Feb 03 #Python
Python利用heapq实现一个优先级队列的方法
Feb 03 #Python
对Python3中dict.keys()转换成list类型的方法详解
Feb 03 #Python
对python中字典keys,values,items的使用详解
Feb 03 #Python
python生成带有表格的图片实例
Feb 03 #Python
You might like
php中使用接口实现工厂设计模式的代码
2012/06/17 PHP
php curl优化下载微信头像的方法总结
2018/09/07 PHP
JSChart轻量级图形报表工具(内置函数中文参考)
2010/10/11 Javascript
再说AutoComplete自动补全之实现原理
2011/11/05 Javascript
JS实现简单的Canvas画图实例
2013/07/04 Javascript
javascript实现根据身份证号读取相关信息
2014/12/17 Javascript
JavaScript将一个数组插入到另一个数组的方法
2015/03/19 Javascript
微信小程序进行微信支付的步骤昂述
2016/12/01 Javascript
js获取隐藏元素的宽高
2017/02/24 Javascript
bootstrap插件treeview实现全选父节点下所有子节点和反选功能
2017/07/21 Javascript
Angular4如何自定义首屏的加载动画详解
2017/07/26 Javascript
详解vue项目首页加载速度优化
2017/10/18 Javascript
深入浅析vue-cli@3.0 使用及配置说明
2019/05/08 Javascript
详解VUE调用本地json的使用方法
2019/05/15 Javascript
axios 实现post请求时把对象obj数据转为formdata
2019/10/31 Javascript
Vue项目开发常见问题和解决方案总结
2020/09/11 Javascript
详解Vite的新体验
2021/02/22 Javascript
快速排序的算法思想及Python版快速排序的实现示例
2016/07/02 Python
Python subprocess模块功能与常见用法实例详解
2018/06/28 Python
python datetime中strptime用法详解
2019/08/29 Python
PYQT5 vscode联合操作qtdesigner的方法
2020/03/24 Python
HTML5实现音频和视频嵌入的方法
2018/08/22 HTML / CSS
海滩咖啡馆:Beach Cafe
2018/02/02 全球购物
澳大利亚有机化妆品网上商店:The Well Store
2020/02/20 全球购物
结婚典礼证婚词
2014/01/11 职场文书
手术室护士长竞聘书
2014/03/31 职场文书
党的群众路线教育实践活动查摆问题及整改措施
2014/10/10 职场文书
司法局群众路线教育实践活动整改措施思想汇报
2014/10/13 职场文书
简历自我评价优缺点
2015/03/11 职场文书
王亚平太空授课观后感
2015/06/12 职场文书
毕业感言怎么写
2015/07/31 职场文书
使用numpy nonzero 找出非0元素
2021/05/14 Python
Python操作CSV格式文件的方法大全
2021/07/15 Python
mybatis中注解与xml配置的对应关系和对比分析
2021/08/04 Java/Android
手写实现JS中的new
2021/11/07 Javascript
canvas 中如何实现物体的框选
2022/08/05 Javascript