tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python字典操作简明总结
Apr 13 Python
Python中time模块与datetime模块在使用中的不同之处
Nov 24 Python
Python2实现的LED大数字显示效果示例
Sep 04 Python
学习Python selenium自动化网页抓取器
Jan 20 Python
将python代码和注释分离的方法
Apr 21 Python
解决Ubuntu pip 安装 mysql-python包出错的问题
Jun 11 Python
python爱心表白 每天都是浪漫七夕!
Aug 18 Python
Pycharm之快速定位到某行快捷键的方法
Jan 20 Python
代码实例讲解python3的编码问题
Jul 08 Python
Python统计分析模块statistics用法示例
Sep 06 Python
jenkins配置python脚本定时任务过程图解
Oct 29 Python
python交互模式基础知识点学习
Jun 18 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
php 魔术方法使用说明
2009/10/20 PHP
PHP终止脚本运行三种实现方法详解
2020/09/01 PHP
Nigma vs Alliance BO5 第四场2.14
2021/03/10 DOTA
jQuery Ajax文件上传(php)
2009/06/16 Javascript
理解Javascript_08_函数对象
2010/10/15 Javascript
jQuery 表单验证扩展代码(二)
2010/10/20 Javascript
js获取php变量的实现代码
2013/08/10 Javascript
原生js ActiveXObject获取execl里面的值
2013/11/01 Javascript
带左右箭头图片轮播的JS代码
2013/12/18 Javascript
Javascript 浮点运算精度问题分析与解决
2014/03/26 Javascript
当jquery ajax遇上401请求的解决方法
2016/05/19 Javascript
自己动手制作基于jQuery的Web页面加载进度条插件
2016/06/03 Javascript
jQuery Easyui datagrid editor为combobox时指定数据源实例
2016/12/19 Javascript
详解Vue.js动态绑定class
2016/12/20 Javascript
微信小程序Redux绑定实例详解
2017/06/07 Javascript
axios中cookie跨域及相关配置示例详解
2017/12/20 Javascript
微信小程序点击顶部导航栏切换样式代码实例
2019/11/12 Javascript
python的描述符(descriptor)、装饰器(property)造成的一个无限递归问题分享
2014/07/09 Python
Python 正则表达式(转义问题)
2014/12/15 Python
Python编程入门的一些基本知识
2015/05/13 Python
python二分查找算法的递归实现方法
2016/05/12 Python
python3+PyQt5实现使用剪贴板做复制与粘帖示例
2017/01/24 Python
Python算法之图的遍历
2017/11/16 Python
tensorflow 加载部分变量的实例讲解
2018/07/27 Python
Python实现的插入排序,冒泡排序,快速排序,选择排序算法示例
2019/05/04 Python
Pytorch实现各种2d卷积示例
2019/12/30 Python
python 实现人和电脑猜拳的示例代码
2020/03/02 Python
师德师风自我评价范文
2014/09/11 职场文书
2014年生活老师工作总结
2014/12/23 职场文书
辩论赛主持人开场白
2015/05/29 职场文书
2015暑假实习报告范文
2015/07/13 职场文书
关于开学的感想
2015/08/10 职场文书
初中班主任培训心得体会
2016/01/07 职场文书
演讲稿:​快乐,从不抱怨开始!
2019/04/02 职场文书
2021年pycharm的最新安装教程及基本使用图文详解
2021/04/03 Python
Element-ui Layout布局(Row和Col组件)的实现
2021/12/06 Vue.js