Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python抓取网页图片示例(python爬虫)
Apr 27 Python
python中的字典详细介绍
Sep 18 Python
python操作CouchDB的方法
Oct 08 Python
Python赋值语句后逗号的作用分析
Jun 08 Python
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
Jan 20 Python
Python的Asyncore异步Socket模块及实现端口转发的例子
Jun 14 Python
python 阶乘累加和的实例
Feb 01 Python
python算法与数据结构之单链表的实现代码
Jun 27 Python
python pandas模块基础学习详解
Jul 03 Python
Python Process多进程实现过程
Oct 22 Python
Django1.11自带分页器paginator的使用方法
Oct 31 Python
python开发飞机大战游戏
Jul 15 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
先进的自动咖啡技术,真的可以取代咖啡师吗?
2021/03/06 冲泡冲煮
jquery不支持toggle()高(新)版本的问题解决
2016/09/24 PHP
php+mysql+jquery实现日历签到功能
2017/02/27 PHP
laravel框架路由分组,中间件,命名空间,子域名,路由前缀实例分析
2020/02/18 PHP
PHP设计模式之 策略模式Strategy详解【对象行为型】
2020/05/01 PHP
jQuery学习笔记 操作jQuery对象 CSS处理
2012/09/19 Javascript
AngularJS iframe跨域打开内容时报错误的解决办法
2015/01/26 Javascript
如何使用jQuery技术开发ios风格的页面导航菜单
2015/07/29 Javascript
Windows 系统下设置Nodejs NPM全局路径
2016/04/26 NodeJs
微信小程序 教程之条件渲染
2016/10/18 Javascript
Bootstrap基本组件学习笔记之按钮组(8)
2016/12/07 Javascript
JavaScript实现两个select下拉框选项左移右移
2017/03/09 Javascript
jquery submit()不能提交表单的解决方法
2017/04/24 jQuery
浅谈JS函数节流防抖
2017/10/18 Javascript
在vue项目中使用sass的配置方法
2018/03/20 Javascript
Angular父组件调用子组件的方法
2018/04/02 Javascript
通过扫小程序码实现网站登陆功能
2019/08/22 Javascript
vue 动态生成拓扑图的示例
2021/01/03 Vue.js
如何在 Vue 表单中处理图片
2021/01/26 Vue.js
使用Python代码实现Linux中的ls遍历目录命令的实例代码
2019/09/07 Python
Django模板语言 Tags使用详解
2019/09/09 Python
Python编写打字训练小程序
2019/09/26 Python
python实现输出一个序列的所有子序列示例
2019/11/18 Python
解决python父线程关闭后子线程不关闭问题
2020/04/25 Python
css3 position fixed固定居中问题解决方案
2014/08/19 HTML / CSS
H5仿微信界面教程(一)
2017/07/05 HTML / CSS
H5离线存储Manifest原理及使用
2020/04/28 HTML / CSS
迪拜航空官方网站:flydubai
2017/04/20 全球购物
Ibood荷兰:互联网每日最佳在线优惠
2019/02/28 全球购物
英国奢侈品在线精品店:Hervia
2020/09/03 全球购物
C#和SQL Server的面试题
2016/08/12 面试题
Python里面如何拷贝一个对象
2014/02/17 面试题
学校食堂食品安全责任书
2014/07/28 职场文书
教师辞职书范文
2015/02/26 职场文书
出纳试用期自我评价
2015/03/10 职场文书
2015年教师节慰问信
2015/03/23 职场文书