Broadcast operations allow you to perform element-wise operations on differently shaped tensors without explicit looping.
A practical example of discretizing molecular coordinates:
def cdists(mols, grid): ''' Calculates the pairwise Euclidean distances between a set of molecules and a list of positions on a grid (uses inplace operations to minimize memory demands). Args: mols (torch.Tensor): data set (of molecules) with shape (batch_size x n_atoms x n_dims) grid (torch.Tensor): array (of positions) with shape (n_positions x n_dims) Returns: torch.Tensor: batch of distance matrices (batch_size x n_atoms x n_positions) ''' if len(mols.size()) == len(grid.size()) + 1: grid = grid.unsqueeze(0) # add batch dimension return F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()
So, why does the above code perform an operation like: (mols[:, :, None, :] – grid[:, None, :, :])?
This code is used to calculate the distance between a set of molecules (mols
) and a set of location points on a grid (grid
) Euclidean distance. The (mols[:, :, None, :] - grid[:, None, :, :])
operation here involves the broadcast operation of tensors, and its purpose is to calculate the The distance between each atom and each position.
How this code works:
- The shape of the
mols
tensor is(batch_size x n_atoms x n_dims)
, wherebatch_size
is the batch size andn_atoms
is the number of atoms,n_dims
is the coordinate dimension of the atom (usually 3 dimensions, indicating xyz coordinates). - The shape of the
grid
tensor is(n_positions x n_dims)
, wheren_positions
is the number of position grid points andn_dims
The same is the dimension of the coordinates.
First, if the dimension of the mols
tensor is 1 more than the dimension of the grid
tensor, the code adds an extra dimension via grid.unsqueeze(0)
Dimensions to match the batch dimensions of the mols
tensor. This is for broadcast operations to take effect.
Next, the code uses a broadcast operation to calculate the distance between each atom of each molecule and each position. Broadcast operations allow you to perform element-wise operations on tensors that do not have the same shape without explicit looping.
-
The shape of
mols[:, :, None, :]
becomes(batch_size x n_atoms x 1 x n_dims)
. This operation adds an extra dimension to then_atoms
dimension for broadcasting withgrid[:, None, :, :]
. -
The shape of
grid[:, None, :, :]
becomes(1 x 1 x n_positions x n_dims)
. This operation adds an extra dimension in the batch dimension and then_atoms
dimension for broadcasting withmols[:, :, None, :]
. -
The two tensors are then subtracted element-by-element, calculating the difference between each atom and each position grid point of each molecule. The result is a tensor with shape
(batch_size x n_atoms x n_positions x n_dims)
. -
Finally, the Euclidean distance between each atom of each molecule and each position grid point was calculated using the
.pow_(2)
and.sqrt_()
operations. Get distance.
In summary, this code efficiently calculates the distance between each atom of each molecule and each position via a broadcast operation without the need for explicit loop operations. This helpsimprove computational efficiency, especially when working with large-scale data.
Then bring in a randomly generated data as an example:
grid_test.py :
import torch import torch.nn.functional as F # Sample data batch_size = 2 n_atoms = 5 n_dims = 3 n_positions = 10 mols = torch.rand(batch_size, n_atoms, n_dims) # Randomly generate molecular coordinate data grid = torch.rand(n_positions, n_dims) # Randomly generate grid position data print("batch_size = ", batch_size," n_atoms = ", n_atoms," n_dims = ",n_dims," n_positions = ",n_positions) # Print sample data print("Sample data mols:") print("mols = torch.rand(batch_size, n_atoms, n_dims)") print(mols) print(mols.shape) print("\\ Example data grid:") print("grid = torch.rand(n_positions, n_dims)") print(grid) print(grid.shape) # If the dimensions of the mols tensor are 1 more than the dimensions of the grid tensor, add an extra dimension if len(mols.size()) == len(grid.size()) + 1: grid = grid.unsqueeze(0) print("\\ grid after adding extra dimensions:") print(grid) print(grid.shape) print("\\ mols[:, :, None, :]") print(mols[:, :, None, :]) print(mols[:, :, None, :].shape) print("\\ grid[:, None, :, :]") print(grid[:, None, :, :]) print(grid[:, None, :, :].shape) print("\\ mols[:, :, None, :] - grid[:, None, :, :]") print(mols[:, :, None, :] - grid[:, None, :, :]) print((mols[:, :, None, :] - grid[:, None, :, :]).shape) print("\\ (mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)") print((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)) print("\\ torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)") print(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)) print((torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)).shape) print("\\ F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True )") print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)) print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)). shape) print("\\ F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True ).sqrt_()") print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_( )) print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_ ()).shape) # Calculate the distance between each atom of each molecule and each position result = F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_( ) #Print calculation results print("\\ Calculation result:") print(result)
Output results:
$ python grid_test.py
batch_size = 2 n_atoms = 5 n_dims = 3 n_positions = 10 Example data mols: mols = torch.rand(batch_size, n_atoms, n_dims) tensor([[[0.3787, 0.1093, 0.5062], [0.3149, 0.4295, 0.1202], [0.6499, 0.6533, 0.6489], [0.9395, 0.5027, 0.7664], [0.5991, 0.5733, 0.6474]], [[0.1370, 0.3499, 0.7365], [0.3564, 0.4096, 0.1820], [0.2576, 0.4737, 0.2487], [0.3169, 0.5875, 0.0414], [0.9958, 0.2101, 0.3953]]]) torch.Size([2, 5, 3]) Example data grid: grid = torch.rand(n_positions, n_dims) tensor([[0.0985, 0.5795, 0.3998], [0.3772, 0.8160, 0.7968], [0.9571, 0.0205, 0.5068], [0.7847, 0.9675, 0.3421], [0.9007, 0.0692, 0.3701], [0.8763, 0.4045, 0.2783], [0.5665, 0.1797, 0.8626], [0.4253, 0.6738, 0.3789], [0.3690, 0.3504, 0.3530], [0.1773, 0.4790, 0.9227]]) torch.Size([10, 3]) grid after adding extra dimensions: tensor([[[0.0985, 0.5795, 0.3998], [0.3772, 0.8160, 0.7968], [0.9571, 0.0205, 0.5068], [0.7847, 0.9675, 0.3421], [0.9007, 0.0692, 0.3701], [0.8763, 0.4045, 0.2783], [0.5665, 0.1797, 0.8626], [0.4253, 0.6738, 0.3789], [0.3690, 0.3504, 0.3530], [0.1773, 0.4790, 0.9227]]]) torch.Size([1, 10, 3]) mols[:, :, None, :] tensor([[[[0.3787, 0.1093, 0.5062]], [[0.3149, 0.4295, 0.1202]], [[0.6499, 0.6533, 0.6489]], [[0.9395, 0.5027, 0.7664]], [[0.5991, 0.5733, 0.6474]]], [[[0.1370, 0.3499, 0.7365]], [[0.3564, 0.4096, 0.1820]], [[0.2576, 0.4737, 0.2487]], [[0.3169, 0.5875, 0.0414]], [[0.9958, 0.2101, 0.3953]]]]) torch.Size([2, 5, 1, 3]) grid[:, None, :, :] tensor([[[[0.0985, 0.5795, 0.3998], [0.3772, 0.8160, 0.7968], [0.9571, 0.0205, 0.5068], [0.7847, 0.9675, 0.3421], [0.9007, 0.0692, 0.3701], [0.8763, 0.4045, 0.2783], [0.5665, 0.1797, 0.8626], [0.4253, 0.6738, 0.3789], [0.3690, 0.3504, 0.3530], [0.1773, 0.4790, 0.9227]]]]) torch.Size([1, 1, 10, 3]) mols[:, :, None, :] - grid[:, None, :, :] tensor([[[[ 2.8014e-01, -4.7024e-01, 1.0644e-01], [1.4571e-03, -7.0671e-01, -2.9065e-01], [-5.7839e-01, 8.8782e-02, -5.7650e-04], [-4.0609e-01, -8.5820e-01, 1.6411e-01], [-5.2208e-01, 4.0115e-02, 1.3611e-01], [-4.9762e-01, -2.9525e-01, 2.2792e-01], [-1.8779e-01, -7.0417e-02, -3.5645e-01], [-4.6644e-02, -5.6448e-01, 1.2727e-01], [9.6272e-03, -2.4109e-01, 1.5320e-01], [2.0133e-01, -3.6968e-01, -4.1653e-01]], [[ 2.1635e-01, -1.4999e-01, -2.7958e-01], [-6.2337e-02, -3.8646e-01, -6.7667e-01], [-6.4219e-01, 4.0903e-01, -3.8660e-01], [-4.6988e-01, -5.3796e-01, -2.2191e-01], [-5.8587e-01, 3.6036e-01, -2.4991e-01], [-5.6142e-01, 2.4999e-02, -1.5810e-01], [-2.5159e-01, 2.4983e-01, -7.4247e-01], [-1.1044e-01, -2.4423e-01, -2.5875e-01], [-5.4167e-02, 7.9156e-02, -2.3282e-01], [1.3754e-01, -4.9439e-02, -8.0255e-01]], [[ 5.5133e-01, 7.3810e-02, 2.4912e-01], [2.7265e-01, -1.6266e-01, -1.4796e-01], [-3.0720e-01, 6.3283e-01, 1.4211e-01], [-1.3490e-01, -3.1416e-01, 3.0679e-01], [-2.5089e-01, 5.8416e-01, 2.7880e-01], [-2.2644e-01, 2.4880e-01, 3.7061e-01], [8.3396e-02, 4.7363e-01, -2.1377e-01], [2.2454e-01, -2.0431e-02, 2.6995e-01], [2.8082e-01, 3.0296e-01, 2.9588e-01], [4.7252e-01, 1.7436e-01, -2.7384e-01]], [[ 8.4102e-01, -7.6851e-02, 3.6667e-01], [5.6233e-01, -3.1332e-01, -3.0420e-02], [-1.7516e-02, 4.8217e-01, 2.5965e-01], [1.5479e-01, -4.6482e-01, 4.2434e-01], [3.8800e-02, 4.3350e-01, 3.9634e-01], [6.3253e-02, 9.8140e-02, 4.8815e-01], [3.7309e-01, 3.2297e-01, -9.6225e-02], [5.1423e-01, -1.7109e-01, 3.8750e-01], [5.7050e-01, 1.5230e-01, 4.1342e-01], [7.6221e-01, 2.3702e-02, -1.5630e-01]], [[ 5.0062e-01, -6.2699e-03, 2.4768e-01], [2.2193e-01, -2.4274e-01, -1.4941e-01], [-3.5792e-01, 5.5275e-01, 1.4067e-01], [-1.8561e-01, -3.9424e-01, 3.0535e-01], [-3.0160e-01, 5.0408e-01, 2.7735e-01], [-2.7715e-01, 1.6872e-01, 3.6916e-01], [3.2683e-02, 3.9355e-01, -2.1521e-01], [1.7383e-01, -1.0051e-01, 2.6851e-01], [2.3010e-01, 2.2288e-01, 2.9444e-01], [4.2181e-01, 9.4283e-02, -2.7528e-01]]], [[[ 3.8479e-02, -2.2963e-01, 3.3670e-01], [-2.4021e-01, -4.6610e-01, -6.0388e-02], [-8.2006e-01, 3.2939e-01, 2.2968e-01], [-6.4775e-01, -6.1759e-01, 3.9437e-01], [-7.6374e-01, 2.8072e-01, 3.6637e-01], [-7.3929e-01, -5.4636e-02, 4.5818e-01], [-4.2946e-01, 1.7019e-01, -1.2619e-01], [-2.8831e-01, -3.2387e-01, 3.5753e-01], [-2.3204e-01, -4.7952e-04, 3.8346e-01], [-4.0334e-02, -1.2907e-01, -1.8627e-01]], [[ 2.5793e-01, -1.6992e-01, -2.1771e-01], [-2.0759e-02, -4.0639e-01, -6.1480e-01], [-6.0061e-01, 3.8910e-01, -3.2472e-01], [-4.2830e-01, -5.5788e-01, -1.6004e-01], [-5.4429e-01, 3.4043e-01, -1.8804e-01], [-5.1984e-01, 5.0723e-03, -9.6227e-02], [-2.1001e-01, 2.2990e-01, -6.8060e-01], [-6.8861e-02, -2.6416e-01, -1.9688e-01], [-1.2589e-02, 5.9229e-02, -1.7095e-01], [1.7911e-01, -6.9366e-02, -7.4067e-01]], [[ 1.5904e-01, -1.0583e-01, -1.5102e-01], [-1.1965e-01, -3.4230e-01, -5.4810e-01], [-6.9950e-01, 4.5319e-01, -2.5803e-01], [-5.2719e-01, -4.9380e-01, -9.3347e-02], [-6.4318e-01, 4.0452e-01, -1.2134e-01], [-6.1873e-01, 6.9158e-02, -2.9533e-02], [-3.0890e-01, 2.9399e-01, -6.1391e-01], [-1.6775e-01, -2.0007e-01, -1.3018e-01], [-1.1148e-01, 1.2331e-01, -1.0426e-01], [8.0225e-02, -5.2802e-03, -6.7398e-01]], [[ 2.1837e-01, 8.0081e-03, -3.5833e-01], [-6.0314e-02, -2.2846e-01, -7.5541e-01], [-6.4016e-01, 5.6703e-01, -4.6534e-01], [-4.6786e-01, -3.7996e-01, -3.0066e-01], [-5.8385e-01, 5.1836e-01, -3.2865e-01], [-5.5940e-01, 1.8300e-01, -2.3684e-01], [-2.4956e-01, 4.0783e-01, -8.2122e-01], [-1.0842e-01, -8.6233e-02, -3.3749e-01], [-5.2144e-02, 2.3716e-01, -3.1157e-01], [1.3956e-01, 1.0856e-01, -8.8129e-01]], [[ 8.9723e-01, -3.6938e-01, -4.4795e-03], [6.1855e-01, -6.0585e-01, -4.0157e-01], [3.8695e-02, 1.8964e-01, -1.1149e-01], [2.1100e-01, -7.5734e-01, 5.3189e-02], [9.5011e-02, 1.4097e-01, 2.5192e-02], [1.1946e-01, -1.9439e-01, 1.1700e-01], [4.2930e-01, 3.0443e-02, -4.6737e-01], [5.7044e-01, -4.6362e-01, 1.6351e-02], [6.2672e-01, -1.4023e-01, 4.2277e-02], [8.1842e-01, -2.6882e-01, -5.2744e-01]]]]) torch.Size([2, 5, 10, 3]) (mols[:, :, None, :] - grid[:, None, :, :]).pow_(2) tensor([[[[7.8481e-02, 2.2112e-01, 1.1329e-02], [2.1231e-06, 4.9944e-01, 8.4477e-02], [3.3454e-01, 7.8823e-03, 3.3235e-07], [1.6491e-01, 7.3651e-01, 2.6931e-02], [2.7257e-01, 1.6092e-03, 1.8526e-02], [2.4763e-01, 8.7170e-02, 5.1948e-02], [3.5266e-02, 4.9586e-03, 1.2706e-01], [2.1757e-03, 3.1863e-01, 1.6198e-02], [9.2683e-05, 5.8124e-02, 2.3469e-02], [4.0534e-02, 1.3667e-01, 1.7349e-01]], [[4.6807e-02, 2.2498e-02, 7.8165e-02], [3.8859e-03, 1.4935e-01, 4.5788e-01], [4.1240e-01, 1.6730e-01, 1.4946e-01], [2.2079e-01, 2.8940e-01, 4.9245e-02], [3.4325e-01, 1.2986e-01, 6.2455e-02], [3.1519e-01, 6.2496e-04, 2.4995e-02], [6.3296e-02, 6.2414e-02, 5.5127e-01], [1.2197e-02, 5.9650e-02, 6.6951e-02], [2.9340e-03, 6.2657e-03, 5.4207e-02], [1.8916e-02, 2.4442e-03, 6.4408e-01]], [[3.0397e-01, 5.4479e-03, 6.2063e-02], [7.4335e-02, 2.6459e-02, 2.1893e-02], [9.4375e-02, 4.0047e-01, 2.0195e-02], [1.8198e-02, 9.8694e-02, 9.4122e-02], [6.2946e-02, 3.4124e-01, 7.7727e-02], [5.1273e-02, 6.1902e-02, 1.3735e-01], [6.9548e-03, 2.2433e-01, 4.5697e-02], [5.0420e-02, 4.1742e-04, 7.2876e-02], [7.8857e-02, 9.1784e-02, 8.7545e-02], [2.2327e-01, 3.0403e-02, 7.4989e-02]], [[7.0732e-01, 5.9061e-03, 1.3445e-01], [3.1622e-01, 9.8171e-02, 9.2536e-04], [3.0679e-04, 2.3249e-01, 6.7419e-02], [2.3960e-02, 2.1606e-01, 1.8006e-01], [1.5054e-03, 1.8792e-01, 1.5708e-01], [4.0009e-03, 9.6314e-03, 2.3829e-01], [1.3919e-01, 1.0431e-01, 9.2593e-03], [2.6444e-01, 2.9273e-02, 1.5016e-01], [3.2548e-01, 2.3194e-02, 1.7092e-01], [5.8096e-01, 5.6178e-04, 2.4429e-02]], [[2.5062e-01, 3.9311e-05, 6.1346e-02], [4.9254e-02, 5.8923e-02, 2.2322e-02], [1.2810e-01, 3.0553e-01, 1.9787e-02], [3.4452e-02, 1.5542e-01, 9.3239e-02], [9.0964e-02, 2.5410e-01, 7.6924e-02], [7.6812e-02, 2.8467e-02, 1.3628e-01], [1.0682e-03, 1.5488e-01, 4.6316e-02], [3.0217e-02, 1.0102e-02, 7.2099e-02], [5.2947e-02, 4.9675e-02, 8.6694e-02], [1.7792e-01, 8.8893e-03, 7.5782e-02]]], [[[1.4806e-03, 5.2729e-02, 1.1337e-01], [5.7700e-02, 2.1725e-01, 3.6467e-03], [6.7250e-01, 1.0850e-01, 5.2755e-02], [4.1958e-01, 3.8142e-01, 1.5553e-01], [5.8330e-01, 7.8806e-02, 1.3423e-01], [5.4655e-01, 2.9851e-03, 2.0993e-01], [1.8443e-01, 2.8965e-02, 1.5925e-02], [8.3122e-02, 1.0489e-01, 1.2783e-01], [5.3842e-02, 2.2994e-07, 1.4704e-01], [1.6269e-03, 1.6660e-02, 3.4695e-02]], [[6.6527e-02, 2.8872e-02, 4.7397e-02], [4.3095e-04, 1.6515e-01, 3.7797e-01], [3.6073e-01, 1.5140e-01, 1.0545e-01], [1.8344e-01, 3.1124e-01, 2.5613e-02], [2.9626e-01, 1.1589e-01, 3.5358e-02], [2.7023e-01, 2.5728e-05, 9.2596e-03], [4.4104e-02, 5.2854e-02, 4.6322e-01], [4.7418e-03, 6.9780e-02, 3.8761e-02], [1.5849e-04, 3.5081e-03, 2.9225e-02], [3.2082e-02, 4.8116e-03, 5.4860e-01]], [[2.5293e-02, 1.1201e-02, 2.2806e-02], [1.4316e-02, 1.1717e-01, 3.0042e-01], [4.8930e-01, 2.0538e-01, 6.6580e-02], [2.7793e-01, 2.4384e-01, 8.7136e-03], [4.1369e-01, 1.6363e-01, 1.4724e-02], [3.8283e-01, 4.7828e-03, 8.7222e-04], [9.5418e-02, 8.6428e-02, 3.7688e-01], [2.8140e-02, 4.0030e-02, 1.6948e-02], [1.2428e-02, 1.5206e-02, 1.0870e-02], [6.4360e-03, 2.7880e-05, 4.5425e-01]], [[4.7686e-02, 6.4129e-05, 1.2840e-01], [3.6378e-03, 5.2195e-02, 5.7065e-01], [4.0981e-01, 3.2152e-01, 2.1654e-01], [2.1889e-01, 1.4437e-01, 9.0394e-02], [3.4088e-01, 2.6870e-01, 1.0801e-01], [3.1292e-01, 3.3489e-02, 5.6095e-02], [6.2282e-02, 1.6632e-01, 6.7440e-01], [1.1754e-02, 7.4361e-03, 1.1390e-01], [2.7190e-03, 5.6243e-02, 9.7075e-02], [1.9477e-02, 1.1786e-02, 7.7667e-01]], [[8.0503e-01, 1.3644e-01, 2.0066e-05], [3.8260e-01, 3.6705e-01, 1.6126e-01], [1.4973e-03, 3.5964e-02, 1.2431e-02], [4.4522e-02, 5.7357e-01, 2.8291e-03], [9.0270e-03, 1.9874e-02, 6.3462e-04], [1.4272e-02, 3.7786e-02, 1.3690e-02], [1.8430e-01, 9.2680e-04, 2.1844e-01], [3.2541e-01, 2.1494e-01, 2.6736e-04], [3.9277e-01, 1.9664e-02, 1.7874e-03], [6.6981e-01, 7.2266e-02, 2.7820e-01]]]]) torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1) tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370, 0.0817, 0.3507], [0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388, 0.0634, 0.6654], [0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237, 0.2582, 0.3287], [0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439, 0.5196, 0.6060], [0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124, 0.1893, 0.2626]], [[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158, 0.2009, 0.0530], [0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133, 0.0329, 0.5855], [0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851, 0.0385, 0.4607], [0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331, 0.1560, 0.8079], [0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406, 0.4142, 1.0203]]]) torch.Size([2, 5, 10]) F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True) tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370, 0.0817, 0.3507], [0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388, 0.0634, 0.6654], [0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237, 0.2582, 0.3287], [0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439, 0.5196, 0.6060], [0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124, 0.1893, 0.2626]], [[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158, 0.2009, 0.0530], [0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133, 0.0329, 0.5855], [0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851, 0.0385, 0.4607], [0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331, 0.1560, 0.8079], [0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406, 0.4142, 1.0203]]]) torch.Size([2, 5, 10]) F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_() tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805, 0.2858, 0.5922], [0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726, 0.2518, 0.8157], [0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517, 0.5081, 0.5733], [0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662, 0.7208, 0.7784], [0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353, 0.4351, 0.5124]], [[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620, 0.4482, 0.2302], [0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366, 0.1814, 0.7652], [0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918, 0.1962, 0.6788], [0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648, 0.3950, 0.8989], [0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353, 0.6436, 1.0101]]]) torch.Size([2, 5, 10]) Calculation results: tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805, 0.2858, 0.5922], [0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726, 0.2518, 0.8157], [0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517, 0.5081, 0.5733], [0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662, 0.7208, 0.7784], [0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353, 0.4351, 0.5124]], [[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620, 0.4482, 0.2302], [0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366, 0.1814, 0.7652], [0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918, 0.1962, 0.6788], [0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648, 0.3950, 0.8989], [0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353, 0.6436, 1.0101]]])
The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge. Python entry skill treeArtificial intelligenceDeep learning 378974 people are learning the system