tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教大家使用Python SqlAlchemy
Feb 12 Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 Python
Python unittest模块用法实例分析
May 25 Python
解决python中使用plot画图,图不显示的问题
Jul 04 Python
python使用epoll实现服务端的方法
Oct 16 Python
python 画二维、三维点之间的线段实现方法
Jul 07 Python
django解决订单并发问题【推荐】
Jul 31 Python
python批量读取文件名并写入txt文件中
Sep 05 Python
softmax及python实现过程解析
Sep 30 Python
对Python 字典元素进行删除的方法
Jul 31 Python
Python读写锁实现实现代码解析
Nov 28 Python
Python爬虫基础讲解之请求
May 13 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
Smarty模板快速入门
2007/01/04 PHP
PHP 缓存实现代码及详细注释
2010/05/16 PHP
神盾加密解密教程(二)PHP 神盾解密
2014/06/08 PHP
简单分析ucenter 会员同步登录通信原理
2014/08/25 PHP
Thinkphp多文件上传实现方法
2014/10/31 PHP
ThinkPHP框架实现的MySQL数据库备份功能示例
2018/05/24 PHP
php的RSA加密解密算法原理与用法分析
2020/01/23 PHP
JavaScript中令你抓狂的魔术变量
2006/11/30 Javascript
Dojo之路:如何利用Dojo实现Drag and Drop效果
2007/04/10 Javascript
JS 文件本身编码转换 图文教程
2009/10/12 Javascript
JQuery Tips(3) 关于$()包装集内元素的改变
2009/12/14 Javascript
JavaScript Event学习第十章 一些可替换的事件对
2010/02/10 Javascript
javascript禁用Tab键脚本实例
2013/11/22 Javascript
JS实现的用来对比两个用指定分隔符分割的字符串是否相同
2014/09/19 Javascript
跟我学习javascript的作用域与作用域链
2015/11/19 Javascript
javascript回到顶部特效
2016/07/30 Javascript
JS实现课堂随机点名和顺序点名
2017/03/09 Javascript
ES6新特性之字符串的扩展实例分析
2017/04/01 Javascript
bootstrap select插件封装成Vue2.0组件
2017/04/17 Javascript
浅析从vue源码看观察者模式
2018/01/29 Javascript
Vue单页应用引用单独的样式文件的两种方式
2018/03/30 Javascript
详解jQuery设置内容和属性
2019/04/11 jQuery
vue登录页面cookie的使用及页面跳转代码
2019/07/10 Javascript
js消除图片小游戏代码
2019/12/11 Javascript
[02:42]DOTA2英雄基础教程 杰奇洛
2013/12/23 DOTA
python深度优先搜索和广度优先搜索
2018/02/07 Python
Python异常处理操作实例详解
2018/08/28 Python
np.dot()函数的用法详解
2020/01/17 Python
浅谈python中频繁的print到底能浪费多长时间
2020/02/21 Python
python 密码学示例——理解哈希(Hash)算法
2020/09/21 Python
python修改微信和支付宝步数的示例代码
2020/10/12 Python
南非领先的在线旅行社:Travelstart南非
2016/09/04 全球购物
2015年上半年党建工作总结
2015/03/30 职场文书
2015年党员岗位承诺书
2015/04/27 职场文书
Python anaconda安装库命令详解
2021/10/16 Python
python 判断字符串当中是否包含字符(str.contain)
2022/06/01 Python