Matrix multiplication and broadcasting mechanism in torch

1. Broadcast mechanism broadcast

1. Two tensor “broadcastable” rules:

  • Every tensor has at least one dimension.

  • When iterating dimension sizes, starting from the last dimension, the following conditions are met: (1) the dimensions are equal in size, (2) one of the dimensions is 1, (3) or one of the dimensions does not exist.

Example:

x=torch.empty((0,))
y=torch.empty(2,2)
# x, y are not broadcastable because x does not have at least one dimension


x=torch.empty(5,7,3)
y=torch.empty(5,7,3)
# same shape is always broadcastable


x=torch.empty(5,3,4,1)
y=torch.empty(3,1,1)
# The first tail dimension: the size is 1
# The second tail dimension: the size of y is 1
# Third tail dimension: size of x == size of y
# The fourth tail dimension, the y dimension does not exist
# Satisfies broadcastable rules, so X and y are broadcastable.

x=torch.empty(5,2,4,1)
y=torch.empty(3,1,1)
# x and y are not broadcastable because 2 != 3 in the third tail dimension

2. Broadcast tensor calculation rules:

If two tensors can be broadcast, the resulting tensor size is calculated as follows:

  • If x and y have unequal dimensions, prepend 1 to the dimension of the tensor with less dimension, making them equal in length.

  • For each dimension size, the resulting dimension size is the maximum of the x and y sizes along that dimension.

Example:

>>>x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x + y). size()
torch. Size([5, 3, 4, 1])


>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x + y). size()
torch. Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x + y). size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

Note 1: In-place operations do not allow in-place tensors to change shape due to broadcasting.

>>>x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch. Size([5, 3, 4, 1])

# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x. add_(y)). size() # in_place
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

>>> (x + y). size()
torch. Size([3,3,7])

Note 2: In the case where two tensors do not have the same shape, but are broadcastable and have the same number of elements, the introduction of broadcasting may lead to changes that are backwards incompatible .

 previously produced a tensor of size torch.Size([4,1]), but now produces a tensor of size torch.Size([4,4]). To help identify situations in code where there may be backward incompatibilities introduced by broadcasting, torch.utils.backcompat_broadcast_warning can be set. enabled is True, in which case python warnings will be generated.

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.

2. Various multiplication operations of torch

1.torch.dot(vec1,vec2)

It is used to calculate the dot product of two vectors, does not support the broadcast operation, and requires the same number of elements in the two one-dimensional tensors.

vec1(),vec2()\rightarrow torch.dot()\rightarrow out()

import torch

vec1 = torch.Tensor([1,2,3,4])
vec2 = torch.Tensor([5,6,7,8])
print(torch.dot(vec1, vec2))
# tensor(70.)

2.torch.mm(mat1,mat2)

The matrix multiplication used to calculate two two-dimensional matrices does not support the broadcast operation, and requires the dimensions of the two Tensors to meet the requirements of matrix multiplication.

mat1(n*m),mat2(m*d)\rightarrow torch.mm()\rightarrow out(n*d)

mat1 = torch.randn(3, 4)
mat2 = torch.randn(4, 5)
out = torch.mm(mat1, mat2)
print(out. shape)
# torch. Size([3, 5])

3.torch.bmm(mat1,mat2)

It is used to calculate the multiplication of two 3D matrices with Batch. It does not support the broadcast operation. It is required that the two inputs of this function must be 3D matrices and the first dimension is the same (indicating the Batch dimension).

mat1(b*n*m),mat2(b*m*d)\rightarrow torch.bmm()\rightarrow out(b*n*d)

mat1 = torch.randn(2, 3, 4)
mat2 = torch.randn(2, 4, 5)
out = torch.bmm(mat1, mat2)
print(out. shape)
# torch. Size([2, 3, 5])

4.torch.mv(mat,vec)

It is used to calculate the multiplication between a matrix and a vector (the matrix is in front, and the vector is in the back). The broadcast operation is not supported, and the matrix and vector are required to meet the requirements of matrix multiplication.

mat(n*m),vec(m)\rightarrow torch.mv()\rightarrow out(n)

mat = torch. randn(3, 4)
vec = torch.randn(4)
output = torch.mv(mat, vec)
print(output. shape)
# torch. Size([3])

5.torch.mul(a,b)

troch.multiply() is equivalent to torch.mul();

It is used to calculate matrix Element-wise (dot multiplication) multiplication (dot multiplication), and supports broadcast operations. As long as the dimensions of a and b meet the broadcast conditions, element-wise multiplication operations can be performed.

mat1,mat2\rightarrow torch.mul()\rightarrow out

A = torch.randn(2,1,4)
B = torch.randn(3, 1) # matrix
print(torch.mul(A,B).shape)
# torch. Size([2, 3, 4])
b0 = 2 # scalar
print(torch.mul(A,b0).shape)
# torch. Size([2, 1, 4])
b1 = torch.tensor([1,2,3,4]) # row vector
print(torch.mul(A,b1).shape)
# torch. Size([2, 1, 4])
b2 = torch.Tensor([1,2,3]).reshape(-1,1) # column vector
print(torch.mul(A,b2).shape)
# torch. Size([2, 3, 4])

6. torch.matmul(mat1,mat2)

It can be used to calculate almost all matrix/vector multiplication situations, and supports broadcast operations. It can be understood as the broadcast version of torch.mm. The multiplication rules depend on the dimensions of the two tensors participating in the multiplication.

mat1(j*1*n*m),mat2(k*m*p)\rightarrow torch.matmul()\rightarrow out(j*k*n*p)

In particular, for multidimensional data matmul() multiplication, it can be considered that the matmul() multiplication uses the last two dimensions of the two parameters to calculate, and other dimensions can be considered as batch dimensions.

mat1 = torch.randn(2,1,4,5)
mat2 = torch.randn(2,1,5,2)
out = torch.matmul(mat1, mat2)
print(out. shape)
# torch. Size([2, 1, 4, 2])

If the two matrices are one-dimensional, the function of this function is the same as torch.dot(), returning the dot product result of two one-dimensional tensors;

vec1 = torch.Tensor([1,2,3,4])
vec2 = torch.Tensor([5,6,7,8])
print(torch.matmul(vec1, vec2))
# tensor(70.)
print(torch.dot(vec1, vec2))
# tensor(70.)

If the two matrices are two-dimensional, the function of this function is the same as torch.mm(), returning the matrix multiplication of two two-dimensional matrices;

mat1 = torch.randn(3, 4)
mat2 = torch.randn(4, 5)
out = torch.mm(mat1, mat2)
print(out. shape)
# torch. Size([3, 5])

out1 = torch.matmul(mat1, mat2)
print(out1. shape)
# torch. Size([3, 5])

If the first argument is a 2D tensor (matrix) and the second argument is a 1D tensor (vector), then the product of matrix×vector will be returned. Then the function of this function is the same as torch.mv(), requiring the matrix and vector to meet the requirements of matrix multiplication;

mat = torch. randn(3, 4)
vec = torch.randn(4)
output = torch.mv(mat, vec)
print(output. shape)
# torch. Size([3])

output1 = torch.matmul(mat, vec)
print(output1. shape)
# torch. Size([3])

If the first parameter is a one-dimensional tensor and the second parameter is a two-dimensional tensor, then add a dimension (broadcast) in front of the one-dimensional tensor, and then perform matrix multiplication;

vec = torch.randn(4)
mat = torch.randn(4,2)
print(torch. matmul(vec, mat). shape)
# torch. Size([2])

7. Operator

@ operator : It works like torch.matmul.

mat1 = torch.randn(2,1,4,5)
mat2 = torch.randn(2,1,5,2)
out = torch.matmul(mat1, mat2)
print(out. shape)
# torch. Size([2, 1, 4, 2])
out1 = mat1 @ mat2
print(out1. shape)
# torch. Size([2, 1, 4, 2])

* operator: It works like torch.mul.

A = torch.randn(2,1,4)
B = torch.randn(3, 1) # matrix
print(torch.mul(A,B).shape)
# torch. Size([2, 3, 4])

print((A * B).shape)
# torch. Size([2, 3, 4])

8. Extension: torch.einsum(): Einstein summation convention

Put a link: einsum is all you need!

If, like me, you find it difficult to remember the names and signatures of functions that compute dot products, outer products, transposes, matrix-vector multiplications, and matrix-matrix multiplications in PyTorch/TensorFlow, then einsum notation is our lifesaver. The einsum notation is an elegant way to express the above operations, including complex tensor operations. Basically, einsum can be regarded as a domain-specific language. Once you understand and take advantage of einsum, you can write more compact and efficient code more quickly, in addition to the benefits of not having to memorize and frequently look up specific library functions. When einsum is not used, it is easy to introduce unnecessary tensor transformation or transposition operations, as well as intermediate tensors that can be omitted.

Just for learning record!