浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的Tqdm模块的使用
Jan 10 Python
python数字图像处理实现直方图与均衡化
May 04 Python
python try except 捕获所有异常的实例
Oct 18 Python
修改python plot折线图的坐标轴刻度方法
Dec 13 Python
Python编写打字训练小程序
Sep 26 Python
python3 webp转gif格式的实现示例
Dec 10 Python
Numpy与Pytorch 矩阵操作方式
Dec 27 Python
Python+OpenCV实现旋转文本校正方式
Jan 09 Python
python中wx模块的具体使用方法
May 15 Python
Python3实现建造者模式的示例代码
Jun 28 Python
Django正则URL匹配实现流程解析
Nov 13 Python
Python进阶学习之带你探寻Python类的鼻祖-元类
May 08 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
简单的php写入数据库类代码分享
2011/07/26 PHP
解析php时间戳与日期的转换
2013/06/06 PHP
php echo, print, print_r, sprintf, var_dump, var_expor的使用区别
2013/06/20 PHP
Smarty foreach控制循环次数的实现详解
2013/07/03 PHP
php+ajax实现的点击浏览量加1
2015/04/16 PHP
ThinkPHP在Cli模式下使用模板引擎的方法
2015/09/25 PHP
ThinkPHP实现分页功能
2017/04/28 PHP
用jquery实现自定义风格的滑动条实现代码
2011/04/26 Javascript
循环 vs 递归浅谈
2013/02/28 Javascript
详解js运算符单竖杠“|”与“||”的用法和作用介绍
2016/11/04 Javascript
AngularJs上传前预览图片的实例代码
2017/01/20 Javascript
Angular实现图片裁剪工具ngImgCrop实践
2017/08/17 Javascript
详解vue引入子组件方法
2019/02/12 Javascript
vue实现的请求服务器端API接口示例
2019/05/25 Javascript
Vue代码整洁之去重方法整理
2019/08/06 Javascript
Vue实现图片与文字混输效果
2019/12/04 Javascript
javascript使用Blob对象实现的下载文件操作示例
2020/04/18 Javascript
js实现页面导航层级指示效果
2020/08/25 Javascript
[02:56]DOTA2矮人直升机 英雄基础教程
2013/11/26 DOTA
Python3实现的腾讯微博自动发帖小工具
2013/11/11 Python
遍历python字典几种方法总结(推荐)
2016/09/11 Python
Python3获取拉勾网招聘信息的方法实例
2019/04/03 Python
python实现DEM数据的阴影生成的方法
2019/07/23 Python
termux中matplotlib无法显示中文问题的解决方法
2021/01/11 Python
Pytorch之扩充tensor的操作
2021/03/04 Python
某公司部分笔试题
2013/11/05 面试题
2014年三八妇女节活动方案
2014/02/28 职场文书
诚信考试承诺书
2014/03/27 职场文书
自强自立美德少年事迹材料
2014/08/16 职场文书
2015年新农合工作总结
2015/03/30 职场文书
兴趣班停课通知
2015/04/24 职场文书
刑事撤诉申请书
2015/05/18 职场文书
2016年小学生迎国庆广播稿
2015/12/18 职场文书
党员干部学习心得体会
2016/01/23 职场文书
分享15个Webpack实用的插件!!!
2021/03/31 Javascript
Django+Celery实现定时任务的示例
2021/06/23 Python