TensorFlow中tf.batch_matmul()的用法


Posted in Python onJune 02, 2021

TensorFlow中tf.batch_matmul()用法

如果有两个三阶张量,size分别为

a.shape = [100, 3, 4]
b.shape = [100, 4, 5]
c = tf.batch_matmul(a, b)

则c.shape = [100, 3, 5] //将每一对 3x4 的矩阵与 4x5 的矩阵分别相乘。batch_size不变

100为张量的batch_size。剩下的两个维度为数据的维度。

不过新版的tensorflow已经移除了上面的函数,使用时换为tf.matmul就可以了。与上面注释的方式是同样的。

附: 如果是更高维度。例如(a, b, m, n) 与(a, b, n, k)之间做matmul运算。则结果的维度为(a, b, m, k)。

TensorFlow如何实现batch_matmul

我们知道,在tensorflow早期版本中有tf.batch_matmul()函数,可以实现多维tensor和低维tensor的直接相乘,这在使用过程中非常便捷。

但是最新版本的tensorflow现在只有tf.matmul()函数可以使用,不过只能实现同维度的tensor相乘, 下面的几种方法可以实现batch matmul的可能。

例如: tensor A(batch_size,m,n), tensor B(n,k),实现batch matmul 使得A * B。

方法1: 利用tf.matmul()

对tensor B 进行增维和扩展

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
B_exp = tf.tile(tf.expand_dims(B,0),[batch_size, 1, 1]) #先进行增维再扩展
C = tf.matmul(A, B_exp)

方法2: 利用tf.reshape()

对tensor A 进行reshape操作,然后利用tf.matmul()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
A = tf.reshape(A, [-1, 3])
C = tf.reshape(tf.matmul(A, B), [-1, 2, 5])

方法3: 利用tf.scan()

利用tf.scan() 对tensor按第0维进行展开的特性

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
initializer = tf.Variable(tf.random_normal(shape=(2,5)))
C = tf.scan(lambda a,x: tf.matmul(x, B), A, initializer)

方法4: 利用tf.einsum()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
C = tf.einsum('ijk,kl->ijl',A,B)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的函数用法入门教程
Sep 02 Python
Python中endswith()函数的基本使用
Apr 07 Python
Python实现的数据结构与算法之链表详解
Apr 22 Python
Python中的ConfigParser模块使用详解
May 04 Python
Python解析excel文件存入sqlite数据库的方法
Nov 15 Python
浅谈Python由__dict__和dir()引发的一些思考
Oct 30 Python
python 设置文件编码格式的实现方法
Dec 21 Python
python使用json序列化datetime类型实例解析
Feb 11 Python
Python subprocess库的使用详解
Oct 26 Python
使用keras2.0 将Merge层改为函数式
May 23 Python
Jmeter调用Python脚本实现参数互相传递的实现
Jan 22 Python
pytorch常用数据类型所占字节数对照表一览
May 17 Python
pytorch 运行一段时间后出现GPU OOM的问题
Jun 02 #Python
python flask开发的简单基金查询工具
python爬取网页版QQ空间,生成各类图表
Python爬虫实战之爬取携程评论
Pytorch DataLoader shuffle验证方式
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
You might like
解析php中heredoc的使用方法
2013/06/17 PHP
PHP Ajax实现无刷新附件上传
2016/08/17 PHP
PHP获取文件扩展名的常用方法小结【五种方式】
2018/04/27 PHP
使用Laravel中的查询构造器实现增删改查功能
2019/09/03 PHP
小议Function.apply() 之一------(函数的劫持与对象的复制)
2006/11/30 Javascript
jquery 查找iframe父级页面元素的实现代码
2011/08/28 Javascript
JS跨域总结
2012/08/30 Javascript
Javascript动态引用CSS文件的2种方法介绍
2014/06/06 Javascript
js在指定位置增加节点函数insertBefore()用法实例
2015/01/12 Javascript
js+CSS实现模拟华丽的select控件下拉菜单效果
2015/09/01 Javascript
js实现带三角符的手风琴效果
2017/03/01 Javascript
ExtJs的Ext.Ajax.request实现waitMsg等待提示效果
2017/06/14 Javascript
node结合swig渲染摸板的方法
2018/04/11 Javascript
微信小程序实现的一键复制功能示例
2019/04/24 Javascript
nodejs中各种加密算法的实现详解
2019/07/11 NodeJs
NodeJS http模块用法示例【创建web服务器/客户端】
2019/11/05 NodeJs
[01:33]一分钟玩转DOTA2第三弹:DOTA2&DotA快捷操作大对比
2014/06/04 DOTA
python数据结构之二叉树的建立实例
2014/04/29 Python
Python实现堆排序的方法详解
2016/05/03 Python
Python PyQt5标准对话框用法示例
2017/08/23 Python
python for循环输入一个矩阵的实例
2018/11/14 Python
Python中作用域的深入讲解
2018/12/10 Python
使用Python刷淘宝喵币(低阶入门版)
2019/10/30 Python
pymysql模块的操作实例
2019/12/17 Python
如何在windows下安装Pycham2020软件(方法步骤详解)
2020/05/03 Python
Python爬取微信小程序Charles实现过程图解
2020/09/29 Python
css3截图_动力节点Java学院整理
2017/07/11 HTML / CSS
美国名牌太阳镜折扣网站:Eyedictive
2017/05/15 全球购物
销售业务实习自我鉴定
2013/09/23 职场文书
硕士研究生自我鉴定
2013/11/08 职场文书
养殖行业的创业计划书
2014/01/05 职场文书
教师新年寄语
2014/04/03 职场文书
活动总结报告格式
2014/05/09 职场文书
驾驶员安全责任协议书
2016/03/22 职场文书
2016年村党支部公开承诺书
2016/03/24 职场文书
五年级作文之成长
2019/09/16 职场文书