keras的ImageDataGenerator和flow()的用法说明


Posted in Python onJuly 03, 2020

ImageDataGenerator的参数自己看文档

from keras.preprocessing import image
import numpy as np

X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
  samplewise_center=False,
  featurewise_std_normalization=False,
  samplewise_std_normalization=False,
  zca_whitening=False,
  zca_epsilon=1e-6,
  rotation_range=180,
  width_shift_range=0.2,
  height_shift_range=0.2,
  shear_range=0,
  zoom_range=0.001,
  channel_shift_range=0,
  fill_mode='nearest',
  cval=0.,
  horizontal_flip=True,
  vertical_flip=True,
  rescale=None,
  preprocessing_function=None,
  data_format='channels_last')

a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配

'''
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

输出

[[2]
 [1]
 [2]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

补充知识:tensorflow 与keras 混用之坑

在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下

其中错误为:TypeError: tuple indices must be integers, not list

再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下

原训练代码

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
 
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
 
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
 
model = Sequential()
 
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
 
model.compile(loss='categorical_crossentropy',
       optimizer="Nadam",
       metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
 
train_generator = datagen.flow_from_directory(
  train_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
val_generator = datagen.flow_from_directory(
  val_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
test_generator = datagen.flow_from_directory(
  test_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
model.fit_generator(
  train_generator,
  steps_per_epoch=nb_train_samples // batch_size,
  epochs=epochs,
  validation_data=val_generator,
  validation_steps=nb_validation_samples // batch_size)
 
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")

模型载入

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
 
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")

报错

/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
 from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
 File "/home/disk2/py/neroset/do.py", line 13, in <module>
  model = load_model("grib.h5")
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
  model = model_from_config(model_config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
  return layer_module.deserialize(config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
  printable_module_name='layer')
 File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
  list(custom_objects.items())))
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
  model.add(layer)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
  output_tensor = layer(self.outputs[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
  self.build(input_shapes[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
  dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
 
Process finished with exit code 1

战斗种族解释

убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)

强调文本 强调文本

keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense

##完美解决

##附上原文链接

https://qa-help.ru/questions/keras-batchnormalization

以上这篇keras的ImageDataGenerator和flow()的用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现list反转实例汇总
Nov 11 Python
在Django中进行用户注册和邮箱验证的方法
May 09 Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 Python
python实现京东秒杀功能
Jul 30 Python
Python3爬取英雄联盟英雄皮肤大图实例代码
Nov 14 Python
Django项目中添加ldap登陆认证功能的实现
Apr 04 Python
Python3获取cookie常用三种方案
Oct 05 Python
python+requests实现接口测试的完整步骤
Oct 27 Python
Scrapy-Redis之RedisSpider与RedisCrawlSpider详解
Nov 18 Python
Python实现视频中添加音频工具详解
Dec 06 Python
OpenCV实现常见的四种图像几何变换
Apr 01 Python
Python实现为PDF去除水印的示例代码
Apr 03 Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 #Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
You might like
php学习 字符串课件
2008/06/15 PHP
php相当简单的分页类
2008/10/02 PHP
PHP中把stdClass Object转array的几个方法
2014/05/08 PHP
thinkPHP实现瀑布流的方法
2014/11/29 PHP
php实现curl模拟ftp上传的方法
2015/07/29 PHP
php使用curl下载指定大小的文件实例代码
2017/09/30 PHP
基于jquery的分页控件(C#)
2011/01/06 Javascript
myEvent.js javascript跨浏览器事件框架
2011/10/24 Javascript
Bootstrap每天必学之轮播(Carousel)插件
2016/04/25 Javascript
设置点击文本框或图片弹出日历控件的实现代码
2016/05/12 Javascript
AngularJS基础 ng-keypress 指令简单示例
2016/08/02 Javascript
jQuery操作json常用方法示例
2017/01/04 Javascript
详解如何在NodeJS项目中优雅的使用ES6
2017/04/22 NodeJs
JavaScript实现无刷新上传预览图片功能
2017/08/02 Javascript
vue按需加载实例详解
2019/09/06 Javascript
js实现数字滚动特效
2019/12/16 Javascript
VUE table表格动态添加一列数据,新增的这些数据不可以编辑(v-model绑定的数据不能实时更新)
2020/04/03 Javascript
[02:43]2018DOTA2亚洲邀请赛主赛事首日TOP5
2018/04/04 DOTA
[34:41]夜魇凡尔赛茶话会 第二期02:你画我猜
2021/03/11 DOTA
Python导出DBF文件到Excel的方法
2015/07/25 Python
Python随机生成数据后插入到PostgreSQL
2016/07/28 Python
Pycharm无法使用已经安装Selenium的解决方法
2018/10/13 Python
Python列表list排列组合操作示例
2018/12/18 Python
Tensorflow加载Vgg预训练模型操作
2020/05/26 Python
Android Q之气泡弹窗的实现示例
2020/06/23 Python
Python建造者模式案例运行原理解析
2020/06/29 Python
css3 线性渐变和径向渐变示例附图
2014/04/08 HTML / CSS
英国香水店:The Perfume Shop
2017/03/27 全球购物
欧洲有机婴儿食品最大的市场:Organic Baby Food(供美国和加拿大)
2018/03/28 全球购物
为女性购买传统的印度服装和婚纱:Kalkifashion
2019/07/22 全球购物
小学生暑假安全保证书
2015/07/13 职场文书
秋季运动会加油词
2015/07/18 职场文书
教务处教学工作总结
2015/08/10 职场文书
导游词之青城山景区
2019/09/27 职场文书
使用react+redux实现计数器功能及遇到问题
2021/06/02 Javascript
Java对文件的读写操作方法
2022/04/29 Java/Android