tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的Django基本命令实例详解
Jul 15 Python
python设置环境变量的原因和方法
Jun 24 Python
python中将两组数据放在一起按照某一固定顺序shuffle的实例
Jul 15 Python
python3 pillow模块实现简单验证码
Oct 31 Python
深入了解如何基于Python读写Kafka
Dec 31 Python
python字典和json.dumps()的遇到的坑分析
Mar 11 Python
PyQt使用QPropertyAnimation开发简单动画
Apr 02 Python
Pandas读取csv时如何设置列名
Jun 02 Python
keras中的loss、optimizer、metrics用法
Jun 15 Python
Keras-多输入多输出实例(多任务)
Jun 22 Python
Python用SSH连接到网络设备
Feb 18 Python
Python借助with语句实现代码段只执行有限次
Mar 23 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
建立文件交换功能的脚本(三)
2006/10/09 PHP
php图片裁剪函数
2018/10/31 PHP
初试jQuery EasyUI 使用介绍
2010/04/01 Javascript
探讨在JQuery和Js中,如何让ajax执行完后再继续往下执行
2013/07/09 Javascript
jquery $.each 和for怎么跳出循环终止本次循环
2013/09/27 Javascript
jquery操作 iframe的方法
2014/12/03 Javascript
JS从一组数据中找到指定的单条数据的方法
2016/06/02 Javascript
原生js实现水平方向无缝滚动
2017/01/10 Javascript
js正则表达式验证表单【完整版】
2017/03/06 Javascript
详解vue2.0组件通信各种情况总结与实例分析
2017/03/22 Javascript
AngularJS获取json数据的方法详解
2017/05/27 Javascript
深入理解Angular.JS中的Scope继承
2017/06/04 Javascript
jQuery实现简单日期格式化功能示例
2017/09/19 jQuery
jQuery动态操作表单示例【基于table表格】
2018/12/06 jQuery
vue项目引入ts步骤(小结)
2019/10/31 Javascript
在vue中利用全局路由钩子给url统一添加公共参数的例子
2019/11/01 Javascript
JavaScript常用工具函数汇总(浏览器环境)
2020/09/17 Javascript
分享Python文本生成二维码实例
2016/01/06 Python
linux平台使用Python制作BT种子并获取BT种子信息的方法
2017/01/20 Python
开源软件包和环境管理系统Anaconda的安装使用
2017/09/04 Python
Django在win10下的安装并创建工程
2017/11/20 Python
numpy中的delete删除数组整行和整列的实例
2018/05/09 Python
Python爬虫实现获取动态gif格式搞笑图片的方法示例
2018/12/24 Python
在Pycharm中对代码进行注释和缩进的方法详解
2019/01/20 Python
python学习将数据写入文件并保存方法
2020/06/07 Python
MATCHESFASHION.COM美国官网:英国奢侈品零售商
2018/10/29 全球购物
英国独特家具和家庭用品购物网站:Cuckooland
2020/08/30 全球购物
通信工程专业女生个人求职信
2013/09/21 职场文书
洗发水广告词
2014/03/13 职场文书
教师党的群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
运动会开幕词
2015/01/28 职场文书
稽核岗位职责
2015/02/10 职场文书
董事长秘书岗位职责
2015/02/13 职场文书
汤姆索亚历险记读书笔记
2015/06/29 职场文书
商业计划书格式、范文
2019/03/21 职场文书
5种方法告诉你如何使JavaScript 代码库更干净
2021/09/15 Javascript