TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python类参数self使用示例
Feb 17 Python
解密Python中的描述符(descriptor)
Jun 03 Python
Python使用functools模块中的partial函数生成偏函数
Jul 02 Python
Python读取和处理文件后缀为.sqlite的数据文件(实例讲解)
Jun 27 Python
Python使用numpy模块创建数组操作示例
Jun 20 Python
Python Flask框架扩展操作示例
May 03 Python
python gdal安装与简单使用
Aug 01 Python
Python读取JSON数据操作实例解析
May 18 Python
Python字典实现伪切片功能
Oct 28 Python
Pycharm安装第三方库失败解决方案
Nov 17 Python
Python3利用scapy局域网实现自动多线程arp扫描功能
Jan 21 Python
python3.7.2 tkinter entry框限定输入数字的操作
May 22 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
PHP获取网卡地址的代码
2008/04/09 PHP
php读取数据库信息的几种方法
2008/05/24 PHP
php循环语句 for()与foreach()用法区别介绍
2012/09/05 PHP
PHP实现指定字段的多维数组排序函数分享
2015/03/09 PHP
php源码分析之DZX1.5加密解密函数authcode用法
2015/06/17 PHP
php面向对象与面向过程两种方法给图片添加文字水印
2015/08/26 PHP
thinkphp中多表查询中防止数据重复的sql语句(必看)
2016/09/22 PHP
JavaScript 获得选中文本内容的方法
2009/02/15 Javascript
Array的push与unshift方法性能比较分析
2011/03/05 Javascript
解决jquery插件冲突的问题
2014/01/23 Javascript
Vue 父子组件、组件间通信
2017/03/08 Javascript
原生JS实现不断变化的标签
2017/05/22 Javascript
详解小程序rich-text对富文本支持方案
2018/11/28 Javascript
vue实现axios图片上传功能
2019/08/20 Javascript
vue实现权限控制路由(vue-router 动态添加路由)
2019/11/04 Javascript
vue prop属性传值与传引用示例
2019/11/13 Javascript
[52:05]EG vs OG 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/18 DOTA
Python数据结构之单链表详解
2017/09/12 Python
Python实现对文件进行单词划分并去重排序操作示例
2018/07/10 Python
python requests post多层字典的方法
2018/12/27 Python
对Django项目中的ORM映射与模糊查询的使用详解
2019/07/18 Python
Python实用库 PrettyTable 学习笔记
2019/08/06 Python
Django实现将views.py中的数据传递到前端html页面,并展示
2020/03/16 Python
Python虚拟环境venv用法详解
2020/05/25 Python
python爬虫看看虎牙女主播中谁最“顶”步骤详解
2020/12/01 Python
美国著名的团购网站:Woot
2016/08/02 全球购物
Microsoft新加坡官方网站:购买微软最新软件和技术产品
2016/10/28 全球购物
建筑施工员岗位职责
2013/11/26 职场文书
大学生个人自荐信样本
2014/03/02 职场文书
银行竞聘演讲稿范文
2014/04/23 职场文书
难忘的一天教学反思
2014/04/30 职场文书
2015年清明节网上祭英烈留言寄语
2015/03/04 职场文书
公司员工辞职信范文
2015/05/12 职场文书
2015年档案管理员工作总结
2015/05/13 职场文书
海弦WR-800F
2022/04/05 无线电
Golang bufio详细讲解
2022/04/21 Golang