TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则表达式 概述及常用字符
May 04 Python
python基础教程之获取本机ip数据包示例
Feb 10 Python
python完成FizzBuzzWhizz问题(拉勾网面试题)示例
May 05 Python
从Python的源码浅要剖析Python的内存管理
Apr 16 Python
python简单的函数定义和用法实例
May 07 Python
Python并发编程协程(Coroutine)之Gevent详解
Dec 27 Python
Python日志模块logging基本用法分析
Aug 23 Python
python实现简单图书管理系统
Nov 22 Python
python将数组n等分的实例
Dec 02 Python
基于numpy中的expand_dims函数用法
Dec 18 Python
selenium WebDriverWait类等待机制的实现
Mar 18 Python
基于pycharm 项目和项目文件命名规则的介绍
Jan 15 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
967 个函式
2006/10/09 PHP
PHP导航下拉菜单的实现如此简单
2013/09/22 PHP
Android AsyncTack 异步任务实例详解
2016/11/02 PHP
PHP环境搭建(php+Apache+mysql)
2016/11/14 PHP
JS的反射问题
2010/04/07 Javascript
Dojo 学习要点
2010/09/03 Javascript
Jquery ajax传递复杂参数给WebService的实现代码
2011/08/08 Javascript
JS图片预加载 JS实现图片预加载应用
2012/12/03 Javascript
JS将表单导出成EXCEL的实例代码
2013/11/11 Javascript
AngularJS入门教程之Hello World!
2014/12/06 Javascript
jQuery设置和移除文本框默认值的方法
2015/03/09 Javascript
jQuery标签编辑插件Tagit使用指南
2015/04/21 Javascript
jQuery实现分隔条左右拖动功能
2015/11/21 Javascript
JavaScript字符串常用的方法
2016/03/10 Javascript
jQuery代码实现图片墙自动+手动淡入淡出切换效果
2016/05/09 Javascript
JS禁止查看网页源代码的实现方法
2016/10/12 Javascript
根据Bootstrap Paginator改写的js分页插件
2016/12/25 Javascript
如何利用JQuery实现从底部回到顶部的功能
2016/12/27 Javascript
深入理解基于vue-cli的vuex配置
2017/07/24 Javascript
ES6 class的应用实例分析
2019/06/27 Javascript
收藏整理的一些Python常用方法和技巧
2015/05/18 Python
利用Tkinter(python3.6)实现一个简单计算器
2017/12/21 Python
处理Selenium3+python3定位鼠标悬停才显示的元素
2019/07/31 Python
Python 中pandas索引切片读取数据缺失数据处理问题
2019/10/09 Python
python中设置超时跳过,超时退出的方式
2019/12/13 Python
Python计算指定日期是今年的第几天(三种方法)
2020/03/26 Python
浅谈keras中的目标函数和优化函数MSE用法
2020/06/10 Python
Python之字典添加元素的几种方法
2020/09/30 Python
几款Python编译器比较与推荐(小结)
2020/10/15 Python
Claire’s法国:时尚配饰、美容、珠宝、头发
2021/01/16 全球购物
电大自我鉴定范文
2013/10/01 职场文书
2014端午节活动策划方案
2014/01/27 职场文书
《明天,我们毕业》教学反思
2014/04/24 职场文书
暑期社会实践个人总结
2015/03/06 职场文书
win11系统中dhcp服务异常什么意思? Win11 DHCP服务异常修复方法
2022/04/08 数码科技
Python测试框架pytest高阶用法全面详解
2022/06/01 Python