keras 两种训练模型方式详解fit和fit_generator(节省内存)


Posted in Python onJuly 03, 2020

第一种,fit

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

#读取数据
x_train = np.load("D:\\machineTest\\testmulPE_win7\\data_sprase.npy")[()]
y_train = np.load("D:\\machineTest\\testmulPE_win7\\lable_sprase.npy")

# 获取分类类别总数
classes = len(np.unique(y_train))

#对label进行one-hot编码,必须的
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(y_train)
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
y_train = onehot_encoder.fit_transform(integer_encoded)

#shuffle
X_train, X_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.3, random_state=0)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=784))
model.add(Dense(units=classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.fit(X_train, y_train, epochs=50, batch_size=128)
score = model.evaluate(X_test, y_test, batch_size=128)
# #fit参数详情
# keras.models.fit(
# self,
# x=None, #训练数据
# y=None, #训练数据label标签
# batch_size=None, #每经过多少个sample更新一次权重,defult 32
# epochs=1, #训练的轮数epochs
# verbose=1, #0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
# callbacks=None,#list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数
# validation_split=0., #浮点数0-1,将训练集中的一部分比例作为验证集,然后下面的验证集validation_data将不会起到作用
# validation_data=None, #验证集
# shuffle=True, #布尔值和字符串,如果为布尔值,表示是否在每一次epoch训练前随机打乱输入样本的顺序,如果为"batch",为处理HDF5数据
# class_weight=None, #dict,分类问题的时候,有的类别可能需要额外关注,分错的时候给的惩罚会比较大,所以权重会调高,体现在损失函数上面
# sample_weight=None, #array,和输入样本对等长度,对输入的每个特征+个权值,如果是时序的数据,则采用(samples,sequence_length)的矩阵
# initial_epoch=0, #如果之前做了训练,则可以从指定的epoch开始训练
# steps_per_epoch=None, #将一个epoch分为多少个steps,也就是划分一个batch_size多大,比如steps_per_epoch=10,则就是将训练集分为10份,不能和batch_size共同使用
# validation_steps=None, #当steps_per_epoch被启用的时候才有用,验证集的batch_size
# **kwargs #用于和后端交互
# )
# 
# 返回的是一个History对象,可以通过History.history来查看训练过程,loss值等等

第二种,fit_generator(节省内存)

# 第二种,可以节省内存
'''
Created on 2018-4-11
fit_generate.txt,后面两列为lable,已经one-hot编码
1 2 0 1
2 3 1 0
1 3 0 1
1 4 0 1
2 4 1 0
2 5 1 0

'''
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split

count =1 
def generate_arrays_from_file(path):
 global count
 while 1:
  datas = np.loadtxt(path,delimiter=' ',dtype="int")
  x = datas[:,:2]
  y = datas[:,2:]
  print("count:"+str(count))
  count = count+1
  yield (x,y)
x_valid = np.array([[1,2],[2,3]])
y_valid = np.array([[0,1],[1,0]])
model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

model.fit_generator(generate_arrays_from_file("D:\\fit_generate.txt"),steps_per_epoch=10, epochs=2,max_queue_size=1,validation_data=(x_valid, y_valid),workers=1)
# steps_per_epoch 每执行一次steps,就去执行一次生产函数generate_arrays_from_file
# max_queue_size 从生产函数中出来的数据时可以缓存在queue队列中
# 输出如下:
# Epoch 1/2
# count:1
# count:2
# 
# 1/10 [==>...........................] - ETA: 2s - loss: 0.7145 - acc: 0.3333count:3
# count:4
# count:5
# count:6
# count:7
# 
# 7/10 [====================>.........] - ETA: 0s - loss: 0.7001 - acc: 0.4286count:8
# count:9
# count:10
# count:11
# 
# 10/10 [==============================] - 0s 36ms/step - loss: 0.6960 - acc: 0.4500 - val_loss: 0.6794 - val_acc: 0.5000
# Epoch 2/2
# 
# 1/10 [==>...........................] - ETA: 0s - loss: 0.6829 - acc: 0.5000count:12
# count:13
# count:14
# count:15
# 
# 5/10 [==============>...............] - ETA: 0s - loss: 0.6800 - acc: 0.5000count:16
# count:17
# count:18
# count:19
# count:20
# 
# 10/10 [==============================] - 0s 11ms/step - loss: 0.6766 - acc: 0.5000 - val_loss: 0.6662 - val_acc: 0.5000

补充知识:

自动生成数据还可以继承keras.utils.Sequence,然后写自己的生成数据类:

keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
 
 def __init__(self, datas, batch_size=1, shuffle=True):
  self.batch_size = batch_size
  self.datas = datas
  self.indexes = np.arange(len(self.datas))
  self.shuffle = shuffle

 def __len__(self):
  #计算每一个epoch的迭代次数
  return math.ceil(len(self.datas) / float(self.batch_size))

 def __getitem__(self, index):
  #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
  # 生成batch_size个索引
  batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
  # 根据索引获取datas集合中的数据
  batch_datas = [self.datas[k] for k in batch_indexs]

  # 生成数据
  X, y = self.data_generation(batch_datas)

  return X, y

 def on_epoch_end(self):
  #在每一次epoch结束是否需要进行一次随机,重新随机一下index
  if self.shuffle == True:
   np.random.shuffle(self.indexes)

 def data_generation(self, batch_datas):
  images = []
  labels = []

  # 生成数据
  for i, data in enumerate(batch_datas):
   #x_train数据
   image = cv2.imread(data)
   image = list(image)
   images.append(image)
   #y_train数据 
   right = data.rfind("\\",0)
   left = data.rfind("\\",0,right)+1
   class_name = data[left:right]
   if class_name=="dog":
    labels.append([0,1])
   else: 
    labels.append([1,0])
  #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
  return np.array(images), np.array(labels)
 
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
 file_path = os.path.join("D:/xxx", file)
 if os.path.isdir(file_path):
  class_num = class_num + 1
  for sub_file in os.listdir(file_path):
   train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras 两种训练模型方式详解fit和fit_generator(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现从字符串中找出字符1的位置以及个数的方法
Aug 25 Python
Python3中的2to3转换工具使用示例
Jun 12 Python
Python中Collections模块的Counter容器类使用教程
May 31 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
python中cPickle类使用方法详解
Aug 27 Python
Python代码打开本地.mp4格式文件的方法
Jan 03 Python
Python地图绘制实操详解
Mar 04 Python
python装饰器简介---这一篇也许就够了(推荐)
Apr 01 Python
python异常处理和日志处理方式
Dec 24 Python
Python 实现训练集、测试集随机划分
Jan 08 Python
python GUI库图形界面开发之PyQt5窗口类QMainWindow详细使用方法
Feb 26 Python
python游戏开发的五个案例分享
Mar 09 Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
You might like
php写一个函数,实现扫描并打印出自定目录下(含子目录)所有jpg文件名
2017/05/26 PHP
JS 退出系统并跳转到登录界面的实现代码
2013/06/29 Javascript
关于js遍历表格的实例
2013/07/10 Javascript
JavaScript动态改变表格单元格内容的方法
2015/03/30 Javascript
在Ubuntu系统上安装Ghost博客平台的教程
2015/06/17 Javascript
javascript实现在网页中运行本地程序的方法
2016/02/03 Javascript
ionic实现滑动的三种方式
2016/08/27 Javascript
关于react中组件通信的几种方式详解
2017/12/10 Javascript
浅谈Vue SPA 首屏加载优化实践
2017/12/15 Javascript
vue2.0 axios跨域并渲染的问题解决方法
2018/03/08 Javascript
基于vue实现可搜索下拉框定制组件
2020/03/26 Javascript
node实现socket链接与GPRS进行通信的方法
2019/05/20 Javascript
VScode格式化ESlint方法(最全最好用方法)
2019/09/10 Javascript
[03:41]DOTA2上海特锦赛小组赛第三日recap精彩回顾
2016/02/28 DOTA
Python程序员开发中常犯的10个错误
2014/07/07 Python
Python正确重载运算符的方法示例详解
2017/08/27 Python
python中datetime模块中strftime/strptime函数的使用
2018/07/03 Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
2019/08/19 Python
python同步windows和linux文件
2019/08/29 Python
python kafka 多线程消费者&手动提交实例
2019/12/21 Python
Django-xadmin+rule对象级权限的实现方式
2020/03/30 Python
美国殿堂级滑板、冲浪、滑雪服装品牌:Volcom(钻石)
2017/04/20 全球购物
美国存储和组织商店:The Container Store
2017/08/16 全球购物
墨西哥皇宫度假村预订:Palace Resorts
2018/06/16 全球购物
FLIR美国官网:热成像, 夜视和红外摄像系统
2018/07/13 全球购物
Groupon荷兰官方网站:高达70%的折扣
2019/11/01 全球购物
Clos19英国:高档香槟、葡萄酒和烈酒在线购物平台
2020/07/10 全球购物
Linux如何压缩可执行文件
2014/03/27 面试题
新闻编辑专业自荐信
2014/07/02 职场文书
2014国庆黄金周超市促销活动方案
2014/09/21 职场文书
2014年城市管理工作总结
2014/12/02 职场文书
2014年小学数学工作总结
2014/12/12 职场文书
蓬莱阁导游词
2015/02/04 职场文书
民主评议教师党员自我评价
2015/03/04 职场文书
毕业论文答辩开场白
2015/05/27 职场文书
作文之亲情600字
2019/09/23 职场文书