Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
在Django中创建第一个静态视图
Jul 15 Python
Python中第三方库Requests库的高级用法详解
Mar 12 Python
Python函数和模块的使用总结
May 20 Python
python解压TAR文件至指定文件夹的实例
Jun 10 Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 Python
python实现通过flask和前端进行数据收发
Aug 22 Python
opencv实现简单人脸识别
Feb 19 Python
PYTHON绘制雷达图代码实例
Oct 15 Python
Python图像处理库PIL的ImageDraw模块介绍详解
Feb 26 Python
Django 项目布局方法(值得推荐)
Mar 22 Python
python搜索算法原理及实例讲解
Nov 18 Python
python爬取招聘要求等信息实例
Nov 20 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
php file_get_contents函数轻松采集html数据
2010/04/22 PHP
window+nginx+php环境配置 附配置搭配说明
2010/12/29 PHP
PHP可变函数学习小结
2015/11/29 PHP
Yii2中cookie用法示例分析
2016/07/18 PHP
微信公众平台开发教程⑤ 微信扫码支付模式介绍
2019/04/10 PHP
文本框的字数限制功能jquery插件
2009/11/24 Javascript
js下通过getList函数实现分页效果的代码
2010/09/17 Javascript
JSChart轻量级图形报表工具(内置函数中文参考)
2010/10/11 Javascript
用jquery与css打造个性化的单选框和复选框
2010/10/20 Javascript
jQuery LigerUI 使用教程入门篇
2012/01/18 Javascript
js去字符串前后空格5种实现方法及比较
2013/04/03 Javascript
html中使用javascript调用本地程序(exe、doc等)实现代码
2013/04/26 Javascript
js onload处理html页面加载之后的事件
2013/10/30 Javascript
深入理解JavaScript系列(46):代码复用模式(推荐篇)详解
2015/03/04 Javascript
Jquery网页内滑动缓冲导航的实现代码
2015/04/05 Javascript
js判断当前页面用什么浏览器打开的方法
2016/01/06 Javascript
jquery弹出遮掩层效果【附实例代码】
2016/04/28 Javascript
jQuery图片左右滚动代码 有左右按钮实例
2016/06/20 Javascript
关于微信上网页图片点击全屏放大效果
2016/12/19 Javascript
微信小程序中实现手指缩放图片的示例代码
2018/03/13 Javascript
轻量级富文本编辑器wangEditor结合vue使用方法示例
2018/10/10 Javascript
layui的layedit富文本赋值方法
2019/09/18 Javascript
Python使用Selenium模块实现模拟浏览器抓取淘宝商品美食信息功能示例
2018/07/18 Python
Tensorflow模型实现预测或识别单张图片
2019/07/19 Python
pytorch中tensor.expand()和tensor.expand_as()函数详解
2019/12/27 Python
Pycharm配置autopep8实现流程解析
2020/11/28 Python
美国最顶级的精品店之一:Hampden Clothing
2016/12/22 全球购物
C#面试题
2016/05/06 面试题
拖鞋店创业计划书
2014/01/15 职场文书
趣味比赛活动方案
2014/02/15 职场文书
幼儿园庆六一活动方案
2014/03/06 职场文书
小学生保护环境倡议书
2014/05/15 职场文书
领导班子作风建设剖析材料
2014/10/11 职场文书
2014年检验员工作总结
2014/11/19 职场文书
高三英语教学计划
2015/01/23 职场文书
2015年食品安全工作总结
2015/05/15 职场文书