浅谈keras通过model.fit_generator训练模型(节省内存)


Posted in Python onJune 17, 2020

前言

前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的。

如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。

1.fit_generator函数简介

fit_generator(generator, 
steps_per_epoch=None, 
epochs=1, 
verbose=1, 
callbacks=None, 
validation_data=None, 
validation_steps=None, 
class_weight=None, 
max_queue_size=10, 
workers=1, 
use_multiprocessing=False, 
shuffle=True, 
initial_epoch=0)

参数:

generator:一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例。这是我们实现的重点,后面会着介绍生成器和sequence的两种实现方式。

steps_per_epoch:这个是我们在每个epoch中需要执行多少次生成器来生产数据,fit_generator函数没有batch_size这个参数,是通过steps_per_epoch来实现的,每次生产的数据就是一个batch,因此steps_per_epoch的值我们通过会设为(样本数/batch_size)。如果我们的generator是sequence类型,那么这个参数是可选的,默认使用len(generator) 。

epochs:即我们训练的迭代次数。

verbose:0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行

callbacks:在训练时调用的一系列回调函数。

validation_data:和我们的generator类似,只是这个使用于验证的,不参与训练。

validation_steps:和前面的steps_per_epoch类似。

class_weight:可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。(感觉这个参数用的比较少)

max_queue_size:整数。生成器队列的最大尺寸。默认为10.

workers:整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing:布尔值。如果 True,则使用基于进程的多线程。默认为False。

shuffle:是否在每轮迭代之前打乱 batch 的顺序。 只能与Sequence(keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)

2.generator实现

2.1生成器的实现方式

样例代码:

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image

def process_x(path):
 img = Image.open(path)
 img = img.resize((96,96))
 img = img.convert('RGB')
 img = np.array(img)

 img = np.asarray(img, np.float32) / 255.0
 #也可以进行进行一些数据数据增强的处理
 return img

count =1
def generate_arrays_from_file(x_y):
 #x_y 是我们的训练集包括标签,每一行的第一个是我们的图片路径,后面的是我们的独热化后的标签

 global count
 batch_size = 8
 while 1:
  batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]
  batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]

  batch_x = np.array([process_x(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print("count:"+str(count))
  count = count+1
  yield (batch_x, batch_y)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

x_y = []
model.fit_generator(generate_arrays_from_file(x_y),steps_per_epoch=10, epochs=2,max_queue_size=1,workers=1)

在理解上面代码之前我们需要首先了解yield的用法。

yield关键字:

我们先通过一个例子看一下yield的用法:

def foo():
 print("starting...")
 while True:
  res = yield 4
  print("res:",res)
g = foo()
print(next(g))
print("----------")
print(next(g))

运行结果:

starting...
4
----------
res: None
4

带yield的函数是一个生成器,而不是一个函数。因为foo函数中有yield关键字,所以foo函数并不会真的执行,而是先得到一个生成器的实例,当我们第一次调用next函数的时候,foo函数才开始行,首先先执行foo函数中的print方法,然后进入while循环,循环执行到yield时,yield其实相当于return,函数返回4,程序停止。所以我们第一次调用next(g)的输出结果是前面两行。

然后当我们再次调用next(g)时,这个时候是从上一次停止的地方继续执行,也就是要执行res的赋值操作,因为4已经在上一次执行被return了,随意赋值res为None,然后执行print(“res:”,res)打印res: None,再次循环到yield返回4,程序停止。

所以yield关键字的作用就是我们能够从上一次程序停止的地方继续执行,这样我们用作生成器的时候,就避免一次性读入数据造成内存不足的情况。

现在看到上面的示例代码:

generate_arrays_from_file函数就是我们的生成器,每次循环读取一个batch大小的数据,然后处理数据,并返回。x_y是我们的把路径和标签合并后的训练集,类似于如下形式:

['data/img\\fimg_4092.jpg' '0' '1' '0' '0' '0' ]

至于格式不一定要这样,可以是自己的格式,至于怎么处理,根于自己的格式,在process_x进行处理,这里因为是存放的图片路径,所以在process_x函数的主要作用就是读取图片并进行归一化等操作,也可以在这里定义自己需要进行的操作,例如对图像进行实时数据增强。

2.2使用Sequence实现generator

示例代码:

class BaseSequence(Sequence):
 """
 基础的数据流生成器,每次迭代返回一个batch
 BaseSequence可直接用于fit_generator的generator参数
 fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器
 而且能保证在多进程下的一个epoch中不会重复取相同的样本
 """
 def __init__(self, img_paths, labels, batch_size, img_size):
  #np.hstack在水平方向上平铺
  self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
  self.batch_size = batch_size
  self.img_size = img_size

 def __len__(self):
  #math.ceil表示向上取整
  #调用len(BaseSequence)时返回,返回的是每个epoch我们需要读取数据的次数
  return math.ceil(len(self.x_y) / self.batch_size)

 def preprocess_img(self, img_path):

  img = Image.open(img_path)
  resize_scale = self.img_size[0] / max(img.size[:2])
  img = img.resize((self.img_size[0], self.img_size[0]))
  img = img.convert('RGB')
  img = np.array(img)

  # 数据归一化
  img = np.asarray(img, np.float32) / 255.0
  return img

 def __getitem__(self, idx):
  batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
  batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
  batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print(batch_x.shape)
  return batch_x, batch_y
 #重写的父类Sequence中的on_epoch_end方法,在每次迭代完后调用。
 def on_epoch_end(self):
  #每次迭代后重新打乱训练集数据
  np.random.shuffle(self.x_y)

在上面代码中,__len __和__getitem __,是我们重写的魔法方法,__len __是当我们调用len(BaseSequence)函数时调用,这里我们返回(样本总量/batch_size),供我们传入fit_generator中的steps_per_epoch参数;__getitem __可以让对象实现迭代功能,这样在将BaseSequence的对象传入fit_generator中后,不断执行generator就可循环的读取数据了。

举个例子说明一下getitem的作用:

class Animal:
 def __init__(self, animal_list):
  self.animals_name = animal_list

 def __getitem__(self, index):
  return self.animals_name[index]

animals = Animal(["dog","cat","fish"])
for animal in animals:
 print(animal)

输出结果:

dog
cat
fish

并且使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。

以上这篇浅谈keras通过model.fit_generator训练模型(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python代码检查工具pylint 让你的python更规范
Sep 05 Python
从零学Python之入门(四)运算
May 27 Python
Python实现获取网站PR及百度权重
Jan 21 Python
Python实现统计英文单词个数及字符串分割代码
May 28 Python
浅谈django中的认证与登录
Oct 31 Python
python虚拟环境virtualenv的使用教程
Oct 20 Python
如何优雅地改进Django中的模板碎片缓存详解
Jul 04 Python
python实现自动化上线脚本的示例
Jul 01 Python
python3实现的zip格式压缩文件夹操作示例
Aug 17 Python
详解Python并发编程之创建多线程的几种方法
Aug 23 Python
python 操作hive pyhs2方式
Dec 21 Python
浅谈Pytorch torch.optim优化器个性化的使用
Feb 20 Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
python实例化对象的具体方法
Jun 17 #Python
python和php学习哪个更有发展
Jun 17 #Python
python中线程和进程有何区别
Jun 17 #Python
You might like
最省空间的计数器
2006/10/09 PHP
一个PHP+MSSQL分页的例子
2006/10/09 PHP
Smarty的配置与高级缓存技术分享
2012/06/05 PHP
PHP实现PDO的mysql数据库操作类
2014/12/12 PHP
PHP抓取网页、解析HTML常用的方法总结
2015/07/01 PHP
大家须知简单的php性能优化注意点
2016/01/04 PHP
PHP基于单例模式编写PDO类的方法
2016/09/13 PHP
php制作基于xml的RSS订阅源功能示例
2017/02/08 PHP
PHP设计模式之装饰器模式实例详解
2018/02/07 PHP
PHP dirname(__FILE__)原理及用法解析
2020/10/28 PHP
js 效率组装字符串 StringBuffer
2009/12/23 Javascript
实例详解jQuery结合GridView控件的使用方法
2016/01/04 Javascript
浅谈bootstrap源码分析之scrollspy(滚动侦听)
2016/06/06 Javascript
浅谈JS和jQuery的区别
2019/03/27 jQuery
Flutter部件内部状态管理小结之实现Vue的v-model功能
2019/06/11 Javascript
在Node.js中将SVG图像转换为PNG,JPEG,TIFF,WEBP和HEIF格式的方法
2019/08/22 Javascript
js通过循环多张图片实现动画效果
2019/12/19 Javascript
11个Javascript小技巧帮你提升代码质量(小结)
2020/12/28 Javascript
vue-router路由懒加载及实现的3种方式
2021/02/28 Vue.js
python实现的登陆Discuz!论坛通用代码分享
2014/07/11 Python
Python中变量交换的例子
2014/08/25 Python
Python中使用ElementTree解析XML示例
2015/06/02 Python
Python使用plotly绘制数据图表的方法
2017/07/18 Python
python截取两个单词之间的内容方法
2018/12/25 Python
Python这样操作能存储100多万行的xlsx文件
2019/04/16 Python
关于Python3 lambda函数的深入浅出
2019/11/27 Python
PyCharm License Activation激活码失效问题的解决方法(图文详解)
2020/03/12 Python
python 图像判断,清晰度(明暗),彩色与黑白实例
2020/06/04 Python
Python使用Selenium实现淘宝抢单的流程分析
2020/06/23 Python
anaconda3安装及jupyter环境配置全教程
2020/08/24 Python
HTML5 Canvas中使用用路径描画圆弧
2015/01/01 HTML / CSS
Guess欧洲官网:美国服饰品牌
2019/08/06 全球购物
意大利和国际奢侈品牌购物网站:Suitnegozi.com
2021/01/15 全球购物
材料加工硕士生求职信
2013/10/10 职场文书
打架检讨书500字
2014/01/29 职场文书
借款民事起诉状范文
2015/05/19 职场文书