tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现备份文件实例
Sep 16 Python
python使用calendar输出指定年份全年日历的方法
Apr 04 Python
使用Python的Treq on Twisted来进行HTTP压力测试
Apr 16 Python
详解Python发送邮件实例
Jan 10 Python
TensorFlow模型保存/载入的两种方法
Mar 08 Python
python登录WeChat 实现自动回复实例详解
May 28 Python
Python3的高阶函数map,reduce,filter的示例详解
Jul 23 Python
win10安装tesserocr配置 Python使用tesserocr识别字母数字验证码
Jan 16 Python
Django 解决上传文件时,request.FILES为空的问题
May 20 Python
python3通过subprocess模块调用脚本并和脚本交互的操作
Dec 05 Python
pandas按照列的值排序(某一列或者多列)
Dec 13 Python
python3 kubernetes api的使用示例
Jan 12 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
php循环创建目录示例分享(php创建多级目录)
2014/03/04 PHP
如何使用PHP Embed SAPI实现Opcodes查看器
2015/11/10 PHP
微信公众号开发之获取位置信息php代码
2018/06/13 PHP
PHP中define() 与 const定义常量的区别详解
2019/06/25 PHP
PHP实现微信提现功能(微信商城)
2019/11/21 PHP
javascript编程起步(第七课)
2007/02/27 Javascript
jQuery 性能优化指南(2)
2009/05/21 Javascript
传智播客学习之java 反射
2009/11/22 Javascript
关于jQuery中的each方法(jQuery到底干了什么)
2014/03/05 Javascript
JavaScript实现自动生成网页元素功能(按钮、文本等)
2015/11/21 Javascript
Angular2中select用法之设置默认值与事件详解
2017/05/07 Javascript
详解基于webpack搭建react运行环境
2017/06/01 Javascript
React-router v4 路由配置方法小结
2017/08/08 Javascript
解决node修改后需频繁手动重启的问题
2018/05/13 Javascript
微信小程序自定义对话框弹出和隐藏动画
2018/07/19 Javascript
vue组件中的样式属性scoped实例详解
2018/10/30 Javascript
JavaScript数据结构与算法之检索算法示例【二分查找法、计算重复次数】
2019/02/22 Javascript
解决vue-cli 打包后自定义动画未执行的问题
2019/11/12 Javascript
[12:21]VICI vs TNC (BO3)
2018/06/07 DOTA
[10:18]2018DOTA2国际邀请赛寻真——找回自信的TNCPredator
2018/08/13 DOTA
使用Python操作Elasticsearch数据索引的教程
2015/04/08 Python
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
2016/01/20 Python
Python使用functools模块中的partial函数生成偏函数
2016/07/02 Python
Python简单操作sqlite3的方法示例
2017/03/22 Python
Python爬虫信息输入及页面的切换方法
2018/05/11 Python
DES加密解密算法之python实现版(图文并茂)
2018/12/06 Python
从0开始的Python学习016异常
2019/04/08 Python
2020最新pycharm汉化安装(python工程狮亲测有效)
2020/04/26 Python
Python使用Opencv实现边缘检测以及轮廓检测的实现
2020/12/31 Python
使用sublime text3搭建Python编辑环境的实现
2021/01/12 Python
css3 box-shadow阴影(外阴影与外发光)图示讲解
2017/08/11 HTML / CSS
英国最大的在线运动补充剂商店:Discount Supplements
2017/06/03 全球购物
店长岗位的工作内容
2013/11/12 职场文书
机械专业毕业生推荐信范文
2013/11/25 职场文书
2014年教师节演讲稿范文
2014/09/10 职场文书
基层党员学习党的群众路线教育实践活动心得体会
2014/11/04 职场文书