keras和tensorflow使用fit_generator 批次训练操作


Posted in Python onJuly 03, 2020

fit_generator 是 keras 提供的用来进行批次训练的函数,使用方法如下:

model.fit_generator(generator, steps_per_epoch=None, epochs=1,
    verbose=1, callbacks=None, validation_data=None, validation_steps=None,
    class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
    shuffle=True, initial_epoch=0)

参数说明:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个(inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

这个元组(生成器的单个输出)组成了单个的 batch。 因此,这个元组中的所有数组长度必须相同(与这一个 batch 的大小相等)。 不同的 batch 可能大小不同。 例如,一个 epoch 的最后一个 batch 往往比其他 batch 要小, 如果数据集的尺寸不能被 batch size 整除。 生成器将无限地在数据集上循环。当运行到第steps_per_epoch 时,记一个 epoch 结束。

steps_per_epoch: 在声明一个 epoch 完成并开始下一个 epoch 之前从 generator产生的总步数(批次样本)。 它通常应该等于你的数据集的样本数量除以批量大小。 对于Sequence,它是可选的:如果未指定,将使用len(generator)作为步数。

epochs: 整数。训练模型的迭代总轮数。一个 epoch 是对所提供的整个数据的一轮迭代,如 steps_per_epoch 所定义。注意,与 initial_epoch 一起使用,epoch 应被理解为「最后一轮」。模型没有经历由 epochs 给出的多次迭代的训练,而仅仅是直到达到索引 epoch 的轮次。

verbose: 0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行。

callbacks: keras.callbacks.Callback 实例的列表。在训练时调用的一系列回调函数。

validation_data: 它可以是以下之一:

验证数据的生成器或Sequence实例

一个(inputs, targets) 元组

一个(inputs, targets, sample_weights) 元组。

在每个 epoch 结束时评估损失和任何模型指标。该模型不会对此数据进行训练。

validation_steps: 仅当 validation_data 是一个生成器时才可用。 在停止前 generator 生成的总步数(样本批数)。 对于 Sequence,它是可选的:如果未指定,将使用 len(generator) 作为步数。

class_weight: 可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。

max_queue_size: 整数。生成器队列的最大尺寸。 如未指定,max_queue_size 将默认为 10。

workers: 整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing: 布尔值。如果 True,则使用基于进程的多线程。 如未指定, use_multiprocessing 将默认为 False。 请注意,由于此实现依赖于多进程,所以不应将不可传递的参数传递给生成器,因为它们不能被轻易地传递给子进程。

shuffle: 是否在每轮迭代之前打乱 batch 的顺序。 只能与 Sequence (keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)。

补充知识:Keras中fit_generator 的多个分支输入时,需注意generator的格式 以及 输入序列的顺序

需要注意迭代器 yeild返回不能是[x1,x2],y 这样,而是要完整的字典格式的:

yield ({'input_1': x1, 'input_2': x2}, {'output': y})

这也不算坑 追进去 fit_generator也能看到示例

def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True):
 ylen = len(y_train)
 loopcount = ylen // batch_size
 i=-1
 while True:
  if randomFlag:
   i = random.randint(0,loopcount-1)
  else:
   i=i+1
   i=i%loopcount

  yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size], 
    'bgInput': x_train2[i*batch_size:(i+1)*batch_size]}, 
   {'prediction': y_train[i*batch_size:(i+1)*batch_size]})

ps: 因为要是tuple yield后的括号不能省

需注意的坑1是,validation data中如果用【】组成数组进行输入,是要按顺序的,按编译model前的设置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中数组的顺序来编译

需注意的坑2是,多输入input时,以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm结合时,会提示不匹配。

history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),
   steps_per_epoch=len(trainX)//batchSize,
   validation_data=([testX,testX2],testY),
   epochs=epochs,
   callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1) # Fit the LSTM network/拟合LSTM网络

以上这篇keras和tensorflow使用fit_generator 批次训练操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的id()函数介绍
Feb 10 Python
Pycharm学习教程(2) 代码风格
May 02 Python
python操作xlsx文件的包openpyxl实例
May 03 Python
python进行两个表格对比的方法
Jun 27 Python
图解python全局变量与局部变量相关知识
Nov 02 Python
Django项目使用ckeditor详解(不使用admin)
Dec 17 Python
python3实现从kafka获取数据,并解析为json格式,写入到mysql中
Dec 23 Python
Python编程快速上手——强口令检测算法案例分析
Feb 29 Python
关于python 的legend图例,参数使用说明
Apr 17 Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 Python
python zip()函数的使用示例
Sep 23 Python
python简单验证码识别的实现过程
Jun 20 Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
You might like
人大复印资料处理程序_输入篇
2006/10/09 PHP
PHP 面向对象 PHP5 中的常量
2010/05/05 PHP
php使用GeoIP库实例
2014/06/27 PHP
详解js异步文件加载器
2016/01/24 PHP
PHP利用超级全局变量$_GET来接收表单数据的实例
2016/11/05 PHP
PHP中十六进制颜色与RGB颜色值互转的方法
2019/03/18 PHP
javascript字典探测用户名工具
2006/10/05 Javascript
40款非常棒的jQuery 插件和制作教程(系列一)
2011/10/26 Javascript
JS实现控制表格内指定单元格内容对齐的方法
2015/03/30 Javascript
理解js对象继承的N种模式
2016/01/25 Javascript
简单实现jQuery多选框功能
2017/01/09 Javascript
vue 解决addRoutes动态添加路由后刷新失效问题
2018/07/02 Javascript
取消Bootstrap的dropdown-menu点击默认关闭事件方法
2018/08/10 Javascript
Node.js 使用request模块下载文件的实例
2018/09/05 Javascript
JS实现盒子跟着鼠标移动及键盘方向键控制盒子移动效果示例
2019/01/29 Javascript
JavaScript数据结构与算法之检索算法示例【二分查找法、计算重复次数】
2019/02/22 Javascript
Vue+ElementUI使用vue-pdf实现预览功能
2019/11/26 Javascript
mpvue实现微信小程序快递单号查询代码
2020/04/03 Javascript
JavaScript设计模式--简单工厂模式实例分析【XHR工厂案例】
2020/05/23 Javascript
vue项目中播放rtmp视频文件流的方法
2020/09/17 Javascript
vue created钩子函数与mounted钩子函数的用法区别
2020/11/05 Javascript
[01:03:31]DOTA2上海特级锦标赛B组资格赛#1 Alliance VS Fnatic第二局
2016/02/26 DOTA
简单的通用表达式求10乘阶示例
2014/03/03 Python
浅谈Python由__dict__和dir()引发的一些思考
2017/10/30 Python
Python实现的多线程同步与互斥锁功能示例
2017/11/30 Python
scikit-learn线性回归,多元回归,多项式回归的实现
2019/08/29 Python
python3实现在二叉树中找出和为某一值的所有路径(推荐)
2019/12/26 Python
python 爬取百度文库并下载(免费文章限定)
2020/12/04 Python
Html5移动端网页端适配(js+rem)
2021/02/03 HTML / CSS
影视动画专业个人的自我评价
2013/12/31 职场文书
大学新闻系应届生求职信
2014/06/02 职场文书
公积金具结保证书
2015/05/11 职场文书
2015年小学教导处工作总结
2015/05/26 职场文书
如何写通讯稿
2015/07/22 职场文书
mysql定时自动备份数据库的方法步骤
2021/07/07 MySQL
html5+实现plus.io进行拍照和图片等获取
2022/06/01 HTML / CSS