keras实现图像预处理并生成一个generator的案例


Posted in Python onJune 17, 2020

如下所示:

keras实现图像预处理并生成一个generator的案例

接下来,给出我自己目前积累的代码,从目录中自动读取图像,并产生generator:

第一步:建立好目录结构和图像

keras实现图像预处理并生成一个generator的案例

可以看到目录images_keras_dict下有次级目录,次级目录下就直接包含照片了

**第二步:写代码建立预处理程序

# 先进行预处理图像
train_datagen = ImageDataGenerator(rescale=1./255, 
                  rotation_range=50,
                  height_shift_range=[-0.005, 0, 0.005],
                  width_shift_range=[-0.005, 0, 0.005],
                  horizontal_flip=True, 
                  fill_mode='reflect')
#再对预处理图像指定从目录中读取数据,可以看到我的目录最核心的地方是images_keras_dict(可以对照上一张图片)
train_generator = train_datagen.flow_from_directory('AgriculturalDisease_trainingset/images_keras_dict',
                          target_size=(height, width), batch_size=16)

val_datagen = ImageDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_directory('AgriculturalDisease_validationset/images_keras_dict', target_size=(height, width),
                        batch_size=64)

save_weights = ModelCheckpoint(filepath='models/best_weights.hdf5',monitor='val_loss', verbose=1, save_best_only=True)

# 最后在fit_generator 中放入生成器的函数train_generator
model.fit_generator(train_generator,
          steps_per_epoch=times_train,
          verbose=1,
          epochs=300,
          initial_epoch=0,
          validation_data=val_generator,
          validation_steps=times_val,
          callbacks=[save_weights, TrainValTensorBoard(write_graph=False)])

第三步:写入fit_generator进行训练

已经写在上一个代码中。

第四步:写predict_generator进行预测**

首先我们需要建立同样的目录结构。把包含预测图片的次级目录放在一个文件夹下,这个文件夹名就是关键文件夹。

这里我的关键文件夹是test文件夹

# 建立预处理
predict_datagen = ImageDataGenerator(rescale=1./255)
predict_generator = predict_datagen.flow_from_directory('AgriculturalDisease_validationset/test',
                            target_size=(height, width), batch_size=128)
# predict_generator.reset()
# 利用predict_generator进行预测
pred = model.predict_generator(predict_generator, max_queue_size=10, workers=1, verbose=1)

# 利用几个属性来读取文件夹和对应的分类
train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=40, fill_mode='wrap')
train_generator = train_datagen.flow_from_directory('new_images', target_size=(height, width), batch_size=96)
labels = (train_generator.class_indices)
labels = dict((v,k) for k,v in labels.items())
predictions = [labels[k] for k in predicted_class_indices]

# 还可以知道图片的名字
filenames = predict_generator.filenames

补充知识:[TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用

前言

是的,除了水报错文,我也来写点其他的。本文主要介绍Keras中以下三个函数的用法:

1、fit()

2、fit_generator()

3、train_on_batch()

当然,与上述三个函数相似的evaluate、predict、test_on_batch、predict_on_batch、evaluate_generator和predict_generator等就不详细说了,举一反三嘛。

环境

本文的代码是在以下环境下进行测试的:

Windows 10

Python 3.6

TensorFlow 2.0 Alpha

异同

大家用Keras也就图个简单快捷,但是在享受简单快捷的时候,也常常需要些定制化需求,除了model.fit(),有时候model.fit_generator()和model.train_on_batch()也很重要。

那么,这三个函数有什么异同呢?Adrian Rosebrock [1] 有如下总结:

当你使用.fit()函数时,意味着如下两个假设:

训练数据可以 完整地 放入到内存(RAM)里

数据已经不需要再进行任何处理了

这两个原因解释的非常好,之前我运行程序的时候,由于数据集太大(实际中的数据集显然不会都像 TensorFlow 官方教程里经常使用的 MNIST 数据集那样小),一次性加载训练数据到fit()函数里根本行不通:

history = model.fit(train_data, train_label) // Bomb!!!

于是我想,能不能先加载一个batch训练,然后再加载一个batch,如此往复。于是我就注意到了fit_generator()函数。什么时候该使用fit_generator函数呢?Adrian Rosebrock 的总结道:

内存不足以一次性加载整个训练数据的时候

需要一些数据预处理(例如旋转和平移图片、增加噪音、扩大数据集等操作)

在生成batch的时候需要更多的处理

对于我自己来说,除了数据集太大的缘故之外,我需要在生成batch的时候,对输入数据进行padding,所以fit_generator()就派上了用场。下面介绍如何使用这三种函数。

fit()函数

fit()函数其实没什么好说的,大家在看TensorFlow教程的时候已经见识过了。此外插一句话,tf.data.Dataset对不规则的序列数据真是不友好。

import tensorflow as tf
model = tf.keras.models.Sequential([
 ... // 你的模型
])
model.fit(train_x, // 训练输入
  train_y, // 训练标签
  epochs=5 // 训练5轮
)

fit_generator()函数

fit_generator()函数就比较重要了,也是本文讨论的重点。fit_generator()与fit()的主要区别就在一个generator上。之前,我们把整个训练数据都输入到fit()里,我们也不需要考虑batch的细节;现在,我们使用一个generator,每次生成一个batch送给fit_generator()训练。

def generator(x, y, b_size):
 ... // 处理函数

model.fit_generator(generator(train_x, train_y, batch_size), 
   step_per_epochs=np.ceil(len(train_x)/batch_size), 
   epochs=5
)

从上述代码中,我们发现有两处不同:

一个我们自定义的generator()函数,作为fit_generator()函数的第一个参数;

fit_generator()函数的step_per_epochs参数

自定义的generator()函数

该函数即是我们数据的生成器,在训练的时候,fit_generator()函数会不断地执行generator()函数,获取一个个的batch。

def generator(x, y, b_size):
 """Generates batch and batch and batch then feed into models.
 Args:
 x: input data;
 y: input labels;
 b_size: batch_size.
 Yield:
 (batch_x, batch_label): batched x and y.
 """
 while 1: // 死循环
 idx = ...
 batch_x = ...
 batch_y = ...
 ... // 任何你想要对这个`batch`中的数据执行的操作
 yield (batch_x, batch_y)

需要注意的是,不要使用return或者exit。

step_per_epochs参数

由于generator()函数的循环没有终止条件,fit_generator也不知道一个epoch什么时候结束,所以我们需要手动指定step_per_epochs参数,一般的数值即为len(y)//batch_size。如果数据集大小不能整除batch_size,而且你打算使用最后一个batch的数据(该batch比batch_size要小),此时使用np.ceil(len(y)/batch_size)。

keras.utils.Sequence类(2019年6月10日更新)

除了写generator()函数,我们还可以利用keras.utils.Sequence类来生成batch。先扔代码:

class Generator(keras.utils.Sequence):
 def __init__(self, x, y, b_size):
 self.x, self.y = x, y
 self.batch_size = b_size
 
 def __len__(self):
 return math.ceil(len(self.y)/self.batch_size

 def __getitem__(self, idx):
 b_x = self.x[idx*self.batch_size:(idx+1)*self.batch_size]
 b_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]
 ... // 对`batch`的其余操作
 return np.array(b_x), np.array(b_y)
 
 def on_epoch_end(self):
 """执行完一个`epoch`之后,还可以做一些其他的事情!"""
 ...

我们首先定义__init__函数,读取训练集数据,然后定义__len__函数,返回一个epoch中需要执行的step数(此时在fit_generator()函数中就不需要指定steps_per_epoch参数了),最后定义__getitem__函数,返回一个batch的数据。代码如下:

train_generator = Generator(train_x, train_y, batch_size)
val_generator = Generator(val_x, val_y, batch_size)

model.fit_generator(generator=train_generator, 
   epochs=3197747, 
   validation_data=val_generator
   )

根据官方 [2] 的说法,使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。总之,使用keras.utils.Sequence也是很方便的啦!

train_on_batch()函数

train_on_batch()函数接受一个batch的输入和标签,然后开始反向传播,更新参数等。大部分情况下你都不需要用到train_on_batch()函数,除非你有着充足的理由去定制化你的模型的训练流程。

结语

本文到此结束啦!希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用urllib2获取网络资源实例讲解
Dec 02 Python
python的dict,set,list,tuple应用详解
Jul 24 Python
Python随机生成彩票号码的方法
Mar 05 Python
Python字符串处理实现单词反转
Jun 14 Python
python将字典内容存入mysql实例代码
Jan 18 Python
python3如何将docx转换成pdf文件
Mar 23 Python
Python实现的爬虫刷回复功能示例
Jun 07 Python
python粘包问题及socket套接字编程详解
Jun 29 Python
python 定时器每天就执行一次的实现代码
Aug 14 Python
python同步两个文件夹下的内容
Aug 29 Python
Django更新models数据库结构步骤
Apr 01 Python
Python json格式化打印实现过程解析
Jul 21 Python
pytorch快速搭建神经网络_Sequential操作
Jun 17 #Python
浅谈Keras的Sequential与PyTorch的Sequential的区别
Jun 17 #Python
Keras之fit_generator与train_on_batch用法
Jun 17 #Python
基于Keras的格式化输出Loss实现方式
Jun 17 #Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 #Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
You might like
显示youtube视频缩略图和Vimeo视频缩略图代码分享
2014/02/13 PHP
Codeigniter实现处理用户登录验证后的URL跳转
2014/06/12 PHP
php HTML无刷新提交表单
2016/04/05 PHP
php获取ajax的headers方法与内容实例
2017/12/27 PHP
最新28个很棒的jQuery 教程
2011/05/28 Javascript
Javascript学习笔记 delete运算符
2011/09/13 Javascript
javascript中全局对象的isNaN()方法使用介绍
2013/12/19 Javascript
NodeJS学习笔记之Connect中间件模块(一)
2015/01/27 NodeJs
js控制网页前进和后退的方法
2015/06/08 Javascript
浅析jQuery中使用$所引发的问题
2016/05/29 Javascript
jquery对所有input type=text的控件赋值实现方法
2016/12/02 Javascript
微信小程序websocket实现聊天功能
2020/03/30 Javascript
JavaScript基于数组实现的栈与队列操作示例
2018/12/22 Javascript
Vue表单绑定的实例代码(单选按钮,选择框(单选时,多选时,用 v-for 渲染的动态选项)
2019/05/13 Javascript
JavaScript+HTML5 canvas实现放大镜效果完整示例
2019/05/15 Javascript
[01:52]2014DOTA2西雅图邀请赛 V社开大会你不知道的小秘密
2014/07/08 DOTA
Django如何实现内容缓存示例详解
2017/09/24 Python
解决nohup执行python程序log文件写入不及时的问题
2019/01/14 Python
python使用正则表达式匹配txt特定字符串(有换行)
2020/12/09 Python
Python非单向递归函数如何返回全部结果
2020/12/18 Python
Python爬虫分析微博热搜关键词的实现代码
2021/02/22 Python
任意一块网页内容实现“活”的背景(目前火狐浏览器专有)
2014/05/07 HTML / CSS
通过Canvas及File API缩放并上传图片完整示例
2013/08/08 HTML / CSS
ROSEFIELD手表荷兰官方网上商店:北欧极简设计女士腕表品牌
2018/01/24 全球购物
Pretty Little Thing美国:时尚女性服饰
2018/08/27 全球购物
在C语言中"指针和数组等价"到底是什么意思?
2014/03/24 面试题
校园新闻广播稿
2014/01/10 职场文书
JAVA程序员自荐书
2014/01/30 职场文书
护理专业自荐书
2014/06/04 职场文书
关于感恩的演讲稿400字
2014/08/26 职场文书
机关作风建设工作总结
2014/10/23 职场文书
Nginx配置并兼容HTTP实现代码解析
2021/03/31 Servers
python实现过滤敏感词
2021/05/08 Python
Python中Selenium对Cookie的操作方法
2021/07/09 Python
javascript之Object.assign()的痛点分析
2022/03/03 Javascript
MySql中的json_extract函数处理json字段详情
2022/06/05 MySQL