keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的udp协议Server和Client代码实例
Jun 04 Python
Python3中的列表,元组,字典,字符串相关知识小结
Nov 10 Python
使用python实现knn算法
Dec 20 Python
Python模拟脉冲星伪信号频率实例代码
Jan 03 Python
tensorflow实现图像的裁剪和填充方法
Jul 27 Python
面向初学者的Python编辑器Mu
Oct 08 Python
python 通过SSHTunnelForwarder隧道连接redis的方法
Feb 19 Python
Django之创建引擎索引报错及解决详解
Jul 17 Python
基于python操作ES实例详解
Nov 16 Python
python全局变量引用与修改过程解析
Jan 07 Python
详解Python的三种拷贝方式
Feb 11 Python
python 实现体质指数BMI计算
May 26 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
php4的session功能评述(二)
2006/10/09 PHP
屏蔽浏览器缓存另类方法
2006/10/09 PHP
全世界最小的php网页木马一枚 附PHP木马的防范方法
2009/10/09 PHP
PHP 字符串正则替换函数preg_replace使用说明
2011/07/15 PHP
md5 16位二进制与32位字符串相互转换示例
2013/12/30 PHP
PHP单元测试框架PHPUnit用法详解
2019/01/23 PHP
PHP反射基础知识回顾
2020/09/10 PHP
iis6+javascript Add an Extension File
2007/06/13 Javascript
jQuery 常见开发使用技巧总结
2009/12/26 Javascript
javaScript(JS)替换节点实现思路介绍
2013/04/17 Javascript
深入理解javascript原型链和继承
2014/09/23 Javascript
使用RequireJS优化JavaScript引用代码的方法
2015/07/01 Javascript
JavaScript中style.left与offsetLeft的使用及区别详解
2016/06/08 Javascript
微信js-sdk界面操作接口用法示例
2016/10/12 Javascript
Vue.js绑定HTML class数组语法错误的原因分析
2016/10/19 Javascript
vue 如何添加全局函数或全局变量以及单页面的title设置总结
2017/06/01 Javascript
angular2+node.js express打包部署的实战
2017/07/27 Javascript
详解webpack多页面配置记录
2018/01/22 Javascript
webpack4 SCSS提取和懒加载的示例
2018/09/03 Javascript
Vue如何实现验证码输入交互
2020/12/07 Vue.js
Python实现的多线程端口扫描工具分享
2015/01/21 Python
Python实现批量检测HTTP服务的状态
2016/10/27 Python
Pyorch之numpy与torch之间相互转换方式
2019/12/31 Python
更新升级python和pip版本后不生效的问题解决
2020/04/17 Python
python中str内置函数用法总结
2020/12/27 Python
Qoo10台湾站:亚洲领先的在线市场
2018/05/15 全球购物
施华洛世奇意大利官网:SWAROVSKI意大利
2018/07/23 全球购物
先进工作者获奖感言
2014/02/08 职场文书
《可爱的动物》教学反思
2014/02/22 职场文书
幼儿园春季开学寄语
2014/04/03 职场文书
拒绝黄毒毒宣传标语
2014/06/26 职场文书
上下班时间调整通知
2015/04/23 职场文书
投资申请报告
2015/05/19 职场文书
党支部培养考察意见
2015/06/02 职场文书
《风娃娃》教学反思
2016/02/18 职场文书
少年的你:世界上没有如果,要在第一次就勇敢的反抗
2019/11/20 职场文书