TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中计算三角函数之cos()方法的使用简介
May 15 Python
Python装饰器实现几类验证功能做法实例
May 18 Python
通过python顺序修改文件名字的方法
Jul 11 Python
详解Python 装饰器执行顺序迷思
Aug 08 Python
python对于requests的封装方法详解
Jan 03 Python
使用Python实现企业微信的自动打卡功能
Apr 30 Python
numpy求平均值的维度设定的例子
Aug 24 Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 Python
python ETL工具 pyetl
Jun 07 Python
python中pandas库中DataFrame对行和列的操作使用方法示例
Jun 14 Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 Python
如何使用Python提取Chrome浏览器保存的密码
Jun 09 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
星际争霸 Starcraft 秘技补丁
2020/03/14 星际争霸
PHP的栏目导航程序
2006/10/09 PHP
PHP页面跳转实现延时跳转的方法
2016/12/10 PHP
Yii 使用intervention/image拓展实现图像处理功能
2019/06/22 PHP
Ext JS Grid在IE6 下宽度的问题解决方法
2009/02/15 Javascript
js实现简单登录功能的实例代码
2013/11/09 Javascript
获取select元素被选中的文本内容的js代码
2014/01/29 Javascript
javascript进行数组追加方法小结
2014/06/16 Javascript
node.js下when.js 的异步编程实践
2014/12/03 Javascript
jQuery中removeAttr()方法用法实例
2015/01/05 Javascript
理解JavaScript的变量的入门教程
2015/07/07 Javascript
jQuery Ajax 加载数据时异步显示加载动画
2016/08/01 Javascript
JavaScript中boolean类型之三种情景实例代码
2016/11/21 Javascript
JavaScript函数参数的传递方式详解
2017/03/06 Javascript
JavaScript正则表达式函数总结(常用)
2018/02/22 Javascript
Postman模拟发送带token的请求方法
2018/03/31 Javascript
axios的拦截请求与响应方法
2018/08/11 Javascript
微信小程序实现watch监听
2020/06/04 Javascript
[05:08]2014DOTA2国际邀请赛 Hao专访复仇的胜利很爽
2014/07/15 DOTA
玩转python爬虫之cookie使用方法
2016/02/17 Python
python批量实现Word文件转换为PDF文件
2018/03/15 Python
Python实现简单求解给定整数的质因数算法示例
2018/03/25 Python
浅谈python 导入模块和解决文件句柄找不到问题
2018/12/15 Python
python itchat实现调用微信接口的第三方模块方法
2019/06/11 Python
Pyqt5 实现跳转界面并关闭当前界面的方法
2019/06/19 Python
python 返回一个列表中第二大的数方法
2019/07/09 Python
keras 使用Lambda 快速新建层 添加多个参数操作
2020/06/10 Python
快速实现一个简单的canvas迷宫游戏的示例
2018/07/04 HTML / CSS
菲律宾领先的在线时尚商店:Zalora菲律宾
2018/02/08 全球购物
安全责任书模板
2014/07/22 职场文书
有子女的离婚协议书怎么写(范本)
2014/09/29 职场文书
如何利用JavaScript实现二叉搜索树
2021/04/02 Javascript
如何搭建 MySQL 高可用高性能集群
2021/06/21 MySQL
解决mysql:ERROR 1045 (28000): Access denied for user 'root'@'localhost' (using password: NO/YES)
2021/06/26 MySQL
Log4j.properties配置及其使用
2021/08/02 Java/Android
python数字图像处理之图像自动阈值分割示例
2022/06/28 Python