Broadcasting 是指,在运算中,不同大小的两个 array 应该怎样处理的操作。通常情况下,小一点的数组会被 broadcast 到大一点的,这样才能保持大小一致。Broadcasting 过程中的循环操作都在 C 底层进行,所以速度比较快。但也有一些情况下 Broadcasting 会带来性能上的下降。
两个 Tensors 只有在下列情况下才能进行 broadcasting 操作:
- 每个 tensor 至少有一维
- 遍历所有的维度,从尾部维度开始,每个对应的维度大小要么相同,要么其中一个是 1,要么其中一个不存在。
让我们来看一些代码示例。
1 2 3 | x=torch.empty(5,7,3) y=torch.empty(5,7,3) # 相同维度,一定可以 broadcasting |
1 2 3 | x=torch.empty((0,)) y=torch.empty(2,2) # x 没有符合“至少有一个维度”,所以不可以 broadcasting |
1 2 3 4 5 6 7 8 | # 按照尾部维度对齐 x=torch.empty(5,3,4,1) y=torch.empty( 3,1,1) # x 和 y 是 broadcastable # 1st 尾部维度: 都为 1 # 2nd 尾部维度: y 为 1 # 3rd 尾部维度: x 和 y 相同 # 4th 尾部维度: y 维度不存在 |
1 2 3 4 | # 但是: >>> x=torch.empty(5,2,4,1) >>> y=torch.empty( 3,1,1) # x 和 y 不能 broadcasting, 因为尾3维度 2 != 3 |
如果两个 tensors 可以 broadcasting,那么计算过程是这样的:
- 如果 x 和 y 的维度不同,那么对于维度较小的 tensor 的维度补 1,使它们维度相同。
- 然后,对于每个维度,计算结果的维度值就是 x 和 y 中较大的那个值。
1 2 3 4 5 6 | # 按照尾部维度对齐 x=torch.empty(5,1,4,1) y=torch.empty( 3,1,1) (x+y).size() # 结果维度 torch.Size([5, 3, 4, 1]) |
来看一个不对的例子:
1 2 3 | x=torch.empty(5,2,4,1) y=torch.empty(3,1,1) (x+y).size() |
报错:
1 2 3 4 5 6 7 8 9 10 11 | --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-17-72fb34250db7> in <module>() 1 x=torch.empty(5,2,4,1) 2 y=torch.empty(3,1,1) ----> 3 (x+y).size() RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1 |
注意报错提示说:在 non-singleton 维度上,tensor a 和 b 的 维度应该相同。
本站微信群、QQ群(三群号 726282629):
评论列表(1)