keras和tensorflow使用fit_generator 批次训练操作


Posted in Python onJuly 03, 2020

fit_generator 是 keras 提供的用来进行批次训练的函数,使用方法如下:

model.fit_generator(generator, steps_per_epoch=None, epochs=1,
    verbose=1, callbacks=None, validation_data=None, validation_steps=None,
    class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
    shuffle=True, initial_epoch=0)

参数说明:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个(inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

这个元组(生成器的单个输出)组成了单个的 batch。 因此,这个元组中的所有数组长度必须相同(与这一个 batch 的大小相等)。 不同的 batch 可能大小不同。 例如,一个 epoch 的最后一个 batch 往往比其他 batch 要小, 如果数据集的尺寸不能被 batch size 整除。 生成器将无限地在数据集上循环。当运行到第steps_per_epoch 时,记一个 epoch 结束。

steps_per_epoch: 在声明一个 epoch 完成并开始下一个 epoch 之前从 generator产生的总步数(批次样本)。 它通常应该等于你的数据集的样本数量除以批量大小。 对于Sequence,它是可选的:如果未指定,将使用len(generator)作为步数。

epochs: 整数。训练模型的迭代总轮数。一个 epoch 是对所提供的整个数据的一轮迭代,如 steps_per_epoch 所定义。注意,与 initial_epoch 一起使用,epoch 应被理解为「最后一轮」。模型没有经历由 epochs 给出的多次迭代的训练,而仅仅是直到达到索引 epoch 的轮次。

verbose: 0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行。

callbacks: keras.callbacks.Callback 实例的列表。在训练时调用的一系列回调函数。

validation_data: 它可以是以下之一:

验证数据的生成器或Sequence实例

一个(inputs, targets) 元组

一个(inputs, targets, sample_weights) 元组。

在每个 epoch 结束时评估损失和任何模型指标。该模型不会对此数据进行训练。

validation_steps: 仅当 validation_data 是一个生成器时才可用。 在停止前 generator 生成的总步数(样本批数)。 对于 Sequence,它是可选的:如果未指定,将使用 len(generator) 作为步数。

class_weight: 可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。

max_queue_size: 整数。生成器队列的最大尺寸。 如未指定,max_queue_size 将默认为 10。

workers: 整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing: 布尔值。如果 True,则使用基于进程的多线程。 如未指定, use_multiprocessing 将默认为 False。 请注意,由于此实现依赖于多进程,所以不应将不可传递的参数传递给生成器,因为它们不能被轻易地传递给子进程。

shuffle: 是否在每轮迭代之前打乱 batch 的顺序。 只能与 Sequence (keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)。

补充知识:Keras中fit_generator 的多个分支输入时,需注意generator的格式 以及 输入序列的顺序

需要注意迭代器 yeild返回不能是[x1,x2],y 这样,而是要完整的字典格式的:

yield ({'input_1': x1, 'input_2': x2}, {'output': y})

这也不算坑 追进去 fit_generator也能看到示例

def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True):
 ylen = len(y_train)
 loopcount = ylen // batch_size
 i=-1
 while True:
  if randomFlag:
   i = random.randint(0,loopcount-1)
  else:
   i=i+1
   i=i%loopcount

  yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size], 
    'bgInput': x_train2[i*batch_size:(i+1)*batch_size]}, 
   {'prediction': y_train[i*batch_size:(i+1)*batch_size]})

ps: 因为要是tuple yield后的括号不能省

需注意的坑1是,validation data中如果用【】组成数组进行输入,是要按顺序的,按编译model前的设置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中数组的顺序来编译

需注意的坑2是,多输入input时,以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm结合时,会提示不匹配。

history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),
   steps_per_epoch=len(trainX)//batchSize,
   validation_data=([testX,testX2],testY),
   epochs=epochs,
   callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1) # Fit the LSTM network/拟合LSTM网络

以上这篇keras和tensorflow使用fit_generator 批次训练操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的Kmeans++算法实例
Apr 26 Python
为Python程序添加图形化界面的教程
Apr 29 Python
Python解析json之ValueError: Expecting property name enclosed in double quotes: line 1 column 2(char 1)
Jul 06 Python
python利用paramiko连接远程服务器执行命令的方法
Oct 16 Python
Django项目实战之用户头像上传与访问的示例
Apr 21 Python
python 输入一个数n,求n个数求乘或求和的实例
Nov 13 Python
解决Pytorch 加载训练好的模型 遇到的error问题
Jan 10 Python
解决Tensorboard可视化错误:不显示数据 No scalar data was found
Feb 15 Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 Python
python实现定时发送邮件
Dec 23 Python
python如何修改文件时间属性
Feb 05 Python
Python实现曲线拟合的最小二乘法
Feb 19 Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
You might like
php轻松实现中英文混排字符串截取
2014/05/28 PHP
javascript iframe中打开文件,并检测iframe存在否
2008/12/28 Javascript
三级下拉菜单的js实现代码
2011/05/23 Javascript
js转化毫秒为时间格式代码
2014/04/10 Javascript
使用node.js半年来总结的 10 条经验
2014/08/18 Javascript
Javascript让DEDECMS告别手写Tag
2014/09/01 Javascript
JavaScript实现信用卡校验方法
2015/04/07 Javascript
jquery实现Ctrl+Enter提交表单的方法
2015/07/21 Javascript
js实现浏览器倒计时跳转页面效果
2016/08/12 Javascript
微信小程序 页面跳转传参详解
2016/10/28 Javascript
详解vue-cli快速构建项目以及引入bootstrap、jq
2017/05/26 Javascript
vue 使用axios 数据请求第三方插件的使用教程详解
2019/07/05 Javascript
Angular8基础应用之表单及其验证
2019/08/11 Javascript
解决node.js含有%百分号时发送get请求时浏览器地址自动编码的问题
2019/11/20 Javascript
解决vue项目中出现Invalid Host header的问题
2020/11/17 Javascript
python anaconda 安装 环境变量 升级 以及特殊库安装的方法
2017/06/21 Python
python实现隐马尔科夫模型HMM
2018/03/25 Python
基于python实现名片管理系统
2018/11/30 Python
Python零基础入门学习之输入与输出
2019/04/03 Python
Python实现的企业粉丝抽奖功能示例
2019/07/26 Python
python实现批量修改服务器密码的方法
2019/08/13 Python
解析Tensorflow之MNIST的使用
2020/06/30 Python
Python自动化操作实现图例绘制
2020/07/09 Python
css3实现背景模糊的三种方式(小结)
2020/05/15 HTML / CSS
基于HTML5的WebGL实现json和echarts图表展现在同一个界面
2017/10/26 HTML / CSS
荷兰超市:DEEN
2018/03/14 全球购物
聚网科技C++面试笔试题
2015/09/01 面试题
Java中有几种类型的流?JDK为每种类型的流提供了一些抽象类以供继承,请说出他们分别是哪些类?
2012/05/30 面试题
医学专业本科毕业生自我鉴定
2013/12/28 职场文书
个人简历自我评价
2014/02/02 职场文书
庆元旦文艺演出主持词
2014/03/27 职场文书
班级旅游计划书
2014/05/03 职场文书
2014年度安全生产目标管理责任书
2014/07/25 职场文书
意外死亡赔偿协议书
2014/10/14 职场文书
2014年仓管员工作总结
2014/11/18 职场文书
党员转正介绍人意见
2015/06/03 职场文书