浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python编写简单的端口扫描器的实例分享
Dec 18 Python
Python字符编码判断方法分析
Jul 01 Python
python文件特定行插入和替换实例详解
Jul 12 Python
python机器学习理论与实战(二)决策树
Jan 19 Python
如何利用python查找电脑文件
Apr 27 Python
Python字符串、整数、和浮点型数相互转换实例
Aug 04 Python
深入理解python中sort()与sorted()的区别
Aug 29 Python
基于python if 判断选择结构的实例详解
May 06 Python
python时间与Unix时间戳相互转换方法详解
Feb 13 Python
python+playwright微软自动化工具的使用
Feb 02 Python
在python3.9下如何安装scrapy的方法
Feb 03 Python
Python软件包安装的三种常见方法
Jul 07 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
实现php加速的eAccelerator dll支持文件打包下载
2007/09/30 PHP
为IP查询添加GOOGLE地图功能的代码
2010/08/08 PHP
php与java通过socket通信的实现代码
2013/10/21 PHP
Thinkphp自定义代码生成工具及用法说明(附下载地址)
2016/05/27 PHP
php图片添加水印例子
2016/07/20 PHP
thinkPHP5框架中widget的功能与用法详解
2018/06/11 PHP
js类后台管理菜单类-MenuSwitch
2007/09/12 Javascript
Chosen 基于jquery的选择框插件使用方法
2012/05/30 Javascript
目前流行的JavaScript库的介绍及对比
2013/09/29 Javascript
简单的ajax连接库分享(不用jquery的ajax)
2014/01/19 Javascript
jQuery插件EasyUI获取当前Tab中iframe窗体对象的方法
2016/08/05 Javascript
滚动条的监听与内容随着滚动条动态加载的实现
2017/02/08 Javascript
jQuery实现搜索页面关键字的功能
2017/02/16 Javascript
浅谈js函数三种定义方式 & 四种调用方式 & 调用顺序
2017/02/19 Javascript
Thinkphp5微信小程序获取用户信息接口的实例详解
2017/09/26 Javascript
jQuery实现判断滚动条滚动到document底部的方法分析
2019/08/27 jQuery
Python获取单个程序CPU使用情况趋势图
2015/03/10 Python
Python实现曲线点抽稀算法的示例
2017/10/12 Python
Python机器学习之决策树算法
2017/12/22 Python
Python3中详解fabfile的编写
2018/06/24 Python
修改默认的pip版本为对应python2.7的方法
2018/11/06 Python
python实现银联支付和支付宝支付接入
2019/05/07 Python
python读取图片的方式,以及将图片以三维数组的形式输出方法
2019/07/03 Python
numpy 返回函数的上三角矩阵实例
2019/11/25 Python
Pandas把dataframe或series转换成list的方法
2020/06/14 Python
HTML5离线缓存在tomcat下部署可实现图片flash等离线浏览
2012/12/13 HTML / CSS
介绍一下JMS编程步骤
2015/09/22 面试题
《小儿垂钓》教学反思
2014/02/23 职场文书
党支部反对四风思想汇报
2014/10/10 职场文书
2015年发展党员工作总结报告
2015/03/31 职场文书
实施意见格式范本
2015/06/05 职场文书
雷锋电影观后感
2015/06/10 职场文书
法人代表资格证明书
2015/06/18 职场文书
MySQL8.0的WITH查询详情
2021/08/30 MySQL
小程序实现悬浮按钮的全过程记录
2021/10/16 HTML / CSS
Go语言grpc和protobuf
2022/04/13 Golang