tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作sqlite3快速、安全插入数据(防注入)的实例
Apr 26 Python
Python读取ini文件、操作mysql、发送邮件实例
Jan 01 Python
Python实现的摇骰子猜大小功能小游戏示例
Dec 18 Python
Python实现的视频播放器功能完整示例
Feb 01 Python
python写一个md5解密器示例
Feb 23 Python
Python+OpenCV实现图像融合的原理及代码
Dec 03 Python
在Qt中正确的设置窗体的背景图片的几种方法总结
Jun 19 Python
使用python将mysql数据库的数据转换为json数据的方法
Jul 01 Python
python3+opencv 使用灰度直方图来判断图片的亮暗操作
Jun 02 Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 Python
python logging模块的使用
Sep 07 Python
selenium判断元素是否存在的两种方法小结
Dec 07 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
基于AppServ,XAMPP,WAMP配置php.ini去掉警告信息(NOTICE)的方法详解
2013/05/07 PHP
PHP+MYSQL实现用户的增删改查
2015/03/24 PHP
php生成验证码函数
2015/10/20 PHP
WordPress特定文章对搜索引擎隐藏或只允许搜索引擎查看
2015/12/31 PHP
PHP实现对二维数组某个键排序的方法
2016/09/14 PHP
PHP的自定义模板引擎
2017/03/24 PHP
PHP实现的激活用户注册验证邮箱功能示例
2017/06/06 PHP
PHP实现微信图片上传到服务器的方法示例
2017/06/29 PHP
PHP+JS实现的实时搜索提示功能
2018/03/13 PHP
laravel利用中间件防止未登录用户直接访问后台的方法
2019/09/30 PHP
也说JavaScript中String类的replace函数
2011/09/22 Javascript
jquery操作checkbox实现全选和取消全选
2014/05/02 Javascript
javascript常用方法总结
2015/05/14 Javascript
javascript基础知识
2016/06/07 Javascript
纯javascript版日历控件
2016/11/24 Javascript
详解vuex的简单使用
2018/03/12 Javascript
浅谈JavaScript面向对象--继承
2019/03/20 Javascript
开源一个微信小程序仪表盘组件过程解析
2019/07/30 Javascript
微信小程序中weui用法解析
2019/10/21 Javascript
[02:17]2016国际邀请赛中国区预选赛VG战队领队采访
2016/06/26 DOTA
python的三目运算符和not in运算符使用示例
2014/03/03 Python
Python with的用法
2014/08/22 Python
Python实现的NN神经网络算法完整示例
2018/06/19 Python
在Python中定义一个常量的方法
2018/11/10 Python
Python爬虫程序架构和运行流程原理解析
2020/03/09 Python
浅析关于Keras的安装(pycharm)和初步理解
2020/10/23 Python
Python高并发和多线程有什么关系
2020/11/14 Python
python爬虫如何解决图片验证码
2021/02/14 Python
详解通过HTML5 Canvas实现图片的平移及旋转变化的方法
2016/03/22 HTML / CSS
元旦晚会邀请函
2014/01/27 职场文书
高二物理教学反思
2014/02/08 职场文书
入党综合考察材料
2014/06/02 职场文书
阅兵口号
2014/06/19 职场文书
村党建工作汇报材料
2014/11/02 职场文书
教您怎么制定西餐厅运营方案 ?
2019/07/05 职场文书
SQL实战演练之网上商城数据库商品类别数据操作
2021/10/24 MySQL