keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过colorama模块在控制台输出彩色文字的方法
Mar 19 Python
python实用代码片段收集贴
Jun 03 Python
Python编程中time模块的一些关键用法解析
Jan 19 Python
利用Python暴力破解zip文件口令的方法详解
Dec 21 Python
python2和python3在处理字符串上的区别详解
May 29 Python
python redis 批量设置过期key过程解析
Nov 26 Python
python 统计文件中的字符串数目示例
Dec 24 Python
python实现异常信息堆栈输出到日志文件
Dec 26 Python
Jupyter notebook无法导入第三方模块的解决方式
Apr 15 Python
使用K.function()调试keras操作
Jun 17 Python
PIP和conda 更换国内安装源的方法步骤
Sep 21 Python
Python 数据可视化工具 Pyecharts 安装及应用
Apr 20 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
PHP5中新增stdClass 内部保留类
2011/06/13 PHP
PHP截断标题且兼容utf8和gb2312编码
2013/09/22 PHP
php过滤html中的其他网站链接的方法(域名白名单功能)
2014/04/24 PHP
ThinkPHP的I方法使用详解
2014/06/18 PHP
php常用日期时间函数实例小结
2019/07/04 PHP
javascript cookies 设置、读取、删除实例代码
2010/04/12 Javascript
25个好玩的JavaScript小游戏分享
2011/04/22 Javascript
拉动滚动条加载数据的jquery代码
2012/05/03 Javascript
Jquery多选框互相内容交换的实例代码
2013/07/04 Javascript
利用window.name实现windowStorage代码分享
2014/01/02 Javascript
jquery教程ajax请求json数据示例
2014/01/13 Javascript
JS鼠标拖拽实例分析
2015/11/23 Javascript
超实用的JavaScript表单代码段
2016/02/26 Javascript
JavaScript实战之菜单特效
2016/08/16 Javascript
jQuery EasyUI常用数据验证汇总
2016/09/18 Javascript
利用angularjs1.4制作的简易滑动门效果
2017/02/28 Javascript
node.js平台下的mysql数据库配置及连接
2017/03/31 Javascript
bootstrap multiselect下拉列表功能
2017/08/22 Javascript
three.js实现3D影院的原理的代码分析
2017/12/18 Javascript
js数组方法reduce经典用法代码分享
2018/01/07 Javascript
VUE:vuex 用户登录信息的数据写入与获取方式
2019/11/11 Javascript
[06:45]2018DOTA2亚洲邀请赛 4.5 SOLO赛 Sccc vs Maybe
2018/04/06 DOTA
[08:08]DOTA2-DPC中国联赛2月28日Recap集锦
2021/03/11 DOTA
Python中isnumeric()方法的使用简介
2015/05/19 Python
python 全局变量的import机制介绍
2017/09/07 Python
解决c++调用python中文乱码问题
2020/07/29 Python
手把手教你配置JupyterLab 环境的实现
2021/02/02 Python
Omio法国:全欧洲低价大巴、火车和航班搜索和比价
2017/11/13 全球购物
几道Java和数据库的面试题
2013/05/30 面试题
环保倡议书300字
2014/05/15 职场文书
工作经常出错的检讨书
2014/09/13 职场文书
不服从公司安排检讨书
2014/09/24 职场文书
2014年话务员工作总结
2014/11/19 职场文书
2015年基建工作总结范文
2015/05/23 职场文书
2015年公路路政个人工作总结
2015/07/24 职场文书
Python 数据可视化神器Pyecharts绘制图像练习
2022/02/28 Python