Examples of dimension expansion and broadcast operations of tensors in Pytorch

Broadcast operations allow you to perform element-wise operations on differently shaped tensors without explicit looping.

A practical example of discretizing molecular coordinates:

def cdists(mols, grid):
    '''
    Calculates the pairwise Euclidean distances between a set of molecules and a list
    of positions on a grid (uses inplace operations to minimize memory demands).

    Args:
        mols (torch.Tensor): data set (of molecules) with shape
            (batch_size x n_atoms x n_dims)
        grid (torch.Tensor): array (of positions) with shape (n_positions x n_dims)

    Returns:
        torch.Tensor: batch of distance matrices (batch_size x n_atoms x n_positions)
    '''
    if len(mols.size()) == len(grid.size()) + 1:
        grid = grid.unsqueeze(0) # add batch dimension
    return F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1),
                  inplace=True).sqrt_()

So, why does the above code perform an operation like: (mols[:, :, None, :] – grid[:, None, :, :])?

This code is used to calculate the distance between a set of molecules (mols) and a set of location points on a grid (grid) Euclidean distance. The (mols[:, :, None, :] - grid[:, None, :, :]) operation here involves the broadcast operation of tensors, and its purpose is to calculate the The distance between each atom and each position.

How this code works:

  1. The shape of the mols tensor is (batch_size x n_atoms x n_dims), where batch_size is the batch size and n_atoms is the number of atoms, n_dims is the coordinate dimension of the atom (usually 3 dimensions, indicating xyz coordinates).
  2. The shape of the grid tensor is (n_positions x n_dims), where n_positions is the number of position grid points and n_dims The same is the dimension of the coordinates.

First, if the dimension of the mols tensor is 1 more than the dimension of the grid tensor, the code adds an extra dimension via grid.unsqueeze(0) Dimensions to match the batch dimensions of the mols tensor. This is for broadcast operations to take effect.

Next, the code uses a broadcast operation to calculate the distance between each atom of each molecule and each position. Broadcast operations allow you to perform element-wise operations on tensors that do not have the same shape without explicit looping.

  • The shape of mols[:, :, None, :] becomes (batch_size x n_atoms x 1 x n_dims). This operation adds an extra dimension to the n_atoms dimension for broadcasting with grid[:, None, :, :].

  • The shape of grid[:, None, :, :] becomes (1 x 1 x n_positions x n_dims). This operation adds an extra dimension in the batch dimension and the n_atoms dimension for broadcasting with mols[:, :, None, :].

  • The two tensors are then subtracted element-by-element, calculating the difference between each atom and each position grid point of each molecule. The result is a tensor with shape (batch_size x n_atoms x n_positions x n_dims).

  • Finally, the Euclidean distance between each atom of each molecule and each position grid point was calculated using the .pow_(2) and .sqrt_() operations. Get distance.

In summary, this code efficiently calculates the distance between each atom of each molecule and each position via a broadcast operation without the need for explicit loop operations. This helpsimprove computational efficiency, especially when working with large-scale data.

Then bring in a randomly generated data as an example:

grid_test.py :

import torch
import torch.nn.functional as F

# Sample data
batch_size = 2
n_atoms = 5
n_dims = 3
n_positions = 10

mols = torch.rand(batch_size, n_atoms, n_dims) # Randomly generate molecular coordinate data
grid = torch.rand(n_positions, n_dims) # Randomly generate grid position data

print("batch_size = ", batch_size," n_atoms = ", n_atoms," n_dims = ",n_dims," n_positions = ",n_positions)

# Print sample data
print("Sample data mols:")
print("mols = torch.rand(batch_size, n_atoms, n_dims)")
print(mols)
print(mols.shape)

print("\\
Example data grid:")
print("grid = torch.rand(n_positions, n_dims)")
print(grid)
print(grid.shape)

# If the dimensions of the mols tensor are 1 more than the dimensions of the grid tensor, add an extra dimension
if len(mols.size()) == len(grid.size()) + 1:
    grid = grid.unsqueeze(0)
    print("\\
grid after adding extra dimensions:")
    print(grid)
    print(grid.shape)

print("\\
mols[:, :, None, :]")
print(mols[:, :, None, :])
print(mols[:, :, None, :].shape)

print("\\
grid[:, None, :, :]")
print(grid[:, None, :, :])
print(grid[:, None, :, :].shape)

print("\\
mols[:, :, None, :] - grid[:, None, :, :]")
print(mols[:, :, None, :] - grid[:, None, :, :])
print((mols[:, :, None, :] - grid[:, None, :, :]).shape)

print("\\
(mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)")
print((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2))

print("\\
torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)")
print(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1))
print((torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)).shape)

print("\\
F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True )")
print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True))
print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)). shape)

print("\\
F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True ).sqrt_()")
print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_( ))
print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_ ()).shape)

# Calculate the distance between each atom of each molecule and each position
result = F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_( )

#Print calculation results
print("\\
Calculation result:")
print(result)

Output results:

$ python grid_test.py 
batch_size = 2 n_atoms = 5 n_dims = 3 n_positions = 10
Example data mols:
mols = torch.rand(batch_size, n_atoms, n_dims)
tensor([[[0.3787, 0.1093, 0.5062],
         [0.3149, 0.4295, 0.1202],
         [0.6499, 0.6533, 0.6489],
         [0.9395, 0.5027, 0.7664],
         [0.5991, 0.5733, 0.6474]],

        [[0.1370, 0.3499, 0.7365],
         [0.3564, 0.4096, 0.1820],
         [0.2576, 0.4737, 0.2487],
         [0.3169, 0.5875, 0.0414],
         [0.9958, 0.2101, 0.3953]]])
torch.Size([2, 5, 3])

Example data grid:
grid = torch.rand(n_positions, n_dims)
tensor([[0.0985, 0.5795, 0.3998],
        [0.3772, 0.8160, 0.7968],
        [0.9571, 0.0205, 0.5068],
        [0.7847, 0.9675, 0.3421],
        [0.9007, 0.0692, 0.3701],
        [0.8763, 0.4045, 0.2783],
        [0.5665, 0.1797, 0.8626],
        [0.4253, 0.6738, 0.3789],
        [0.3690, 0.3504, 0.3530],
        [0.1773, 0.4790, 0.9227]])
torch.Size([10, 3])

grid after adding extra dimensions:
tensor([[[0.0985, 0.5795, 0.3998],
         [0.3772, 0.8160, 0.7968],
         [0.9571, 0.0205, 0.5068],
         [0.7847, 0.9675, 0.3421],
         [0.9007, 0.0692, 0.3701],
         [0.8763, 0.4045, 0.2783],
         [0.5665, 0.1797, 0.8626],
         [0.4253, 0.6738, 0.3789],
         [0.3690, 0.3504, 0.3530],
         [0.1773, 0.4790, 0.9227]]])
torch.Size([1, 10, 3])

mols[:, :, None, :]
tensor([[[[0.3787, 0.1093, 0.5062]],

         [[0.3149, 0.4295, 0.1202]],

         [[0.6499, 0.6533, 0.6489]],

         [[0.9395, 0.5027, 0.7664]],

         [[0.5991, 0.5733, 0.6474]]],


        [[[0.1370, 0.3499, 0.7365]],

         [[0.3564, 0.4096, 0.1820]],

         [[0.2576, 0.4737, 0.2487]],

         [[0.3169, 0.5875, 0.0414]],

         [[0.9958, 0.2101, 0.3953]]]])
torch.Size([2, 5, 1, 3])

grid[:, None, :, :]
tensor([[[[0.0985, 0.5795, 0.3998],
          [0.3772, 0.8160, 0.7968],
          [0.9571, 0.0205, 0.5068],
          [0.7847, 0.9675, 0.3421],
          [0.9007, 0.0692, 0.3701],
          [0.8763, 0.4045, 0.2783],
          [0.5665, 0.1797, 0.8626],
          [0.4253, 0.6738, 0.3789],
          [0.3690, 0.3504, 0.3530],
          [0.1773, 0.4790, 0.9227]]]])
torch.Size([1, 1, 10, 3])

mols[:, :, None, :] - grid[:, None, :, :]
tensor([[[[ 2.8014e-01, -4.7024e-01, 1.0644e-01],
          [1.4571e-03, -7.0671e-01, -2.9065e-01],
          [-5.7839e-01, 8.8782e-02, -5.7650e-04],
          [-4.0609e-01, -8.5820e-01, 1.6411e-01],
          [-5.2208e-01, 4.0115e-02, 1.3611e-01],
          [-4.9762e-01, -2.9525e-01, 2.2792e-01],
          [-1.8779e-01, -7.0417e-02, -3.5645e-01],
          [-4.6644e-02, -5.6448e-01, 1.2727e-01],
          [9.6272e-03, -2.4109e-01, 1.5320e-01],
          [2.0133e-01, -3.6968e-01, -4.1653e-01]],

         [[ 2.1635e-01, -1.4999e-01, -2.7958e-01],
          [-6.2337e-02, -3.8646e-01, -6.7667e-01],
          [-6.4219e-01, 4.0903e-01, -3.8660e-01],
          [-4.6988e-01, -5.3796e-01, -2.2191e-01],
          [-5.8587e-01, 3.6036e-01, -2.4991e-01],
          [-5.6142e-01, 2.4999e-02, -1.5810e-01],
          [-2.5159e-01, 2.4983e-01, -7.4247e-01],
          [-1.1044e-01, -2.4423e-01, -2.5875e-01],
          [-5.4167e-02, 7.9156e-02, -2.3282e-01],
          [1.3754e-01, -4.9439e-02, -8.0255e-01]],

         [[ 5.5133e-01, 7.3810e-02, 2.4912e-01],
          [2.7265e-01, -1.6266e-01, -1.4796e-01],
          [-3.0720e-01, 6.3283e-01, 1.4211e-01],
          [-1.3490e-01, -3.1416e-01, 3.0679e-01],
          [-2.5089e-01, 5.8416e-01, 2.7880e-01],
          [-2.2644e-01, 2.4880e-01, 3.7061e-01],
          [8.3396e-02, 4.7363e-01, -2.1377e-01],
          [2.2454e-01, -2.0431e-02, 2.6995e-01],
          [2.8082e-01, 3.0296e-01, 2.9588e-01],
          [4.7252e-01, 1.7436e-01, -2.7384e-01]],

         [[ 8.4102e-01, -7.6851e-02, 3.6667e-01],
          [5.6233e-01, -3.1332e-01, -3.0420e-02],
          [-1.7516e-02, 4.8217e-01, 2.5965e-01],
          [1.5479e-01, -4.6482e-01, 4.2434e-01],
          [3.8800e-02, 4.3350e-01, 3.9634e-01],
          [6.3253e-02, 9.8140e-02, 4.8815e-01],
          [3.7309e-01, 3.2297e-01, -9.6225e-02],
          [5.1423e-01, -1.7109e-01, 3.8750e-01],
          [5.7050e-01, 1.5230e-01, 4.1342e-01],
          [7.6221e-01, 2.3702e-02, -1.5630e-01]],

         [[ 5.0062e-01, -6.2699e-03, 2.4768e-01],
          [2.2193e-01, -2.4274e-01, -1.4941e-01],
          [-3.5792e-01, 5.5275e-01, 1.4067e-01],
          [-1.8561e-01, -3.9424e-01, 3.0535e-01],
          [-3.0160e-01, 5.0408e-01, 2.7735e-01],
          [-2.7715e-01, 1.6872e-01, 3.6916e-01],
          [3.2683e-02, 3.9355e-01, -2.1521e-01],
          [1.7383e-01, -1.0051e-01, 2.6851e-01],
          [2.3010e-01, 2.2288e-01, 2.9444e-01],
          [4.2181e-01, 9.4283e-02, -2.7528e-01]]],


        [[[ 3.8479e-02, -2.2963e-01, 3.3670e-01],
          [-2.4021e-01, -4.6610e-01, -6.0388e-02],
          [-8.2006e-01, 3.2939e-01, 2.2968e-01],
          [-6.4775e-01, -6.1759e-01, 3.9437e-01],
          [-7.6374e-01, 2.8072e-01, 3.6637e-01],
          [-7.3929e-01, -5.4636e-02, 4.5818e-01],
          [-4.2946e-01, 1.7019e-01, -1.2619e-01],
          [-2.8831e-01, -3.2387e-01, 3.5753e-01],
          [-2.3204e-01, -4.7952e-04, 3.8346e-01],
          [-4.0334e-02, -1.2907e-01, -1.8627e-01]],

         [[ 2.5793e-01, -1.6992e-01, -2.1771e-01],
          [-2.0759e-02, -4.0639e-01, -6.1480e-01],
          [-6.0061e-01, 3.8910e-01, -3.2472e-01],
          [-4.2830e-01, -5.5788e-01, -1.6004e-01],
          [-5.4429e-01, 3.4043e-01, -1.8804e-01],
          [-5.1984e-01, 5.0723e-03, -9.6227e-02],
          [-2.1001e-01, 2.2990e-01, -6.8060e-01],
          [-6.8861e-02, -2.6416e-01, -1.9688e-01],
          [-1.2589e-02, 5.9229e-02, -1.7095e-01],
          [1.7911e-01, -6.9366e-02, -7.4067e-01]],

         [[ 1.5904e-01, -1.0583e-01, -1.5102e-01],
          [-1.1965e-01, -3.4230e-01, -5.4810e-01],
          [-6.9950e-01, 4.5319e-01, -2.5803e-01],
          [-5.2719e-01, -4.9380e-01, -9.3347e-02],
          [-6.4318e-01, 4.0452e-01, -1.2134e-01],
          [-6.1873e-01, 6.9158e-02, -2.9533e-02],
          [-3.0890e-01, 2.9399e-01, -6.1391e-01],
          [-1.6775e-01, -2.0007e-01, -1.3018e-01],
          [-1.1148e-01, 1.2331e-01, -1.0426e-01],
          [8.0225e-02, -5.2802e-03, -6.7398e-01]],

         [[ 2.1837e-01, 8.0081e-03, -3.5833e-01],
          [-6.0314e-02, -2.2846e-01, -7.5541e-01],
          [-6.4016e-01, 5.6703e-01, -4.6534e-01],
          [-4.6786e-01, -3.7996e-01, -3.0066e-01],
          [-5.8385e-01, 5.1836e-01, -3.2865e-01],
          [-5.5940e-01, 1.8300e-01, -2.3684e-01],
          [-2.4956e-01, 4.0783e-01, -8.2122e-01],
          [-1.0842e-01, -8.6233e-02, -3.3749e-01],
          [-5.2144e-02, 2.3716e-01, -3.1157e-01],
          [1.3956e-01, 1.0856e-01, -8.8129e-01]],

         [[ 8.9723e-01, -3.6938e-01, -4.4795e-03],
          [6.1855e-01, -6.0585e-01, -4.0157e-01],
          [3.8695e-02, 1.8964e-01, -1.1149e-01],
          [2.1100e-01, -7.5734e-01, 5.3189e-02],
          [9.5011e-02, 1.4097e-01, 2.5192e-02],
          [1.1946e-01, -1.9439e-01, 1.1700e-01],
          [4.2930e-01, 3.0443e-02, -4.6737e-01],
          [5.7044e-01, -4.6362e-01, 1.6351e-02],
          [6.2672e-01, -1.4023e-01, 4.2277e-02],
          [8.1842e-01, -2.6882e-01, -5.2744e-01]]]])
torch.Size([2, 5, 10, 3])

(mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)
tensor([[[[7.8481e-02, 2.2112e-01, 1.1329e-02],
          [2.1231e-06, 4.9944e-01, 8.4477e-02],
          [3.3454e-01, 7.8823e-03, 3.3235e-07],
          [1.6491e-01, 7.3651e-01, 2.6931e-02],
          [2.7257e-01, 1.6092e-03, 1.8526e-02],
          [2.4763e-01, 8.7170e-02, 5.1948e-02],
          [3.5266e-02, 4.9586e-03, 1.2706e-01],
          [2.1757e-03, 3.1863e-01, 1.6198e-02],
          [9.2683e-05, 5.8124e-02, 2.3469e-02],
          [4.0534e-02, 1.3667e-01, 1.7349e-01]],

         [[4.6807e-02, 2.2498e-02, 7.8165e-02],
          [3.8859e-03, 1.4935e-01, 4.5788e-01],
          [4.1240e-01, 1.6730e-01, 1.4946e-01],
          [2.2079e-01, 2.8940e-01, 4.9245e-02],
          [3.4325e-01, 1.2986e-01, 6.2455e-02],
          [3.1519e-01, 6.2496e-04, 2.4995e-02],
          [6.3296e-02, 6.2414e-02, 5.5127e-01],
          [1.2197e-02, 5.9650e-02, 6.6951e-02],
          [2.9340e-03, 6.2657e-03, 5.4207e-02],
          [1.8916e-02, 2.4442e-03, 6.4408e-01]],

         [[3.0397e-01, 5.4479e-03, 6.2063e-02],
          [7.4335e-02, 2.6459e-02, 2.1893e-02],
          [9.4375e-02, 4.0047e-01, 2.0195e-02],
          [1.8198e-02, 9.8694e-02, 9.4122e-02],
          [6.2946e-02, 3.4124e-01, 7.7727e-02],
          [5.1273e-02, 6.1902e-02, 1.3735e-01],
          [6.9548e-03, 2.2433e-01, 4.5697e-02],
          [5.0420e-02, 4.1742e-04, 7.2876e-02],
          [7.8857e-02, 9.1784e-02, 8.7545e-02],
          [2.2327e-01, 3.0403e-02, 7.4989e-02]],

         [[7.0732e-01, 5.9061e-03, 1.3445e-01],
          [3.1622e-01, 9.8171e-02, 9.2536e-04],
          [3.0679e-04, 2.3249e-01, 6.7419e-02],
          [2.3960e-02, 2.1606e-01, 1.8006e-01],
          [1.5054e-03, 1.8792e-01, 1.5708e-01],
          [4.0009e-03, 9.6314e-03, 2.3829e-01],
          [1.3919e-01, 1.0431e-01, 9.2593e-03],
          [2.6444e-01, 2.9273e-02, 1.5016e-01],
          [3.2548e-01, 2.3194e-02, 1.7092e-01],
          [5.8096e-01, 5.6178e-04, 2.4429e-02]],

         [[2.5062e-01, 3.9311e-05, 6.1346e-02],
          [4.9254e-02, 5.8923e-02, 2.2322e-02],
          [1.2810e-01, 3.0553e-01, 1.9787e-02],
          [3.4452e-02, 1.5542e-01, 9.3239e-02],
          [9.0964e-02, 2.5410e-01, 7.6924e-02],
          [7.6812e-02, 2.8467e-02, 1.3628e-01],
          [1.0682e-03, 1.5488e-01, 4.6316e-02],
          [3.0217e-02, 1.0102e-02, 7.2099e-02],
          [5.2947e-02, 4.9675e-02, 8.6694e-02],
          [1.7792e-01, 8.8893e-03, 7.5782e-02]]],


        [[[1.4806e-03, 5.2729e-02, 1.1337e-01],
          [5.7700e-02, 2.1725e-01, 3.6467e-03],
          [6.7250e-01, 1.0850e-01, 5.2755e-02],
          [4.1958e-01, 3.8142e-01, 1.5553e-01],
          [5.8330e-01, 7.8806e-02, 1.3423e-01],
          [5.4655e-01, 2.9851e-03, 2.0993e-01],
          [1.8443e-01, 2.8965e-02, 1.5925e-02],
          [8.3122e-02, 1.0489e-01, 1.2783e-01],
          [5.3842e-02, 2.2994e-07, 1.4704e-01],
          [1.6269e-03, 1.6660e-02, 3.4695e-02]],

         [[6.6527e-02, 2.8872e-02, 4.7397e-02],
          [4.3095e-04, 1.6515e-01, 3.7797e-01],
          [3.6073e-01, 1.5140e-01, 1.0545e-01],
          [1.8344e-01, 3.1124e-01, 2.5613e-02],
          [2.9626e-01, 1.1589e-01, 3.5358e-02],
          [2.7023e-01, 2.5728e-05, 9.2596e-03],
          [4.4104e-02, 5.2854e-02, 4.6322e-01],
          [4.7418e-03, 6.9780e-02, 3.8761e-02],
          [1.5849e-04, 3.5081e-03, 2.9225e-02],
          [3.2082e-02, 4.8116e-03, 5.4860e-01]],

         [[2.5293e-02, 1.1201e-02, 2.2806e-02],
          [1.4316e-02, 1.1717e-01, 3.0042e-01],
          [4.8930e-01, 2.0538e-01, 6.6580e-02],
          [2.7793e-01, 2.4384e-01, 8.7136e-03],
          [4.1369e-01, 1.6363e-01, 1.4724e-02],
          [3.8283e-01, 4.7828e-03, 8.7222e-04],
          [9.5418e-02, 8.6428e-02, 3.7688e-01],
          [2.8140e-02, 4.0030e-02, 1.6948e-02],
          [1.2428e-02, 1.5206e-02, 1.0870e-02],
          [6.4360e-03, 2.7880e-05, 4.5425e-01]],

         [[4.7686e-02, 6.4129e-05, 1.2840e-01],
          [3.6378e-03, 5.2195e-02, 5.7065e-01],
          [4.0981e-01, 3.2152e-01, 2.1654e-01],
          [2.1889e-01, 1.4437e-01, 9.0394e-02],
          [3.4088e-01, 2.6870e-01, 1.0801e-01],
          [3.1292e-01, 3.3489e-02, 5.6095e-02],
          [6.2282e-02, 1.6632e-01, 6.7440e-01],
          [1.1754e-02, 7.4361e-03, 1.1390e-01],
          [2.7190e-03, 5.6243e-02, 9.7075e-02],
          [1.9477e-02, 1.1786e-02, 7.7667e-01]],

         [[8.0503e-01, 1.3644e-01, 2.0066e-05],
          [3.8260e-01, 3.6705e-01, 1.6126e-01],
          [1.4973e-03, 3.5964e-02, 1.2431e-02],
          [4.4522e-02, 5.7357e-01, 2.8291e-03],
          [9.0270e-03, 1.9874e-02, 6.3462e-04],
          [1.4272e-02, 3.7786e-02, 1.3690e-02],
          [1.8430e-01, 9.2680e-04, 2.1844e-01],
          [3.2541e-01, 2.1494e-01, 2.6736e-04],
          [3.9277e-01, 1.9664e-02, 1.7874e-03],
          [6.6981e-01, 7.2266e-02, 2.7820e-01]]]])

torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)
tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370,
          0.0817, 0.3507],
         [0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388,
          0.0634, 0.6654],
         [0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237,
          0.2582, 0.3287],
         [0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439,
          0.5196, 0.6060],
         [0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124,
          0.1893, 0.2626]],

        [[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158,
          0.2009, 0.0530],
         [0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133,
          0.0329, 0.5855],
         [0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851,
          0.0385, 0.4607],
         [0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331,
          0.1560, 0.8079],
         [0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406,
          0.4142, 1.0203]]])
torch.Size([2, 5, 10])

F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)
tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370,
          0.0817, 0.3507],
         [0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388,
          0.0634, 0.6654],
         [0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237,
          0.2582, 0.3287],
         [0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439,
          0.5196, 0.6060],
         [0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124,
          0.1893, 0.2626]],

        [[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158,
          0.2009, 0.0530],
         [0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133,
          0.0329, 0.5855],
         [0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851,
          0.0385, 0.4607],
         [0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331,
          0.1560, 0.8079],
         [0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406,
          0.4142, 1.0203]]])
torch.Size([2, 5, 10])

F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()
tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805,
          0.2858, 0.5922],
         [0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726,
          0.2518, 0.8157],
         [0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517,
          0.5081, 0.5733],
         [0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662,
          0.7208, 0.7784],
         [0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353,
          0.4351, 0.5124]],

        [[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620,
          0.4482, 0.2302],
         [0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366,
          0.1814, 0.7652],
         [0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918,
          0.1962, 0.6788],
         [0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648,
          0.3950, 0.8989],
         [0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353,
          0.6436, 1.0101]]])
torch.Size([2, 5, 10])

Calculation results:
tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805,
          0.2858, 0.5922],
         [0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726,
          0.2518, 0.8157],
         [0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517,
          0.5081, 0.5733],
         [0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662,
          0.7208, 0.7784],
         [0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353,
          0.4351, 0.5124]],

        [[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620,
          0.4482, 0.2302],
         [0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366,
          0.1814, 0.7652],
         [0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918,
          0.1962, 0.6788],
         [0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648,
          0.3950, 0.8989],
         [0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353,
          0.6436, 1.0101]]])

The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge. Python entry skill treeArtificial intelligenceDeep learning 378974 people are learning the system