tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读取Android permission文件
Nov 01 Python
Python通过PIL获取图片主要颜色并和颜色库进行对比的方法
Mar 19 Python
Python实现读取邮箱中的邮件功能示例【含文本及附件】
Aug 05 Python
在python3环境下的Django中使用MySQL数据库的实例
Aug 29 Python
浅谈Python实现2种文件复制的方法
Jan 19 Python
Python3.6通过自带的urllib通过get或post方法请求url的实例
May 10 Python
python自动截取需要区域,进行图像识别的方法
May 17 Python
python3.6的venv模块使用详解
Aug 01 Python
对python中词典的values值的修改或新增KEY详解
Jan 20 Python
Python实现删除排序数组中重复项的两种方法示例
Jan 31 Python
python实现交并比IOU教程
Apr 16 Python
python实现每天自动签到领积分的示例代码
Aug 18 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
PHP字符串word末字符实现大小写互换的方法
2014/11/10 PHP
如何在HTML 中嵌入 PHP 代码
2015/05/13 PHP
老司机传授Ubuntu下Apache+PHP+MySQL环境搭建攻略
2016/03/20 PHP
PHP+AjaxForm异步带进度条上传文件实例代码
2017/08/14 PHP
PHP中关于php.ini参数优化详解
2020/02/28 PHP
JQury slideToggle闪烁问题及解决办法
2011/07/05 Javascript
Extjs4 消息框去掉关闭按钮(类似Ext.Msg.alert)
2013/04/02 Javascript
jQuery之ajax删除详解
2014/02/27 Javascript
JS调试必备的5个debug技巧
2014/03/07 Javascript
nodejs教程之异步I/O
2014/11/21 NodeJs
4种JavaScript实现简单tab选项卡切换的方法
2016/01/06 Javascript
JS实现将Asp.Net的DateTime Json类型转换为标准时间的方法
2016/08/02 Javascript
jQuery上传多张图片带进度条样式(DEMO)
2017/03/02 Javascript
JavaScript字符串检索字符的方法
2017/06/23 Javascript
Vue利用canvas实现移动端手写板的方法
2018/05/03 Javascript
vue-cli3 karma单元测试的实现
2019/01/18 Javascript
js如何获取图片url的Blob值并预览示例代码
2019/03/07 Javascript
12个提高JavaScript技能的概念(小结)
2019/05/09 Javascript
浅谈一种让小程序支持JSX语法的新思路
2019/06/16 Javascript
浅谈layer的Icon样式以及一些常用的layer窗口使用方法
2019/09/11 Javascript
JS关闭子窗口并且刷新上一个窗口的实现示例
2020/03/10 Javascript
微信小程序自定义扫码功能界面的实现代码
2020/07/02 Javascript
解决vue加scoped后就无法修改vant的UI组件的样式问题
2020/09/07 Javascript
python双向链表实现实例代码
2013/11/21 Python
python模块简介之有序字典(OrderedDict)
2016/12/01 Python
Tensorflow使用支持向量机拟合线性回归
2018/09/07 Python
Python调用.NET库的方法步骤
2019/12/27 Python
关于初始种子自动选取的区域生长实例(python+opencv)
2020/01/16 Python
Tensorflow 卷积的梯度反向传播过程
2020/02/10 Python
Python如何使用turtle库绘制图形
2020/02/26 Python
速比涛英国官网:Speedo英国
2019/07/15 全球购物
竞选学生会主席演讲稿
2014/04/24 职场文书
小学教师师德承诺书
2014/05/23 职场文书
践行三严三实心得体会
2014/10/13 职场文书
趵突泉导游词
2015/02/03 职场文书
JS高级程序设计之class继承重点详解
2022/07/07 Javascript