tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入分析在Python模块顶层运行的代码引起的一个Bug
Jul 04 Python
详解Python如何获取列表(List)的中位数
Aug 12 Python
Python 模板引擎的注入问题分析
Jan 01 Python
Python中动态创建类实例的方法
Mar 24 Python
python学习必备知识汇总
Sep 08 Python
Python实现多线程的两种方式分析
Aug 29 Python
python try 异常处理(史上最全)
Mar 07 Python
Flask框架单例模式实现方法详解
Jul 31 Python
python飞机大战pygame游戏框架搭建操作详解
Dec 17 Python
使用Tensorboard工具查看Loss损失率
Feb 15 Python
python开发实例之python使用Websocket库开发简单聊天工具实例详解(python+Websocket+JS)
Mar 18 Python
python中upper是做什么用的
Jul 20 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
PHP中func_get_args(),func_get_arg(),func_num_args()的区别
2013/09/30 PHP
php cli换行示例
2014/04/22 PHP
destoon网站转移服务器后搜索汉字出现乱码的解决方法
2014/06/21 PHP
ajax+php控制所有后台函数调用
2015/07/15 PHP
PHP实现的超长文本分页显示功能示例
2018/06/04 PHP
PHP rsa加密解密算法原理解析
2020/12/09 PHP
关于JS中的闭包浅谈
2013/08/23 Javascript
JavaScript中变量声明有var和没var的区别示例介绍
2014/09/15 Javascript
轻松创建nodejs服务器(2):nodejs服务器的构成分析
2014/12/18 NodeJs
jQuery制作简洁的多级联动Select下拉框
2014/12/23 Javascript
分享有关jQuery中animate、slide、fade等动画的连续触发、滞后反复执行的bug
2016/01/10 Javascript
JavaScript高级程序设计(第三版)学习笔记1~5章
2016/03/11 Javascript
jquery实现跳到底部,回到顶部效果的简单实例(类似锚)
2016/07/10 Javascript
AngularJS 中使用Swiper制作滚动图不能滑动的解决方法
2016/11/15 Javascript
详解angular中通过$location获取路径(参数)的写法
2017/03/21 Javascript
vue中动态绑定表单元素的属性方法
2018/02/23 Javascript
vue ssr服务端渲染(小白解惑)
2019/11/10 Javascript
[58:15]2018DOTA2亚洲邀请赛 4.1 小组赛 A组 NB vs Liquid
2018/04/02 DOTA
[50:28]2018DOTA2亚洲邀请赛 3.31 小组赛 A组 Newbee vs KG
2018/04/01 DOTA
Python处理文本文件中控制字符的方法
2017/02/07 Python
Python-Tkinter Text输入内容在界面显示的实例
2019/07/12 Python
Python lxml模块的基本使用方法分析
2019/12/21 Python
HTML5资源预加载(Link prefetch)详细介绍(给你的网页加速)
2014/05/07 HTML / CSS
印度最大的网上花店:Ferns N Petals(鲜花、礼品和蛋糕)
2017/10/16 全球购物
加拿大专业美发产品购物网站:Chatters
2021/02/28 全球购物
什么情况下你必须要把一个类定义为abstract的
2013/01/06 面试题
Linux如何修改文件和文件夹的权限
2013/09/05 面试题
瀑布模型都有哪些优缺点
2014/06/23 面试题
抽奖活动主持词
2014/03/31 职场文书
小学生操行评语
2014/04/22 职场文书
2014年审计工作总结
2014/11/17 职场文书
2014年城管个人工作总结
2014/12/08 职场文书
教师“一帮一”结对子活动总结
2015/05/07 职场文书
2015年班组建设工作总结
2015/05/13 职场文书
MySQL系列之十一 日志记录
2021/07/02 MySQL
Python函数式编程中itertools模块详解
2021/09/15 Python