keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现从字符串中找出字符1的位置以及个数的方法
Aug 25 Python
详解MySQL数据类型int(M)中M的含义
Nov 20 Python
Python3中条件控制、循环与函数的简易教程
Nov 21 Python
Python3之文件读写操作的实例讲解
Jan 23 Python
python如何实现int函数的方法示例
Feb 19 Python
解决pip install的时候报错timed out的问题
Jun 12 Python
Python中logging.NullHandler 的使用教程
Nov 29 Python
使用Fabric自动化部署Django项目的实现
Sep 27 Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 Python
opencv中图像叠加/图像融合/按位操作的实现
Apr 01 Python
详解python中[-1]、[:-1]、[::-1]、[n::-1]使用方法
Apr 25 Python
Python中json.load()和json.loads()有哪些区别
Jun 07 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
《魔兽世界》惊魂幻象将获得调整
2020/03/08 其他游戏
使用 php4 加速 web 传输
2006/10/09 PHP
使用php测试硬盘写入速度示例
2014/01/27 PHP
thinkphp使用literal防止模板标签被解析的方法
2014/11/22 PHP
php判断并删除空目录及空子目录的方法
2015/02/11 PHP
PHP使用stream_context_create()模拟POST/GET请求的方法
2016/04/02 PHP
PHP编程快速实现数组去重的方法详解
2017/07/22 PHP
javascript下有关dom以及xml节点访问兼容问题
2007/11/26 Javascript
js面向对象编程之如何实现方法重载
2014/07/02 Javascript
JavaScript设置获取和设置属性的方法
2015/03/04 Javascript
Jquery ajax基础教程
2015/11/20 Javascript
微信小程序 less文件编译成wxss文件实现办法
2016/12/05 Javascript
JS中的phototype详解
2017/02/04 Javascript
jQuery插件HighCharts绘制简单2D折线图效果示例【附demo源码】
2017/03/21 jQuery
Node.js进阶之核心模块https入门
2018/05/23 Javascript
微信小程序城市选择及搜索功能的方法
2019/03/22 Javascript
ZK中使用JS读取客户端txt文件内容问题
2019/11/07 Javascript
[54:53]完美世界DOTA2联赛PWL S2 GXR vs PXG 第二场 11.18
2020/11/18 DOTA
[01:03:50]DOTA2-DPC中国联赛 正赛 CDEC vs DLG BO3 第二场 2月7日
2021/03/11 DOTA
python检测lvs real server状态
2014/01/22 Python
python网络编程之数据传输UDP实例分析
2015/05/20 Python
Python基于PycURL实现POST的方法
2015/07/25 Python
python实现傅里叶级数展开的实现
2018/07/21 Python
总结python中pass的作用
2019/02/27 Python
使用keras实现孪生网络中的权值共享教程
2020/06/11 Python
python3实现飞机大战
2020/11/29 Python
美国最大的存储市场:SpareFoot
2018/07/23 全球购物
抽象方法、抽象类怎样声明
2014/10/25 面试题
办公室人员先进事迹
2014/01/27 职场文书
违反工作纪律检讨书
2014/02/15 职场文书
文秘应聘自荐书范文
2014/02/18 职场文书
人事行政经理岗位职责
2014/06/18 职场文书
小学生暑假安全保证书
2015/07/13 职场文书
2016年端午节红领巾广播稿
2015/12/18 职场文书
幼儿园教师暑期培训心得体会
2016/01/09 职场文书
vue elementUI批量上传文件
2022/04/26 Vue.js