tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现截屏的函数
Jul 25 Python
Python网络编程之TCP与UDP协议套接字用法示例
Feb 02 Python
matplotlib 纵坐标轴显示数据值的实例
May 25 Python
Python3正则匹配re.split,re.finditer及re.findall函数用法详解
Jun 11 Python
利用python打开摄像头及颜色检测方法
Aug 03 Python
让代码变得更易维护的7个Python库
Oct 09 Python
pandas 对日期类型数据的处理方法详解
Aug 08 Python
浅析使用Python搭建http服务器
Oct 27 Python
python去除删除数据中\u0000\u0001等unicode字符串的代码
Mar 06 Python
django model的update时auto_now不被更新的原因及解决方式
Apr 01 Python
基于FME使用Python过程图解
May 13 Python
利用python实时刷新基金估值(摸鱼小工具)
Sep 15 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
PHP开发框架Laravel数据库操作方法总结
2014/09/03 PHP
php绘制一个矩形的方法
2015/01/24 PHP
Linux下安装Memcached服务器和客户端与PHP使用示例
2019/04/15 PHP
phpstudy隐藏index.php的方法
2020/09/21 PHP
CodeMirror2 IE7/IE8 下面未知运行时错误的解决方法
2012/03/29 Javascript
关于JS中的闭包浅谈
2013/08/23 Javascript
JS方法调用括号的问题探讨
2014/01/24 Javascript
JS操作JSON方法总结(推荐)
2016/06/14 Javascript
url传递的参数值中包含&amp;时,url自动截断问题的解决方法
2016/08/02 Javascript
关于angularJs指令的Scope(作用域)介绍
2016/10/25 Javascript
jQuery双向列表选择器DIV模拟版
2016/11/01 Javascript
微信小程序 网络API Websocket详解
2016/11/09 Javascript
elementUI中Table表格问题的解决方法
2018/12/04 Javascript
JavaScript刷新页面的几种方法总结
2019/03/28 Javascript
vue多页面项目中路由使用history模式的方法
2019/09/23 Javascript
JS通用方法触发点击事件代码实例
2020/02/17 Javascript
简介JavaScript错误处理机制
2020/08/04 Javascript
vue中解决微信html5原生ios虚拟键返回不刷新问题
2020/10/20 Javascript
Python中实现对list做减法操作介绍
2015/01/09 Python
使用Python的Scrapy框架十分钟爬取美女图
2016/12/26 Python
python利用socketserver实现并发套接字功能
2018/01/26 Python
Python异步操作MySQL示例【使用aiomysql】
2019/05/16 Python
python经典趣味24点游戏程序设计
2019/07/26 Python
利用python list完成最简单的DB连接池方法
2019/08/09 Python
dpn网络的pytorch实现方式
2020/01/14 Python
Pyinstaller打包Scrapy项目的实现步骤
2020/09/22 Python
html5表单及新增的改良元素详解
2016/06/07 HTML / CSS
法国和欧洲海边和滑雪度假:Pierre & Vacances
2017/01/04 全球购物
中学生在校期间的自我评价分享
2013/11/13 职场文书
婚前协议书
2014/04/15 职场文书
本科应届生自荐信
2014/06/29 职场文书
学校教师读书活动总结
2014/07/08 职场文书
教师群众路线剖析材料
2014/09/29 职场文书
MySQL之PXC集群搭建的方法步骤
2021/05/25 MySQL
解决Navicat for Mysql连接报错1251的问题(连接失败)
2021/05/27 MySQL
python读取并查看npz/npy文件数据以及数据显示方法
2022/04/14 Python