Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
教你如何在Django 1.6中正确使用 Signal
Jun 22 Python
Python手机号码归属地查询代码
May 04 Python
深入了解Python数据类型之列表
Jun 24 Python
python 排序算法总结及实例详解
Sep 28 Python
python Socket之客户端和服务端握手详解
Sep 18 Python
Python编程实现线性回归和批量梯度下降法代码实例
Jan 04 Python
python如何让类支持比较运算
Mar 20 Python
Django 表单模型选择框如何使用分组
May 16 Python
python set内置函数的具体使用
Jul 02 Python
简单了解python元组tuple相关原理
Dec 02 Python
使用pyqt 实现重复打开多个相同界面
Dec 13 Python
Python如何在windows环境安装pip及rarfile
Jun 15 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
PHP strtotime函数详解
2009/12/18 PHP
删除无限分类并同时删除它下面的所有子分类的方法
2010/08/08 PHP
destoon复制新模块的方法
2014/06/21 PHP
php实现网站文件批量压缩下载功能
2015/10/28 PHP
php关键字仅替换一次的实现函数
2015/10/29 PHP
CodeIgniter扩展核心类实例详解
2016/01/20 PHP
深入理解php printf() 输出格式化的字符串
2016/05/23 PHP
php xhprof使用实例详解
2019/04/15 PHP
one.php 多项目、函数库、类库 统一为一个版本的方法
2020/08/24 PHP
在JS数组特定索引处指定位置插入元素的技巧
2014/08/24 Javascript
用C/C++来实现 Node.js 的模块(二)
2014/09/24 Javascript
jQuery下拉美化搜索表单效果代码分享
2015/08/25 Javascript
JS实现具备延时功能的滑动门菜单效果
2015/09/17 Javascript
Javascript6中字符串的四个新用法分享
2016/09/11 Javascript
footer定位页面底部(代码分享)
2017/03/07 Javascript
easyui-datagrid特殊字符不能显示的处理方法
2017/04/12 Javascript
AngularJS模态框模板ngDialog的使用详解
2018/05/11 Javascript
记录一次开发微信网页分享的步骤
2019/05/07 Javascript
JavaScript实现左右滚动电影画布
2020/02/06 Javascript
JavaScript常用工具函数库汇总
2020/09/17 Javascript
python使用WMI检测windows系统信息、硬盘信息、网卡信息的方法
2015/05/15 Python
利用PyCharm操作Github(仓库新建、更新,代码回滚)
2019/12/18 Python
瑞典首都斯德哥尔摩的多元奢侈时尚品牌:Acne Studios
2017/07/09 全球购物
德国高端单身人士交友网站:ElitePartner
2018/12/02 全球购物
自荐信格式简述
2014/01/25 职场文书
学生操行评语大全
2014/04/24 职场文书
安全施工标语
2014/06/07 职场文书
相亲活动方案
2014/08/26 职场文书
科学育儿宣传标语
2014/10/08 职场文书
四川省传达学习贯彻党的群众路线教育实践活动总结大会精神新闻稿
2014/10/26 职场文书
2014业务员年终工作总结
2014/12/09 职场文书
大三学生英语考试作弊检讨书
2015/01/01 职场文书
2015年国庆放假通知范文
2015/08/18 职场文书
《玩出了名堂》教学反思
2016/02/17 职场文书
python实现网络五子棋
2021/04/11 Python
Python语法学习之进程的创建与常用方法详解
2022/04/08 Python