tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python异步任务队列示例
Apr 01 Python
Python 的内置字符串方法小结
Mar 15 Python
Python多进程同步简单实现代码
Apr 27 Python
python 异常处理总结
Oct 18 Python
详解django中自定义标签和过滤器
Jul 03 Python
Python如何通过subprocess调用adb命令详解
Aug 27 Python
Python中的单继承与多继承实例分析
May 10 Python
python matplotlib 在指定的两个点之间连线方法
May 25 Python
Python3数据库操作包pymysql的操作方法
Jul 16 Python
python实现超市商品销售管理系统
Oct 25 Python
Pytorch mask_select 函数的用法详解
Feb 18 Python
django rest framework 过滤时间操作
Jul 12 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
PHP:风雨欲来 路在何方?
2006/10/09 PHP
通过JavaScript或PHP检测Android设备的代码
2011/03/09 PHP
用php实现选择排序的解决方法
2013/05/04 PHP
php返回字符串中所有单词的方法
2015/03/09 PHP
基于Laravel(5.4版本)的基本增删改查操作方法
2019/10/11 PHP
javascript 自动填写表单的实现方法
2010/04/09 Javascript
Jquery UI震动效果实现原理及步骤
2013/02/04 Javascript
jquery验证表单中的单选与多选实例
2013/08/18 Javascript
jquery中ready()函数执行的时机和window的load事件比较
2015/06/22 Javascript
详解JavaScript中常用的函数类型
2015/11/18 Javascript
jQuery删除节点用法示例(remove方法)
2016/09/08 Javascript
BootStrap的两种模态框方式
2017/05/10 Javascript
JavaScript之Map和Set_动力节点Java学院整理
2017/06/29 Javascript
解决jquery appaend元素中id绑定事件失效的问题
2017/09/12 jQuery
React中上传图片到七牛的示例代码
2017/10/10 Javascript
JS基于对象的特性实现去除数组中重复项功能详解
2017/11/17 Javascript
vue axios 给生产环境和发布环境配置不同的接口地址(推荐)
2018/05/08 Javascript
Javascript之高级数组API的使用实例
2019/03/08 Javascript
Python 获取新浪微博的最新公共微博实例分享
2014/07/03 Python
在python中按照特定顺序访问字典的方法详解
2018/12/14 Python
Django项目主urls导入应用中views的红线问题解决
2019/08/10 Python
解决os.path.isdir() 判断文件夹却返回false的问题
2019/11/29 Python
python+Selenium自动化测试——输入,点击操作
2020/03/06 Python
Python使用Numpy模块读取文件并绘制图片
2020/05/13 Python
VSCode配合pipenv搞定虚拟环境的实现方法
2020/05/17 Python
CSS3教程:边框属性border的极致应用
2009/04/02 HTML / CSS
使用phonegap克隆和删除联系人的实现方法
2017/03/31 HTML / CSS
canvas拼图功能实现代码示例
2018/11/21 HTML / CSS
投标承诺书怎么写
2014/05/24 职场文书
部门活动策划方案
2014/08/16 职场文书
环保项目建议书
2014/08/26 职场文书
师德师风主题教育活动总结
2015/05/07 职场文书
驾驶员安全责任协议书
2016/03/22 职场文书
《成长的天空》读后感3篇
2019/12/06 职场文书
yolov5返回坐标的方法实例
2022/03/17 Python
JS实现页面炫酷的时钟特效示例
2022/08/14 Javascript