探秘TensorFlow 和 NumPy 的 Broadcasting 机制


Posted in Python onMarch 13, 2020

在使用Tensorflow的过程中,我们经常遇到数组形状不同的情况,但有时候发现二者还能进行加减乘除的运算,在这背后,其实是Tensorflow的broadcast即广播机制帮了大忙。而Tensorflow中的广播机制其实是效仿的numpy中的广播机制。本篇,我们就来一同研究下numpy和Tensorflow中的广播机制。

1、numpy广播原理

1.1 数组和标量计算时的广播

标量和数组合并时就会发生简单的广播,标量会和数组中的每一个元素进行计算。

举个例子:

arr = np.arange(5)
arr * 4

得到的输出为:

array([ 0,  4,  8, 12, 16])

这个是很好理解的,我们重点来研究数组之间的广播

1.2 数组之间计算时的广播

用书中的话来介绍广播的规则:两个数组之间广播的规则:如果两个数组的后缘维度(即从末尾开始算起的维度)的轴长度相等或其中一方的长度为1,则认为他们是广播兼容的,广播会在缺失和(或)长度为1的维度上进行。

上面的规则挺拗口的,我们举几个例子吧:

二维的情况

假设有一个二维数组,我们想要减去它在0轴和1轴的均值,这时的广播是什么样的呢。

我们先来看减去0轴均值的情况:

arr = np.arange(12).reshape(4,3)
arr-arr.mean(0)

输出的结果为:

array([[-4.5, -4.5, -4.5],
       [-1.5, -1.5, -1.5],
       [ 1.5,  1.5,  1.5],
       [ 4.5,  4.5,  4.5]])

0轴的平均值为[4.5,5.5,6.5],形状为(3,),而原数组形状为(4,3),在进行广播时,从后往前比较两个数组的形状,首先是3=3,满足条件而继续比较,这时候发现其中一个数组的形状数组遍历完成,因此会在缺失轴即0轴上进行广播。

可以理解成将均值数组在0轴上复制4份,变成形状(4,3)的数组,再与原数组进行计算。

书中的图形象的表示了这个过程(数据不一样请忽略):

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来看一下减去1轴平均值的情况,即每行都减去该行的平均值:

arr - arr.mean(1)

此时报错了:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们再来念叨一遍我们的广播规则,均值数组的形状为(4,),而原数组形状为(4,3),按照比较规则,4 != 3,因此不符合广播的条件,因此报错。

正确的做法是什么呢,因为原数组在0轴上的形状为4,我们的均值数组必须要先有一个值能够跟3比较同时满足我们的广播规则,这个值不用多想,就是1。因此我们需要先将均值数组变成(4,1)的形状,再去进行运算:

arr-arr.mean(1).reshape((4,1))

得到正确的结果:

array([[-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.],
    [-1., 0., 1.]])

三维的情况

理解了二维的情况,我们也就能很快的理解三维数组的情况。

首先看下图:

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

根据广播原则分析:arr1的shape为(3,4,2),arr2的shape为(4,2),它们的后缘轴长度都为(4,2),所以可以在0轴进行广播。因此,arr2在0轴上复制三份,shape变为(3,4,2),再进行计算。

不只是0轴,1轴和2轴也都可以进行广播。但形状必须满足一定的条件。举个例子来说,我们arr1的shape为(8,5,3),想要在0轴上广播的话,arr2的shape是(1,5,3)或者(5,3),想要在1轴上进行广播的话,arr2的shape是(8,1,3),想要在2轴上广播的话,arr2的shape必须是(8,5,1)。

探秘TensorFlow 和 NumPy 的 Broadcasting 机制

我们来写几个例子吧:

arr2 = np.arange(24).reshape((2,3,4))
arr3_0 = np.arange(12).reshape((3,4))
print("0轴广播")
print(arr2 - arr3_0)

arr3_1 = np.arange(8).reshape((2,1,4))
print("1轴广播")
print(arr2 - arr3_1)

arr3_2 = np.arange(6).reshape((2,3,1))
print("2轴广播")
print(arr2 - arr3_2)

输出为:

0轴广播
[[[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  0  0  0]]

 [[12 12 12 12]
  [12 12 12 12]
  [12 12 12 12]]]
1轴广播
[[[ 0  0  0  0]
  [ 4  4  4  4]
  [ 8  8  8  8]]

 [[ 8  8  8  8]
  [12 12 12 12]
  [16 16 16 16]]]
2轴广播
[[[ 0  1  2  3]
  [ 3  4  5  6]
  [ 6  7  8  9]]

 [[ 9 10 11 12]
  [12 13 14 15]
  [15 16
 17 18]]]

如果我们想在两个轴上进行广播,那arr2的shape要满足什么条件呢?

arr1.shape 广播轴 arr2.shape
(8,5,3) 0,1 (3,),(1,3),(1,1,3)
(8,5,3) 0,2 (5,1),(1,5,1)
(8,5,3) 1,2 (8,1,1)

具体的例子就不给出啦,嘻嘻。

2、Tensorflow 广播举例

Tensorflow中的广播机制和numpy是一样的,因此我们给出一些简单的举例:

二维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3),0,0.1))
b = tf.Variable(tf.random_normal((2,1),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[-0.1419442 ,  0.14135399,  0.22752595],
       [ 0.1382471 ,  0.28228047,  0.13102233]], dtype=float32)

三维的情况

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,1,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

array([[[-0.0154749 , -0.02047186, -0.01022427, -0.08932371],
        [-0.12693939, -0.08069084, -0.15459496,  0.09405404],
        [ 0.09730847,  0.06936138,  0.04050628,  0.15374713]],

       [[-0.02691782, -0.26384184,  0.05825682, -0.07617196],
        [-0.02653179, -0.01997554, -0.06522765,  0.03028341],
        [-0.07577246,  0.03199019,  0.0321    , -0.12571403]]], dtype=float32)

错误示例

sess = tf.Session()
a = tf.Variable(tf.random_normal((2,3,4),0,0.1))
b = tf.Variable(tf.random_normal((2,4),0,0.1))
c = a - b
sess.run(tf.global_variables_initializer())
sess.run(c)

输出为:

ValueError: Dimensions must be equal, but are 3 and 2 for 'sub_2' (op: 'Sub') with input shapes: [2,3,4], [2,4].

到此这篇关于探秘TensorFlow 和 NumPy 的 Broadcasting 机制的文章就介绍到这了,更多相关TensorFlow 和NumPy 的Broadcasting 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
提升Python程序运行效率的6个方法
Mar 31 Python
Python实现从URL地址提取文件名的方法
May 15 Python
Windows下为Python安装Matplotlib模块
Nov 06 Python
详解Python中的文件操作
Aug 28 Python
Python打印“菱形”星号代码方法
Feb 05 Python
python实现爬取图书封面
Jul 05 Python
使用Python实现一个栈判断括号是否平衡
Aug 23 Python
Python从ZabbixAPI获取信息及实现Zabbix-API 监控的方法
Sep 17 Python
浅谈python中拼接路径os.path.join斜杠的问题
Oct 23 Python
python生成lmdb格式的文件实例
Nov 08 Python
pandas取dataframe特定行列的实现方法
May 24 Python
Python实现单例模式的5种方法
Jun 15 Python
自定义Django Form中choicefield下拉菜单选取数据库内容实例
Mar 13 #Python
django处理select下拉表单实例(从model到前端到post到form)
Mar 13 #Python
python实现俄罗斯方块游戏(改进版)
Mar 13 #Python
Python之Django自动实现html代码(下拉框,数据选择)
Mar 13 #Python
Tensorflow中的dropout的使用方法
Mar 13 #Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
You might like
php使用function_exists判断函数可用的方法
2014/11/19 PHP
php操作memcache缓存方法分享
2015/06/03 PHP
php+js实现百度地图多点标注的方法
2016/11/30 PHP
利用php的ob缓存机制实现页面静态化方法
2017/07/09 PHP
PHP实现模拟http请求的方法分析
2017/12/20 PHP
laravel通用化的CURD的实现
2019/12/13 PHP
IE中radio 或checkbox的checked属性初始状态下不能选中显示问题
2009/07/25 Javascript
jquery批量控制form禁用的代码
2013/08/06 Javascript
JS实现静止元素自动移动示例
2014/04/14 Javascript
基于jquery实现发送文章到手机的代码
2014/12/26 Javascript
用js编写的简单的计算器代码程序
2015/08/04 Javascript
jQuery获取复选框被选中数量及判断选择值的方法详解
2016/05/25 Javascript
基于BootStrap实现局部刷新分页实例代码
2016/08/08 Javascript
KnockoutJS 3.X API 第四章之表单value绑定
2016/10/10 Javascript
实现点击下箭头变上箭头来回切换的两种方法【推荐】
2016/12/14 Javascript
BootStrap实现带关闭按钮功能
2017/02/15 Javascript
Angular.JS去掉访问路径URL中的#号详解
2017/03/30 Javascript
angular或者js怎么确定选中ul中的哪几个li
2017/08/16 Javascript
使用async、enterproxy控制并发数量的方法详解
2018/01/02 Javascript
详解Angular系列之变化检测(Change Detection)
2018/02/26 Javascript
vue中实现先请求数据再渲染dom分享
2018/03/17 Javascript
Python中文分词实现方法(安装pymmseg)
2016/06/14 Python
python 读取excel文件生成sql文件实例详解
2017/05/12 Python
Python同步遍历多个列表的示例
2019/02/19 Python
python3利用Socket实现通信的方法示例
2019/05/06 Python
纯CSS3制作的简洁蓝白风格的登录模板(非IE效果更好)
2013/08/11 HTML / CSS
HTML5之SVG 2D入门6—视窗坐标系与用户坐标系及变换概述
2013/01/30 HTML / CSS
韩国邮政旗下生鲜食品网上超市:epost
2016/08/27 全球购物
魅力惠奢品线上平台:MEI.COM
2016/11/29 全球购物
法律工作求职自荐信
2013/10/31 职场文书
经销商会议欢迎词
2014/01/11 职场文书
中式结婚主持词
2014/03/14 职场文书
领导班子专题民主生活会情况想汇报
2014/09/30 职场文书
2015年社区计生工作总结
2015/04/21 职场文书
小学班主任工作经验交流材料
2015/11/02 职场文书
python实现简单反弹球游戏
2021/04/12 Python