浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 中的列表解析和生成表达式
Mar 10 Python
35个Python编程小技巧
Apr 01 Python
介绍Python中的一些高级编程技巧
Apr 02 Python
Python中的异常处理简明介绍
Apr 13 Python
python中偏函数partial用法实例分析
Jul 08 Python
python监控文件或目录变化
Jun 07 Python
Python设计模式之代理模式简单示例
Jan 09 Python
Python搭建FTP服务器的方法示例
Jan 19 Python
CentOS7下python3.7.0安装教程
Jul 30 Python
Python笔试面试题小结
Sep 07 Python
Python实现自动签到脚本功能
Aug 20 Python
Python 实现Mac 屏幕截图详解
Oct 05 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
mysql4.1以上版本连接时出现Client does not support authentication protocol问题解决办法
2007/03/15 PHP
WordPress中限制非管理员用户在文章后只能评论一次
2015/12/31 PHP
PHP 接入支付宝即时到账功能
2016/09/18 PHP
json数据处理技巧(字段带空格、增加字段、排序等等)
2013/06/14 Javascript
JavaScript DOM 编程艺术(第2版)读书笔记(JavaScript的最佳实践)
2013/10/01 Javascript
jquery中ajax函数执行顺序问题之如何设置同步
2014/02/28 Javascript
javascript实现多级联动下拉菜单的方法
2015/02/06 Javascript
jQuery实现带动画效果的二级下拉导航方法
2015/03/11 Javascript
jQuery判断一个元素是否可见的方法
2015/06/05 Javascript
基于JQuery实现分隔条的功能
2016/06/17 Javascript
JavaScript常见的五种数组去重的方式
2016/12/15 Javascript
理解javascript中的Function.prototype.bind的方法
2017/02/03 Javascript
从零开始学习Node.js系列教程五:服务器监听方法示例
2017/04/13 Javascript
JS SetInterval 代码实现页面轮询
2017/08/11 Javascript
js实现可以点击收缩或张开的悬浮窗
2017/09/18 Javascript
浅谈Vue响应式(数组变异方法)
2018/05/07 Javascript
AngularJS ui-router刷新子页面路由的方法
2018/07/23 Javascript
jquery.pagination.js分页使用教程
2018/10/23 jQuery
jQuery 筛选器简单操作示例
2019/10/02 jQuery
vue页面更新patch的实现示例
2020/03/25 Javascript
JS中的变量作用域(console版)
2020/07/18 Javascript
Python 命令行非阻塞输入的小例子
2013/09/27 Python
Python实现遍历windows所有窗口并输出窗口标题的方法
2015/03/13 Python
连接Python程序与MySQL的教程
2015/04/29 Python
Python写的一个简单监控系统
2015/06/19 Python
python重试装饰器的简单实现方法
2019/01/31 Python
详解Django admin高级用法
2019/11/06 Python
在Python中使用filter去除列表中值为假及空字符串的例子
2019/11/18 Python
python实现数据清洗(缺失值与异常值处理)
2019/12/02 Python
基于Python-Pycharm实现的猴子摘桃小游戏(源代码)
2021/02/20 Python
New Balance天猫官方旗舰店:始于1906年,百年慢跑品牌
2017/11/15 全球购物
澳大利亚自然和有机的健康美容产品一站式商店:Ziani Beauty
2017/12/28 全球购物
部队2015年终工作总结
2015/04/02 职场文书
2015年小学教科研工作总结
2015/07/20 职场文书
高中诗歌鉴赏教学反思
2016/02/16 职场文书
Mysql中 unique列插入重复值该怎么解决呢
2021/05/26 MySQL