浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
下载给定网页上图片的方法
Feb 18 Python
python 读写txt文件 json文件的实现方法
Oct 22 Python
Python深入06——python的内存管理详解
Dec 07 Python
Python中datetime模块参考手册
Jan 13 Python
Python求出0~100以内的所有素数
Jan 23 Python
Python使用folium excel绘制point
Jan 03 Python
Python 使用PyQt5 完成选择文件或目录的对话框方法
Jun 27 Python
python字符串的拼接方法总结
Nov 18 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 Python
Python Selenium XPath根据文本内容查找元素的方法
Dec 07 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
May 28 Python
python实现手机推送 代码也就10行左右
Apr 12 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
安装PHP可能遇到的问题“无法载入mysql扩展” 的解决方法
2007/04/16 PHP
php实现通过cookie换肤的方法
2015/07/13 PHP
php.ini中date.timezone设置详解
2016/11/20 PHP
利用php获得flv视频长度的实例代码
2017/10/26 PHP
PHP+redis实现的悲观锁机制示例
2018/06/12 PHP
优化网页之快速的呈现我们的网页
2007/06/29 Javascript
JavaScript获得表单target属性的方法
2015/04/02 Javascript
JS中from 表单序列化提交的代码
2017/01/20 Javascript
AngularJS实现自定义指令及指令配置项的方法
2017/11/20 Javascript
使用vue制作探探滑动堆叠组件的实例代码
2018/03/07 Javascript
Layer弹出层动态获取数据的方法
2018/08/20 Javascript
小程序数据通信方法大全(推荐)
2019/04/15 Javascript
探索JavaScript中私有成员的相关知识
2019/06/13 Javascript
vue实现打地鼠小游戏
2020/08/21 Javascript
[02:28]DOTA2英雄基础教程 灰烬之灵
2013/12/19 DOTA
[03:02]辉夜杯主赛事第二日 每日之星
2015/12/27 DOTA
Python中字典(dict)和列表(list)的排序方法实例
2014/06/16 Python
python如何在列表、字典中筛选数据
2018/03/19 Python
python+opencv 读取文件夹下的所有图像并批量保存ROI的方法
2019/01/10 Python
python+logging+yaml实现日志分割
2019/07/22 Python
python实现猜数字游戏
2020/03/25 Python
python实现超市管理系统(后台管理)
2019/10/25 Python
Python 实现顺序高斯消元法示例
2019/12/09 Python
TensorBoard 计算图的查看方式
2020/02/15 Python
Django在Model保存前记录日志实例
2020/05/14 Python
台湾流行服饰购物平台:OB严选
2018/01/21 全球购物
eBay澳大利亚站:eBay.com.au
2018/02/02 全球购物
WebSphere面试题:在WebSphere里面如何部署一个应用
2015/08/02 面试题
护理专业毕业生自我鉴定
2013/10/08 职场文书
计算机专业自我鉴定
2013/10/15 职场文书
应用艺术专业个人的自我评价
2014/01/03 职场文书
国庆节演讲稿范文2014
2014/09/19 职场文书
婚前协议书怎么写,才具有法律效力呢 ?
2019/06/28 职场文书
Python Pandas常用函数方法总结
2021/06/15 Python
MySQL修炼之联结与集合浅析
2021/10/05 MySQL
MySQL中varchar和char类型的区别
2021/11/17 MySQL