在keras 中获取张量 tensor 的维度大小实例


Posted in Python onJune 10, 2020

在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。

shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,

示例:

>>> from keras import backend as K
>>> tf_session = K.get_session()
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> input = keras.backend.placeholder(shape=(2, 4, 5))
>>> K.shape(kvar)
<tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
>>> K.shape(input)
<tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
__To get integer shape (Instead, you can use K.int_shape(x))__
 
>>> K.shape(kvar).eval(session=tf_session)
array([2, 2], dtype=int32)
>>> K.shape(input).eval(session=tf_session)
array([2, 4, 5], dtype=int32)

如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。

>>> from keras import backend as K
>>> input = K.placeholder(shape=(2, 4, 5))
>>> K.int_shape(input)
(2, 4, 5)
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val)
>>> K.int_shape(kvar)
(2, 2)

最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。

补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)

tf.shape(a)和a.get_shape()比较

相同点:都可以得到tensor a的尺寸

不同点:tf.shape()中a 数据的类型可以是tensor, list, array

a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)

import tensorflow as tf 
import numpy as np 

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]] 
z=np.arange(24).reshape([2,3,4])

sess=tf.Session() 
# tf.shape() 
x_shape=tf.shape(x)          # x_shape 是一个tensor 
y_shape=tf.shape(y)          # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> 
z_shape=tf.shape(z)          # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> 
print(sess.run(x_shape))       # 结果:[2 3]
print(sess.run(y_shape))       # 结果:[2 3]
print(sess.run(z_shape) )       # 结果:[2 3 4]

x_shape=x.get_shape() 
print(x_shape)    # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组                            (2, 3)
x_shape=x.get_shape().as_list() 
print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算
# y_shape=y.get_shape() 
print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() 
print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()

以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python远程桌面协议RDPY安装使用介绍
Apr 15 Python
PyCharm设置SSH远程调试的方法
Jul 17 Python
Python使用装饰器模拟用户登陆验证功能示例
Aug 24 Python
python实现写数字文件名的递增保存文件方法
Oct 25 Python
Python根据欧拉角求旋转矩阵的实例
Jan 28 Python
Python从入门到精通之环境搭建教程图解
Sep 26 Python
pygame编写音乐播放器的实现代码示例
Nov 19 Python
python Manager 之dict KeyError问题的解决
Dec 21 Python
python使用nibabel和sitk读取保存nii.gz文件实例
Jul 01 Python
Python下载网易云歌单歌曲的示例代码
Aug 12 Python
Jupyter Notebook添加代码自动补全功能的实现
Jan 07 Python
详解Python 3.10 中的新功能和变化
Apr 28 Python
Keras—embedding嵌入层的用法详解
Jun 10 #Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
Jun 10 #Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 #Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 #Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 #Python
matplotlib 生成的图像中无法显示中文字符的解决方法
Jun 10 #Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 #Python
You might like
php 5.3.5安装memcache注意事项小结
2011/04/12 PHP
关于使用coreseek并为其做分页的介绍
2013/06/21 PHP
PHP使用strtotime获取上个月、下个月、本月的日期
2015/12/30 PHP
PHP空值检测函数与方法汇总
2017/11/19 PHP
PHP面向对象五大原则之依赖倒置原则(DIP)详解
2018/04/08 PHP
Apply an AutoFormat to an Excel Spreadsheet
2007/06/12 Javascript
jQuery语法总结和注意事项小结
2012/11/11 Javascript
解决JS中乘法的浮点错误的方法
2014/01/03 Javascript
jQuery语法小结(超实用)
2015/12/31 Javascript
javascript实现计时器的简单方法
2016/02/21 Javascript
NodeJS中的MongoDB快速入门详细教程
2016/11/11 NodeJs
Bootstrap CSS布局之代码
2016/12/17 Javascript
jQuery插件HighCharts实现的2D对数饼图效果示例【附demo源码下载】
2017/03/09 Javascript
Angular实现一个简单的多选复选框的弹出框指令实例
2017/04/25 Javascript
JavaScript函数表达式详解及实例
2017/05/05 Javascript
js学习总结之DOM2兼容处理重复问题的解决方法
2017/07/27 Javascript
利用angular自动编译andriod APK的绕坑经历分享
2019/03/08 Javascript
node中使用log4js4.x版本记录日志的方法
2019/08/20 Javascript
layui表格内容溢出的解决方法
2019/09/06 Javascript
Vue-cli3项目引入Typescript的实现方法
2019/10/18 Javascript
JS实现随机点名器
2020/04/12 Javascript
[49:05]OG vs Newbee 2019DOTA2国际邀请赛淘汰赛 胜者组 BO3 第二场 8.21.mp4
2020/07/19 DOTA
Python字典简介以及用法详解
2016/11/15 Python
Python实现基于多线程、多用户的FTP服务器与客户端功能完整实例
2017/08/18 Python
pandas数据处理基础之筛选指定行或者指定列的数据
2018/05/03 Python
Python UnboundLocalError和NameError错误根源案例解析
2018/10/31 Python
python3反转字符串的3种方法(小结)
2019/11/07 Python
CSS3为背景图设置遮罩并解决遮罩样式继承问题
2020/06/22 HTML / CSS
Topshop美国官网:英国快速时尚品牌
2019/05/16 全球购物
高一历史教学反思
2014/01/13 职场文书
元旦活动感言
2014/03/08 职场文书
电脑售后服务承诺书
2014/03/27 职场文书
罚款通知怎么写
2015/04/22 职场文书
升学宴家长致辞
2015/07/27 职场文书
Python循环之while无限迭代
2022/04/30 Python
React如何使用axios请求数据并把数据渲染到组件
2022/08/05 Javascript