tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python sys模块sys.path使用方法示例
Dec 04 Python
php使用递归与迭代实现快速排序示例
Jan 23 Python
python处理cookie详解
Feb 07 Python
举例讲解Python中的死锁、可重入锁和互斥锁
Nov 05 Python
PyCharm在win10的64位系统安装实例
Nov 26 Python
30秒轻松实现TensorFlow物体检测
Mar 14 Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 Python
浅谈python 读excel数值为浮点型的问题
Dec 25 Python
Django Admin中增加导出CSV功能过程解析
Sep 04 Python
Python标准库:内置函数max(iterable, *[, key, default])说明
Apr 25 Python
基于Python-turtle库绘制路飞的草帽骷髅旗、美国队长的盾牌、高达的源码
Feb 18 Python
Python List remove()实例用法详解
Aug 02 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
php file_put_contents()功能函数(集成了fopen、fwrite、fclose)
2011/05/24 PHP
php数据结构与算法(PHP描述) 快速排序 quick sort
2012/06/21 PHP
ThinkPHP视图查询详解
2014/06/30 PHP
Yii2 中实现单点登录的方法
2018/03/09 PHP
js打印纸函数代码(递归)
2010/06/18 Javascript
jquery异步请求实例代码
2011/06/21 Javascript
元素未显示设置width/height时IE中使用currentStyle获取为auto
2014/05/04 Javascript
js 获取浏览器版本以此来调整CSS的样式
2014/06/03 Javascript
javascript生成大小写字母
2015/07/03 Javascript
浅析如何利用angular结合translate为项目实现国际化
2016/12/08 Javascript
bootstrap table之通用方法( 时间控件,导出,动态下拉框, 表单验证 ,选中与获取信息)代码分享
2017/01/24 Javascript
Vue2.0 v-for filter列表过滤功能的实现
2018/09/07 Javascript
JS实现获取数组中最大值或最小值功能示例
2019/03/02 Javascript
在JavaScript中使用严格模式(Strict Mode)
2019/06/13 Javascript
JavaScript 严格模式(use strict)用法实例分析
2020/03/04 Javascript
JavaScript编写开发动态时钟
2020/07/29 Javascript
编写v-for循环的技巧汇总
2020/12/01 Javascript
django自定义Field实现一个字段存储以逗号分隔的字符串
2014/04/27 Python
深入浅析Python中的yield关键字
2018/01/24 Python
python脚本生成caffe train_list.txt的方法
2018/04/27 Python
Python实现调用另一个路径下py文件中的函数方法总结
2018/06/07 Python
python3解析库lxml的安装与基本使用
2018/06/27 Python
详解django中使用定时任务的方法
2018/09/27 Python
python matplotlib库绘制散点图例题解析
2019/08/10 Python
Python测试模块doctest使用解析
2019/08/10 Python
Python实现打印实心和空心菱形
2019/11/23 Python
python 利用已有Ner模型进行数据清洗合并代码
2019/12/24 Python
Win10下用Anaconda安装TensorFlow(图文教程)
2020/06/18 Python
利用CSS3实现开门效果实例源码
2016/08/22 HTML / CSS
Huda Beauty官方商店:化妆和美容产品
2020/09/05 全球购物
旅游专业职业生涯规划范文
2014/01/13 职场文书
保健品市场营销方案
2014/03/31 职场文书
区域销售主管岗位职责
2014/06/15 职场文书
出纳年终工作总结2014
2014/12/05 职场文书
对领导班子的意见和建议
2015/06/08 职场文书
幼儿园庆元旦主持词
2015/07/06 职场文书