浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python代码的打包与发布详解
Jul 30 Python
Python中使用异常处理来判断运行的操作系统平台方法
Jan 22 Python
Python中死锁的形成示例及死锁情况的防止
Jun 14 Python
Python reduce()函数的用法小结
Nov 15 Python
详解如何利用Cython为Python代码加速
Jan 27 Python
python3使用SMTP发送HTML格式邮件
Jun 19 Python
Python给定一个句子倒序输出单词以及字母的方法
Dec 20 Python
Django 中自定义 Admin 样式与功能的实现方法
Jul 04 Python
python实现图片横向和纵向拼接
Mar 05 Python
Python实现鼠标自动在屏幕上随机移动功能
Mar 14 Python
python实现图像全景拼接
Mar 27 Python
python文件读取失败怎么处理
Jun 23 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
PHP 数组入门教程小结
2009/05/20 PHP
PHP使用DOMDocument类生成HTML实例(包含常见标签元素)
2014/06/25 PHP
thinkphp ajaxfileupload实现异步上传图片的示例
2017/08/28 PHP
PHP大文件分割上传 PHP分片上传
2017/08/28 PHP
PHP开发实现快递查询功能详解
2019/04/08 PHP
javascript fullscreen全屏实现代码
2009/04/09 Javascript
javascript encodeURI和encodeURIComponent的比较
2010/04/03 Javascript
functional继承模式 摘自javascript:the good parts
2011/06/20 Javascript
JS高级笔记
2011/07/13 Javascript
JavaScript和CSS交互的方法汇总
2014/12/02 Javascript
jQuery EasyUi实战教程之布局篇
2016/01/26 Javascript
Vue制作Todo List网页
2017/04/26 Javascript
捕获未处理的Promise错误方法
2017/10/13 Javascript
ES6之模版字符串的具体使用
2018/05/17 Javascript
MVVM框架下实现分页功能示例
2018/06/14 Javascript
浅谈JavaScript中this的指向更改
2020/07/28 Javascript
[06:45]DOTA2-DPC中国联赛 正赛 Magma vs LBZS 选手采访
2021/03/11 DOTA
Python enumerate遍历数组示例应用
2008/09/06 Python
Python 使用os.remove删除文件夹时报错的解决方法
2017/01/13 Python
python  创建一个保留重复值的列表的补码
2018/10/15 Python
python实现Dijkstra算法的最短路径问题
2019/06/21 Python
Django处理Ajax发送的Get请求代码详解
2019/07/29 Python
python实现加密的方式总结
2020/01/19 Python
Python 执行矩阵与线性代数运算
2020/08/01 Python
Python classmethod装饰器原理及用法解析
2020/10/17 Python
详解Html5 Canvas画线有毛边解决方法
2018/03/01 HTML / CSS
html5在移动端的屏幕适应问题示例探讨
2014/06/15 HTML / CSS
Web前端页面跳转并取到值
2017/04/24 HTML / CSS
小学教师师德反思
2014/02/03 职场文书
构建高效课堂实施方案
2014/03/13 职场文书
教师个人自我评价范文
2014/04/13 职场文书
工伤事故赔偿协议书
2014/04/15 职场文书
会计毕业生自荐书
2014/06/12 职场文书
关于感恩的演讲稿200字
2014/08/26 职场文书
个人查摆问题自查报告
2014/10/16 职场文书
化验室岗位职责
2015/02/14 职场文书