TensorFlow中tf.batch_matmul()的用法


Posted in Python onJune 02, 2021

TensorFlow中tf.batch_matmul()用法

如果有两个三阶张量,size分别为

a.shape = [100, 3, 4]
b.shape = [100, 4, 5]
c = tf.batch_matmul(a, b)

则c.shape = [100, 3, 5] //将每一对 3x4 的矩阵与 4x5 的矩阵分别相乘。batch_size不变

100为张量的batch_size。剩下的两个维度为数据的维度。

不过新版的tensorflow已经移除了上面的函数,使用时换为tf.matmul就可以了。与上面注释的方式是同样的。

附: 如果是更高维度。例如(a, b, m, n) 与(a, b, n, k)之间做matmul运算。则结果的维度为(a, b, m, k)。

TensorFlow如何实现batch_matmul

我们知道,在tensorflow早期版本中有tf.batch_matmul()函数,可以实现多维tensor和低维tensor的直接相乘,这在使用过程中非常便捷。

但是最新版本的tensorflow现在只有tf.matmul()函数可以使用,不过只能实现同维度的tensor相乘, 下面的几种方法可以实现batch matmul的可能。

例如: tensor A(batch_size,m,n), tensor B(n,k),实现batch matmul 使得A * B。

方法1: 利用tf.matmul()

对tensor B 进行增维和扩展

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
B_exp = tf.tile(tf.expand_dims(B,0),[batch_size, 1, 1]) #先进行增维再扩展
C = tf.matmul(A, B_exp)

方法2: 利用tf.reshape()

对tensor A 进行reshape操作,然后利用tf.matmul()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
A = tf.reshape(A, [-1, 3])
C = tf.reshape(tf.matmul(A, B), [-1, 2, 5])

方法3: 利用tf.scan()

利用tf.scan() 对tensor按第0维进行展开的特性

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
initializer = tf.Variable(tf.random_normal(shape=(2,5)))
C = tf.scan(lambda a,x: tf.matmul(x, B), A, initializer)

方法4: 利用tf.einsum()

A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
C = tf.einsum('ijk,kl->ijl',A,B)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python采用requests库模拟登录和抓取数据的简单示例
Jul 05 Python
解析Python编程中的包结构
Oct 25 Python
Scrapy-redis爬虫分布式爬取的分析和实现
Feb 07 Python
TensorFlow神经网络优化策略学习
Mar 09 Python
pycharm 解除默认unittest模式的方法
Nov 30 Python
解决Pycharm后台indexing导致不能run的问题
Jun 27 Python
Python递归及尾递归优化操作实例分析
Feb 01 Python
Django-xadmin+rule对象级权限的实现方式
Mar 30 Python
30行Python代码实现高分辨率图像导航的方法
May 22 Python
Python3+RIDE+RobotFramework自动化测试框架搭建过程详解
Sep 23 Python
pandas处理csv文件的方法步骤
Oct 16 Python
python 如何把docker-compose.yaml导入到数据库相关条目里
Jan 15 Python
pytorch 运行一段时间后出现GPU OOM的问题
Jun 02 #Python
python flask开发的简单基金查询工具
python爬取网页版QQ空间,生成各类图表
Python爬虫实战之爬取携程评论
Pytorch DataLoader shuffle验证方式
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
You might like
PHP DataGrid 实现代码
2009/08/12 PHP
php数组对百万数据进行排除重复数据的实现代码
2010/06/08 PHP
PHP二分查找算法示例【递归与非递归方法】
2016/09/29 PHP
javascript代码编写需要注意的7个小细节小结
2011/09/21 Javascript
扩展js对象数组的OrderByAsc和OrderByDesc方法实现思路
2013/05/17 Javascript
JS 精确统计网站访问量的实例代码
2013/07/05 Javascript
js 实现日期灵活格式化的小例子
2013/07/14 Javascript
JavaScript判断数组是否包含指定元素的方法
2015/07/01 Javascript
整理Javascript函数学习笔记
2015/12/01 Javascript
浅析jQuery Ajax请求参数和返回数据的处理
2016/02/24 Javascript
全面解析JavaScript中apply和call以及bind(推荐)
2016/06/15 Javascript
根据Bootstrap Paginator改写的js分页插件
2016/12/25 Javascript
微信小程序自定义底部导航带跳转功能
2018/11/27 Javascript
js实现下拉框二级联动
2018/12/04 Javascript
解决LayUI数据表格复选框不居中显示的问题
2019/09/25 Javascript
angular inputNumber指令输入框只能输入数字的实现
2019/12/03 Javascript
Vue跨域请求问题解决方案过程解析
2020/08/07 Javascript
JavaScript 判断数据类型的4种方法
2020/09/11 Javascript
JavaScript通如何过RGraph实现动态仪表盘
2020/10/15 Javascript
[06:16]DOTA2守卫传承者——职业选手谈心路历程
2015/02/26 DOTA
使用Python的PEAK来适配协议的教程
2015/04/14 Python
Python用Bottle轻量级框架进行Web开发
2016/06/08 Python
Python中类型检查的详细介绍
2017/02/13 Python
Python中对象的引用与复制代码示例
2017/12/04 Python
Python实现的多进程和多线程功能示例
2018/05/29 Python
Python实例方法、类方法、静态方法区别详解
2020/09/05 Python
CSS3属性选择符介绍
2008/10/17 HTML / CSS
小学生环保标语
2014/06/13 职场文书
学校周年庆活动方案
2014/08/22 职场文书
交通事故死亡赔偿协议书
2014/12/03 职场文书
南京导游词
2015/02/03 职场文书
消防演习通知
2015/04/25 职场文书
个人道歉信大全
2019/04/11 职场文书
想创业成功,需要掌握这些要点
2019/12/06 职场文书
浅谈vue2的$refs在vue3组合式API中的替代方法
2021/04/18 Vue.js
青岛市的收音机研制与生产
2022/04/07 无线电