tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
零基础写python爬虫之抓取糗事百科代码分享
Nov 06 Python
解决出现Incorrect integer value: '' for column 'id' at row 1的问题
Oct 29 Python
Python编程之gui程序实现简单文件浏览器代码
Dec 08 Python
matplotlib绘制动画代码示例
Jan 02 Python
windows下python和pip安装教程
May 25 Python
解决python中画图时x,y轴名称出现中文乱码的问题
Jan 29 Python
Pycharm+django2.2+python3.6+MySQL实现简单的考试报名系统
Sep 05 Python
Python3合并两个有序数组代码实例
Aug 11 Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 Python
如何基于Python按行合并两个txt
Nov 03 Python
Django debug为True时,css加载失败的解决方案
Apr 24 Python
如何理解python接口自动化之logging日志模块
Jun 15 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
PHP5 的对象赋值机制介绍
2011/08/02 PHP
destoon实现调用图文新闻的方法
2014/08/21 PHP
PHP实现财务审核通过后返现金额到客户的功能
2019/07/04 PHP
Yii框架通过请求组件处理get,post请求的方法分析
2019/09/03 PHP
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
javascript 精粹笔记
2010/05/09 Javascript
一个js导致的jquery失效问题的解决方法
2013/11/27 Javascript
鼠标移入移出事件改变图片的分辨率的两种方法
2013/12/17 Javascript
基于replaceChild制作简单的吞噬特效
2015/09/21 Javascript
js操作数组函数实例小结
2015/12/10 Javascript
jQuery position() 函数详解以及jQuery中position函数的应用
2015/12/14 Javascript
如何使用PHP+jQuery+MySQL实现异步加载ECharts地图数据(附源码下载)
2016/02/23 Javascript
jquery设置表单元素为不可用的简单代码
2016/07/04 Javascript
Bootstrap 网站实例之单页营销网站
2016/10/20 Javascript
Bootstrap弹出框modal上层的输入框不能获得焦点问题的解决方法
2016/12/13 Javascript
用最简单的方法判断JavaScript中this的指向(推荐)
2017/09/04 Javascript
深入理解NodeJS 多进程和集群
2018/10/17 NodeJs
微信小程序功能之全屏滚动效果的实现代码
2018/11/22 Javascript
antd配置config-overrides.js文件的操作
2020/10/31 Javascript
Python性能提升之延迟初始化
2016/12/04 Python
Python虚拟环境项目实例
2017/11/20 Python
使用Python+Splinter自动刷新抢12306火车票
2018/01/03 Python
python如何查看微信消息撤回
2018/11/27 Python
django多个APP的urls设置方法(views重复问题解决)
2019/07/19 Python
python读取Excel表格文件的方法
2019/09/02 Python
如何用Python徒手写线性回归
2021/01/25 Python
css3新单位vw、vh的使用教程
2018/03/23 HTML / CSS
沪江旗下的海量优质课程平台:沪江网校
2017/11/07 全球购物
英国领先的游戏零售商:GAME
2019/09/24 全球购物
人力资源经理自我评价
2014/01/04 职场文书
关爱残疾人标语
2014/06/25 职场文书
影视广告专业求职信
2014/09/02 职场文书
党员创先争优心得体会
2014/09/11 职场文书
个人查摆剖析材料
2014/10/04 职场文书
药店营业员岗位职责
2015/04/14 职场文书
六一儿童节致辞
2015/07/31 职场文书