tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现字符串连接的三种方法及其效率、适用场景详解
Jan 13 Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 Python
Python函数参数操作详解
Aug 03 Python
pandas读取csv文件,分隔符参数sep的实例
Dec 12 Python
windows下安装Python虚拟环境virtualenvwrapper-win
Jun 14 Python
python opencv 图像拼接的实现方法
Jun 27 Python
python自动化测试无法启动谷歌浏览器问题
Oct 10 Python
python获取系统内存占用信息的实例方法
Jul 17 Python
python 使用OpenCV进行简单的人像分割与合成
Feb 02 Python
Scrapy实现模拟登录的示例代码
Feb 21 Python
python中os.path.join()函数实例用法
May 26 Python
python面向对象版学生信息管理系统
Jun 24 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
用PHP的超级变量$_GET获取HTML表单(Form) 数据
2011/05/07 PHP
php解析xml提示Invalid byte 1 of 1-byte UTF-8 sequence错误的处理方法
2013/11/14 PHP
PHP count_chars()函数讲解
2019/02/14 PHP
php设计模式之适配器模式原理、用法及注意事项详解
2019/09/24 PHP
thinkphp5 模型实例化获得数据对象的教程
2019/10/18 PHP
学习ExtJS Column布局
2009/10/08 Javascript
一步一步制作jquery插件Tabs实现过程
2010/07/06 Javascript
jquery和js实现对div的隐藏和显示方法
2014/09/26 Javascript
web前端设计师们常用的jQuery特效插件汇总
2014/12/07 Javascript
jquery利用命名空间移除绑定事件的方法
2015/03/11 Javascript
基于AngularJS+HTML+Groovy实现登录功能
2016/02/17 Javascript
浅析BootStrap中Modal(模态框)使用心得
2016/12/24 Javascript
JavaScript使用简单正则表达式的数据验证功能示例
2017/01/13 Javascript
关于javascript获取内联样式与嵌入式样式的实例
2017/06/01 Javascript
jQuery选择器之表单元素选择器详解
2017/09/19 jQuery
实现Vue的markdown文档可以在线运行的方法示例
2018/12/11 Javascript
jquery添加div实现消息聊天框
2020/02/08 jQuery
vue treeselect获取当前选中项的label实例
2020/08/31 Javascript
[01:23:24]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Elephant BO3 第三场 2月7日
2021/03/11 DOTA
python记录程序运行时间的三种方法
2017/07/14 Python
windows下cx_Freeze生成Python可执行程序的详细步骤
2018/10/09 Python
python 爬虫 实现增量去重和定时爬取实例
2020/02/28 Python
阿根廷首家户外用品制造商和经销商:Montagne
2018/02/12 全球购物
Homestay中文官网:全球寄宿家庭
2018/10/18 全球购物
整个世界的设计师家具在哈恩:Designathome
2019/03/25 全球购物
POP文化和音乐灵感的时尚:Hot Topic
2019/06/19 全球购物
SheIn沙特阿拉伯:女装在线
2020/03/23 全球购物
毕业实习个人鉴定范文
2013/12/10 职场文书
理工大学毕业生自荐信范文
2014/02/22 职场文书
学习实践科学发展观心得体会
2014/09/10 职场文书
12.4法制宣传日标语
2014/10/08 职场文书
廉洁自律承诺书2015
2015/01/22 职场文书
出纳试用期工作总结2015
2015/05/28 职场文书
合作合同协议书
2016/03/21 职场文书
如何做好员工培训计划?
2019/07/09 职场文书
八年级作文之感悟亲情
2019/11/20 职场文书