探秘TensorFlow 和 NumPy 的 Broadcasting 机制


Posted in Python onMarch 13, 2020

在使用Tensorflow的过程中,我们经常遇到数组形状不同的情况,但有时候发现二者还能进行加减乘除的运算,在这背后,其实是Tensorflow的broadcast即广播机制帮了大忙。而Tensorflow中的广播机制其实是效仿的numpy中的广播机制。本篇,我们就来一同研究下numpy和Tensorflow中的广播机制。

1、numpy广播原理

1.1 数组和标量计算时的广播

标量和数组合并时就会发生简单的广播,标量会和数组中的每一个元素进行计算。

举个例子:

arr = np.arange(5)
arr * 4

得到的输出为:

array([ 0,  4,  8, 12, 16])

这个是很好理解的,我们重点来研究数组之间的广播

1.2 数组之间计算时的广播

用书中的话来介绍广播的规则:两个数组之间广播的规则:如果两个数组的后缘维度(即从末尾开始算起的维度)的轴长度相等或其中一方的长度为1,则认为他们是广播兼容的,广播会在缺失和(或)长度为1的维度上进行。

上面的规则挺拗口的,我们举几个例子吧:

二维的情况

假设有一个二维数组,我们想要减去它在0轴和1轴的均值,这时的广播是什么样的呢。

我们先来看减去0轴均值的情况:

arr = np.arange(12).reshape(4,3)
arr-arr.mean(0)

输出的结果为:

array([[-4.5, -4.5, -4.5],
       [-1.5, -1.5, -1.5],
       [ 1.5,  1.5,  1.5],
       [ 4.5,  4.5,  4.5]])

0轴的平均值为[4.5,5.5,6.5],形状为(3,),而原数组形状为(4,3),在进行广播时,从后往前比较两个数组的形状,首先是3=3,满足条件而继续比较,这时候发现其中一个数组的形状数组遍历完成,因此会在缺失轴即0轴上进行广播。

可以理解成将均值数组在0轴上复制4份,变成形状(4,3)的数组,再与原数组进行计算。

书中的图形象的表示了这个过程(数据不一样请忽略):

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来看一下减去1轴平均值的情况,即每行都减去该行的平均值:

arr - arr.mean(1)

此时报错了:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来念叨一遍我们的广播规则,均值数组的形状为(4,),而原数组形状为(4,3),按照比较规则,4 != 3,因此不符合广播的条件,因此报错。

正确的做法是什么呢,因为原数组在0轴上的形状为4,我们的均值数组必须要先有一个值能够跟3比较同时满足我们的广播规则,这个值不用多想,就是1。因此我们需要先将均值数组变成(4,1)的形状,再去进行运算:

arr-arr.mean(1).reshape((4,1))

得到正确的结果:

array([[-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.]])

三维的情况

理解了二维的情况,我们也就能很快的理解三维数组的情况。

首先看下图:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

根据广播原则分析:arr1的shape为(3,4,2),arr2的shape为(4,2),它们的后缘轴长度都为(4,2),所以可以在0轴进行广播。因此,arr2在0轴上复制三份,shape变为(3,4,2),再进行计算。

不只是0轴,1轴和2轴也都可以进行广播。但形状必须满足一定的条件。举个例子来说,我们arr1的shape为(8,5,3),想要在0轴上广播的话,arr2的shape是(1,5,3)或者(5,3),想要在1轴上进行广播的话,arr2的shape是(8,1,3),想要在2轴上广播的话,arr2的shape必须是(8,5,1)。

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们来写几个例子吧:

arr2 = np.arange(24).reshape((2,3,4))
arr3_0 = np.arange(12).reshape((3,4))
print("0轴广播")
print(arr2 - arr3_0)

arr3_1 = np.arange(8).reshape((2,1,4))
print("1轴广播")
print(arr2 - arr3_1)

arr3_2 = np.arange(6).reshape((2,3,1))
print("2轴广播")
print(arr2 - arr3_2)

输出为:

0轴广播
[[[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  0  0  0]]

 [[12 12 12 12]
  [12 12 12 12]
  [12 12 12 12]]]
1轴广播
[[[ 0  0  0  0]
  [ 4  4  4  4]
  [ 8  8  8  8]]

 [[ 8  8  8  8]
  [12 12 12 12]
  [16 16 16 16]]]
2轴广播
[[[ 0  1  2  3]
  [ 3  4  5  6]
  [ 6  7  8  9]]

 [[ 9 10 11 12]
  [12 13 14 15]
  [15 16
 17 18]]]

如果我们想在两个轴上进行广播,那arr2的shape要满足什么条件呢?

arr1.shape 广播轴 arr2.shape
(8,5,3) 0,1 (3,),(1,3),(1,1,3)
(8,5,3) 0,2 (5,1),(1,5,1)
(8,5,3) 1,2 (8,1,1)

具体的例子就不给出啦,嘻嘻。

2、Tensorflow 广播举例

Tensorflow中的广播机制和numpy是一样的,因此我们给出一些简单的举例:

二维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3),0,0.1))
b = tf.Variable(tf.random_normal((2,1),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[-0.1419442 ,  0.14135399,  0.22752595],
       [ 0.1382471 ,  0.28228047,  0.13102233]], dtype=float32)

三维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,1,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[[-0.0154749 , -0.02047186, -0.01022427, -0.08932371],
        [-0.12693939, -0.08069084, -0.15459496,  0.09405404],
        [ 0.09730847,  0.06936138,  0.04050628,  0.15374713]],

       [[-0.02691782, -0.26384184,  0.05825682, -0.07617196],
        [-0.02653179, -0.01997554, -0.06522765,  0.03028341],
        [-0.07577246,  0.03199019,  0.0321    , -0.12571403]]], dtype=float32)

错误示例

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

ValueError: Dimensions must be equal, but are 3 and 2 for 'sub_2' (op: 'Sub') with input shapes: [2,3,4], [2,4].

到此这篇关于探秘TensorFlow 和 NumPy 的 Broadcasting 机制的文章就介绍到这了,更多相关TensorFlow 和NumPy 的Broadcasting 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python设计模式之代理模式实例
Apr 26 Python
python中zip()方法应用实例分析
Apr 16 Python
深入理解Python中的内置常量
May 20 Python
tensorflow实现KNN识别MNIST
Mar 12 Python
Python实现的维尼吉亚密码算法示例
Apr 12 Python
python的set处理二维数组转一维数组的方法示例
May 31 Python
Python适配器模式代码实现解析
Aug 02 Python
简单分析python的类变量、实例变量
Aug 23 Python
Java如何基于wsimport调用wcf接口
Jun 17 Python
实例讲解Python 迭代器与生成器
Jul 08 Python
python 如何实现遗传算法
Sep 22 Python
Python实现小黑屋游戏的完整实例
Jan 06 Python
自定义Django Form中choicefield下拉菜单选取数据库内容实例
Mar 13 #Python
django处理select下拉表单实例(从model到前端到post到form)
Mar 13 #Python
python实现俄罗斯方块游戏(改进版)
Mar 13 #Python
Python之Django自动实现html代码(下拉框,数据选择)
Mar 13 #Python
Tensorflow中的dropout的使用方法
Mar 13 #Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
You might like
十天学会php(1)
2006/10/09 PHP
CI框架学习笔记(一) - 环境安装、基本术语和框架流程
2014/10/26 PHP
ThinkPHP中Common/common.php文件常用函数功能分析
2016/05/20 PHP
在IIS下安装PHP扩展的方法(超简单)
2017/04/10 PHP
Javascript Cookie读写删除操作的函数
2010/03/02 Javascript
获取焦点时,利用js定时器设定时间执行动作
2010/04/02 Javascript
基于MooTools的很有创意的滚动条时钟动画
2010/11/14 Javascript
jQuery处理xml格式的返回数据(实例解析)
2013/11/28 Javascript
jQuery实现可编辑的表格实例讲解(2)
2015/09/17 Javascript
Highcharts学习之坐标轴
2016/08/02 Javascript
HTML页面,测试JS对C函数的调用简单实例
2016/08/09 Javascript
JS中的数组转变成JSON格式字符串的方法
2017/05/09 Javascript
javascript function(函数类型)使用与注意事项小结
2019/06/10 Javascript
vue实现日历备忘录功能
2020/09/24 Javascript
python实现的AES双向对称加密解密与用法分析
2017/05/02 Python
在python中实现将一张图片剪切成四份的方法
2018/12/05 Python
解决python多行注释引发缩进错误的问题
2019/08/23 Python
Python将视频或者动态图gif逐帧保存为图片的方法
2019/09/10 Python
Pycharm插件(Grep Console)自定义规则输出颜色日志的方法
2020/05/27 Python
python 删除excel表格重复行,数据预处理操作
2020/07/06 Python
英国一家专门出售品牌鞋子的网站:Allsole
2016/08/07 全球购物
eBay爱尔兰站:eBay.ie
2019/08/09 全球购物
十佳美德少年事迹材料
2014/02/05 职场文书
材料专业毕业生求职信
2014/02/26 职场文书
学习十八大报告感言
2014/02/28 职场文书
升旗仪式主持词
2014/03/19 职场文书
无刑事犯罪记录证明
2014/09/18 职场文书
某集团股份有限公司委托书样本
2014/09/24 职场文书
党的群众路线教育实践活动个人整改措施范文
2014/11/04 职场文书
2014年项目工作总结
2014/11/24 职场文书
2014年优秀班主任工作总结
2014/12/16 职场文书
上班旷工检讨书
2015/08/15 职场文书
企业文化学习心得体会
2016/01/21 职场文书
解除租赁合同协议书
2016/03/21 职场文书
Python中super().__init__()测试以及理解
2021/12/06 Python
el-form每行显示两列底部按钮居中效果的实现
2022/08/05 HTML / CSS