TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量提交沙箱问题实例
Oct 08 Python
分析用Python脚本关闭文件操作的机制
Jun 28 Python
python 读入多行数据的实例
Apr 19 Python
python如何创建TCP服务端和客户端
Aug 26 Python
Python编程在flask中模拟进行Restful的CRUD操作
Dec 28 Python
Python调用百度根据经纬度查询地址的示例代码
Jul 07 Python
python过滤中英文标点符号的实例代码
Jul 15 Python
python爬虫豆瓣网的模拟登录实现
Aug 21 Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 Python
python实现井字棋小游戏
Mar 04 Python
Python代码风格与编程习惯重要吗?
Jun 03 Python
Python实现批量将文件复制到新的目录中再修改名称
Apr 12 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
php XPath对XML文件查找及修改实现代码
2011/07/27 PHP
免费的ip数据库淘宝IP地址库简介和PHP调用实例
2014/04/08 PHP
PHP面试常用算法(推荐)
2016/07/22 PHP
PHP中ajax无刷新上传图片与图片下载功能
2017/02/21 PHP
js wmp操作代码小结(音乐连播功能)
2008/11/08 Javascript
jQuery写的日历(包括日历的样式及功能)
2013/04/23 Javascript
jQuery表格排序组件-tablesorter使用示例
2014/05/26 Javascript
JavaScript获取网页表单action属性的方法
2015/04/02 Javascript
JavaScript+html5 canvas制作色彩斑斓的正方形效果
2016/01/27 Javascript
Bootstrap表单组件教程详解
2016/04/26 Javascript
JS使用cookie设置样式的方法
2016/06/30 Javascript
完美解决jQuery fancybox ie 无法显示关闭按钮的问题
2016/11/29 Javascript
教你如何编写Vue.js的单元测试的方法
2018/10/17 Javascript
vue-next/runtime-core 源码阅读指南详解
2019/10/25 Javascript
vue各种事件监听实例(小结)
2020/06/24 Javascript
微信小程序实现签到弹窗动画
2020/09/21 Javascript
分享15个最受欢迎的Python开源框架
2014/07/13 Python
Python利用带权重随机数解决抽奖和游戏爆装备问题
2016/06/16 Python
zookeeper python接口实例详解
2018/01/18 Python
Python对象属性自动更新操作示例
2018/06/15 Python
详解Python中的format格式化函数的使用方法
2019/11/20 Python
python实现画出e指数函数的图像
2019/11/21 Python
python 中值滤波,椒盐去噪,图片增强实例
2019/12/18 Python
python高级特性简介
2020/08/13 Python
通往英国高街的商店橱窗:Down Your High Street
2020/07/19 全球购物
linux下进程间通信的方式
2013/01/23 面试题
应聘医药代表职位求职信
2013/10/21 职场文书
销售经理工作职责
2014/02/03 职场文书
2014年招生工作总结
2014/11/26 职场文书
金秋助学感谢信
2015/01/21 职场文书
班主任高考寄语
2015/02/26 职场文书
质量保证书格式
2015/02/27 职场文书
2015年财务部年度工作总结
2015/05/19 职场文书
Python常用配置文件ini、json、yaml读写总结
2021/07/09 Python
手把手带你彻底卸载MySQL数据库
2022/06/14 MySQL
浅谈为什么我的 z-index 又不生效了
2022/07/15 HTML / CSS