TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单猜数字游戏
Apr 04 Python
Python中实现三目运算的方法
Jun 21 Python
python Django框架实现自定义表单提交
Mar 25 Python
python编写弹球游戏的实现代码
Mar 12 Python
python基础教程项目四之新闻聚合
Apr 02 Python
Django分页查询并返回jsons数据(中文乱码解决方法)
Aug 02 Python
Python时间和字符串转换操作实例分析
Mar 16 Python
Python自动生成代码 使用tkinter图形化操作并生成代码框架
Sep 18 Python
python selenium循环登陆网站的实现
Nov 04 Python
TensorFlow实现checkpoint文件转换为pb文件
Feb 10 Python
Pycharm 使用 Pipenv 新建的虚拟环境(图文详解)
Apr 16 Python
Python selenium模拟网页点击爬虫交管12123违章数据
May 26 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
php使用ffmpeg向视频中添加文字字幕的实现方法
2016/05/23 PHP
thinkPHP实现基于ajax的评论回复功能
2018/06/22 PHP
从jQuery.camelCase()学习string.replace() 函数学习
2011/09/13 Javascript
javascript之典型高阶函数应用介绍二
2013/01/10 Javascript
JavaScript子类用Object.getPrototypeOf去调用父类方法解析
2013/12/05 Javascript
Json实现异步请求提交评论无需跳转其他页面
2014/10/11 Javascript
JS判断是否360安全浏览器极速内核的方法
2015/01/29 Javascript
javascript实现类似java中getClass()得到对象类名的方法
2015/07/27 Javascript
简单介绍jsonp 使用小结
2016/01/27 Javascript
AngularJS 2.0新特性有哪些
2016/02/18 Javascript
浅谈Jquery中Ajax异步请求中的async参数的作用
2016/06/06 Javascript
jqPlot jQuery绘图插件的使用
2016/06/18 Javascript
Vue.js快速入门教程
2016/09/07 Javascript
BootStrap Table 后台数据绑定、特殊列处理、排序功能
2017/05/27 Javascript
vue2.0 自定义 饼状图 (Echarts)组件的方法
2018/03/02 Javascript
webpack4之如何编写loader的方法步骤
2019/06/06 Javascript
VUE实现移动端列表筛选功能
2019/08/23 Javascript
layui固定下拉框的显示条数(有滚动条)的方法
2019/09/10 Javascript
javascript实现时间日期的格式化的方法汇总
2020/08/06 Javascript
prettier自动格式化去换行的实现代码
2020/08/25 Javascript
[05:31]DOTA2上海特级锦标赛主赛事第三日RECAP
2016/03/05 DOTA
[01:19:35]DOTA2上海特级锦标赛主赛事日 - 3 败者组第三轮#2Fnatic VS OG第二局
2016/03/05 DOTA
python在多玩图片上下载妹子图的实现代码
2013/08/13 Python
利用Python中的pandas库对cdn日志进行分析详解
2017/03/07 Python
python Flask实现restful api service
2017/12/04 Python
浅谈Python中range和xrange的区别
2017/12/20 Python
python socket网络编程之粘包问题详解
2018/04/28 Python
Linux(Redhat)安装python3.6虚拟环境(推荐)
2018/05/05 Python
在python Numpy中求向量和矩阵的范数实例
2019/08/26 Python
Python获取统计自己的qq群成员信息的方法
2019/11/15 Python
基于python爬取链家二手房信息代码示例
2020/10/21 Python
2014年移动公司工作总结
2014/12/08 职场文书
销售经理岗位职责
2015/01/31 职场文书
2019年员工旷工保证书!
2019/06/28 职场文书
Nginx配置https原理及实现过程详解
2021/03/31 Servers
浅析Python中的随机采样和概率分布
2021/12/06 Python