浅谈keras通过model.fit_generator训练模型(节省内存)


Posted in Python onJune 17, 2020

前言

前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的。

如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。

1.fit_generator函数简介

fit_generator(generator, 
steps_per_epoch=None, 
epochs=1, 
verbose=1, 
callbacks=None, 
validation_data=None, 
validation_steps=None, 
class_weight=None, 
max_queue_size=10, 
workers=1, 
use_multiprocessing=False, 
shuffle=True, 
initial_epoch=0)

参数:

generator:一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例。这是我们实现的重点,后面会着介绍生成器和sequence的两种实现方式。

steps_per_epoch:这个是我们在每个epoch中需要执行多少次生成器来生产数据,fit_generator函数没有batch_size这个参数,是通过steps_per_epoch来实现的,每次生产的数据就是一个batch,因此steps_per_epoch的值我们通过会设为(样本数/batch_size)。如果我们的generator是sequence类型,那么这个参数是可选的,默认使用len(generator) 。

epochs:即我们训练的迭代次数。

verbose:0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行

callbacks:在训练时调用的一系列回调函数。

validation_data:和我们的generator类似,只是这个使用于验证的,不参与训练。

validation_steps:和前面的steps_per_epoch类似。

class_weight:可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。(感觉这个参数用的比较少)

max_queue_size:整数。生成器队列的最大尺寸。默认为10.

workers:整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing:布尔值。如果 True,则使用基于进程的多线程。默认为False。

shuffle:是否在每轮迭代之前打乱 batch 的顺序。 只能与Sequence(keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)

2.generator实现

2.1生成器的实现方式

样例代码:

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image

def process_x(path):
 img = Image.open(path)
 img = img.resize((96,96))
 img = img.convert('RGB')
 img = np.array(img)

 img = np.asarray(img, np.float32) / 255.0
 #也可以进行进行一些数据数据增强的处理
 return img

count =1
def generate_arrays_from_file(x_y):
 #x_y 是我们的训练集包括标签,每一行的第一个是我们的图片路径,后面的是我们的独热化后的标签

 global count
 batch_size = 8
 while 1:
  batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]
  batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]

  batch_x = np.array([process_x(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print("count:"+str(count))
  count = count+1
  yield (batch_x, batch_y)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

x_y = []
model.fit_generator(generate_arrays_from_file(x_y),steps_per_epoch=10, epochs=2,max_queue_size=1,workers=1)

在理解上面代码之前我们需要首先了解yield的用法。

yield关键字:

我们先通过一个例子看一下yield的用法:

def foo():
 print("starting...")
 while True:
  res = yield 4
  print("res:",res)
g = foo()
print(next(g))
print("----------")
print(next(g))

运行结果:

starting...
4
----------
res: None
4

带yield的函数是一个生成器,而不是一个函数。因为foo函数中有yield关键字,所以foo函数并不会真的执行,而是先得到一个生成器的实例,当我们第一次调用next函数的时候,foo函数才开始行,首先先执行foo函数中的print方法,然后进入while循环,循环执行到yield时,yield其实相当于return,函数返回4,程序停止。所以我们第一次调用next(g)的输出结果是前面两行。

然后当我们再次调用next(g)时,这个时候是从上一次停止的地方继续执行,也就是要执行res的赋值操作,因为4已经在上一次执行被return了,随意赋值res为None,然后执行print(“res:”,res)打印res: None,再次循环到yield返回4,程序停止。

所以yield关键字的作用就是我们能够从上一次程序停止的地方继续执行,这样我们用作生成器的时候,就避免一次性读入数据造成内存不足的情况。

现在看到上面的示例代码:

generate_arrays_from_file函数就是我们的生成器,每次循环读取一个batch大小的数据,然后处理数据,并返回。x_y是我们的把路径和标签合并后的训练集,类似于如下形式:

['data/img\\fimg_4092.jpg' '0' '1' '0' '0' '0' ]

至于格式不一定要这样,可以是自己的格式,至于怎么处理,根于自己的格式,在process_x进行处理,这里因为是存放的图片路径,所以在process_x函数的主要作用就是读取图片并进行归一化等操作,也可以在这里定义自己需要进行的操作,例如对图像进行实时数据增强。

2.2使用Sequence实现generator

示例代码:

class BaseSequence(Sequence):
 """
 基础的数据流生成器,每次迭代返回一个batch
 BaseSequence可直接用于fit_generator的generator参数
 fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器
 而且能保证在多进程下的一个epoch中不会重复取相同的样本
 """
 def __init__(self, img_paths, labels, batch_size, img_size):
  #np.hstack在水平方向上平铺
  self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
  self.batch_size = batch_size
  self.img_size = img_size

 def __len__(self):
  #math.ceil表示向上取整
  #调用len(BaseSequence)时返回,返回的是每个epoch我们需要读取数据的次数
  return math.ceil(len(self.x_y) / self.batch_size)

 def preprocess_img(self, img_path):

  img = Image.open(img_path)
  resize_scale = self.img_size[0] / max(img.size[:2])
  img = img.resize((self.img_size[0], self.img_size[0]))
  img = img.convert('RGB')
  img = np.array(img)

  # 数据归一化
  img = np.asarray(img, np.float32) / 255.0
  return img

 def __getitem__(self, idx):
  batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
  batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
  batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print(batch_x.shape)
  return batch_x, batch_y
 #重写的父类Sequence中的on_epoch_end方法,在每次迭代完后调用。
 def on_epoch_end(self):
  #每次迭代后重新打乱训练集数据
  np.random.shuffle(self.x_y)

在上面代码中,__len __和__getitem __,是我们重写的魔法方法,__len __是当我们调用len(BaseSequence)函数时调用,这里我们返回(样本总量/batch_size),供我们传入fit_generator中的steps_per_epoch参数;__getitem __可以让对象实现迭代功能,这样在将BaseSequence的对象传入fit_generator中后,不断执行generator就可循环的读取数据了。

举个例子说明一下getitem的作用:

class Animal:
 def __init__(self, animal_list):
  self.animals_name = animal_list

 def __getitem__(self, index):
  return self.animals_name[index]

animals = Animal(["dog","cat","fish"])
for animal in animals:
 print(animal)

输出结果:

dog
cat
fish

并且使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。

以上这篇浅谈keras通过model.fit_generator训练模型(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python编写Linux系统守护进程实例
Feb 03 Python
python多进程和多线程究竟谁更快(详解)
May 29 Python
pandas 获取季度,月度,年度首尾日期的方法
Apr 11 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 Python
python使用PIL实现多张图片垂直合并
Jan 15 Python
Django Celery异步任务队列的实现
Jul 24 Python
Python中zip()函数的简单用法举例
Sep 02 Python
python多线程案例之多任务copy文件完整实例
Oct 29 Python
tensorflow 实现数据类型转换
Feb 17 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 Python
python用分数表示矩阵的方法实例
Jan 11 Python
python​格式化字符串
Apr 20 Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
python实例化对象的具体方法
Jun 17 #Python
python和php学习哪个更有发展
Jun 17 #Python
python中线程和进程有何区别
Jun 17 #Python
You might like
笑谈配置,使用Smarty技术
2007/01/04 PHP
详谈php静态方法及普通方法的区别
2016/10/04 PHP
页面只有一个text的时候,回车自动submit的解决方法
2010/08/12 Javascript
jQuery实现数字加减效果汇总
2014/12/16 Javascript
PHP中CURL的几个经典应用实例
2015/01/23 Javascript
js生成随机数的过程解析
2015/11/24 Javascript
js实现仿微博滚动显示信息的效果
2015/12/21 Javascript
ArtEditor富文本编辑器增加表单提交功能
2016/04/18 Javascript
使用BootStrap实现表格隔行变色及hover变色并在需要时出现滚动条
2017/01/04 Javascript
基于构造函数的五种继承方法小结
2017/07/27 Javascript
JS实现手写parseInt的方法示例
2017/09/24 Javascript
Vue中正确使用jQuery的方法
2017/10/30 jQuery
使用vue-route 的 beforeEach 实现导航守卫(路由跳转前验证登录)功能
2018/03/22 Javascript
JS与jQuery实现ListBox上移,下移,左移,右移操作功能示例
2018/05/31 jQuery
jQuery实现雪花飘落效果
2020/08/02 jQuery
python使用新浪微博api上传图片到微博示例
2014/01/10 Python
python中使用OpenCV进行人脸检测的例子
2014/04/18 Python
跟老齐学Python之正规地说一句话
2014/09/28 Python
Python遍历指定文件及文件夹的方法
2015/05/09 Python
Django数据库连接丢失问题的解决方法
2018/12/29 Python
手把手教你pycharm专业版安装破解教程(linux版)
2019/09/26 Python
python中seaborn包常用图形使用详解
2019/11/25 Python
解决Pytorch自定义层出现多Variable共享内存错误问题
2020/06/28 Python
用sleep间隔进行python反爬虫的实例讲解
2020/11/30 Python
HTML5 画布canvas使用方法
2016/03/18 HTML / CSS
html2 canvas生成清晰的图片实现打印功能
2019/09/23 HTML / CSS
政府法律服务方案
2014/06/14 职场文书
小学德育工作总结2015
2015/05/12 职场文书
实习证明模板
2015/06/16 职场文书
领导离职感言
2015/08/03 职场文书
2016年入党心得体会范文
2016/01/23 职场文书
2016年社区服务活动总结
2016/04/06 职场文书
详解PHP设计模式之依赖注入模式
2021/05/25 PHP
Python基础教程,Python入门教程(超详细)
2021/06/24 Python
python基础之类属性和实例属性
2021/10/24 Python
Android超详细讲解组件ScrollView的使用
2022/03/31 Java/Android