SwinGAN (2)–[Sub-module]ciRPE (contextual image relative position encoding)

Calculation subject: rpe_k(q), q.shape=(batch_size,heads_num,patches_num_each_window,embed_dim)

1. Absolute position

  • Input: The height and width of the window.
  • Obtaining method: sqrt(patches_num_each_window)
  • Output: encoding of window coordinates.
  • [sqrt(L),sqrt(L),2]

2. diff

  • Input: absolute position coordinates
    Calculate the difference between each coordinate and other coordinates, including itself.
    diff.csv file
    (Intercepted part)
  • [L,L,2]

3. r, c

  • Input: Enter the diff represented by the abscissa and the diff represented by the ordinate respectively.
  • Algorithm Description:
    Take the position where the absolute value is greater than alpha=1.9, and set it to a new value (an integer in the closed interval [-3,3]). The original range changes from [-7,7] to [-3, 3]. To make it easier to understand, the logic is very simple. It only has judgment and assignment. It traverses each element value of the matrix and judges whether its absolute value is greater than 1.9. If it is satisfied, set it to another value between [-3,3]. Integers, signs remain consistent. Add 3 to change its range. 7x magnification r.
    For example, the coordinates of a certain position are -6 (arbitrary, as long as they are between [-7,7]), and the absolute value of -6 is greater than 1.9, so -6 is input into the function f and mapped to -2, So this point changed from -6 to -2. f contains a lot of logic, including scaling, translation, truncation, and rounding.
    See r.csv and c.csv (one abscissa and one ordinate)
  • Both are [L,L]

4. rp_bucket

  • Input: r, c
  • Algorithm Description:
    Add r and c
  • [L,L]

5. num_buckets

  • Input: calculation method
  • Output: 49

6. self._ctx_p_bucket_flatten

  • Input: rp_bucket
  • Algorithm Description:
    Scan rp_bucket row by row, and add an arithmetic sequence with the first item being 0, the tolerance being 49, and the number being 64. Flatten.
  • [L*L]

7. lookup_table*

  • Input: q, self._ctx_p_bucket_flatten,num_buckets
  • Calculation:
    Linear layer, first change the original 32dim of each header into the dim of num_buckets, and use self._ctx_p_bucket_flatten to index 64 from the vector of each num_buckets(49), so there will be duplicates.
  • [batch_size,num_heads,L,L]

Summary (talking)

The shape of lookup_table is (8,3,64,64) and the shape of attn is the same, and the two are subsequently added.
In fact, the previous step is all about preparing data, and q is only used in the seventh step. In essence, it is just to get the intermediate dimension between the index and the linear layer. All the previous logic is to make the index value more reasonable. So let’s take a look. The most essential step is the linear layer, which is to weight the channel data of q, change the length of the vector, and change the dimension of the channel. Then why not change the last step to a linear layer instead of building a complex index? Taking the choice, we know that each patch corresponds to a set of q/k/v. If we want to consider the position information and perform position encoding, the intuitive idea is to directly apply the position encoding to the patch (this is what I think), and the function On q or k or v, it’s not a bad idea to think about it, because q/k/v itself is obtained by x through the linear layer, and the data itself is the weighting of x data. A patch is encoded as a vector, a vector and a Multiplying sets of weights is a process of weighted summation. In the end, a vector with multiple elements only gets one value. I call it the “annihilation” process of weighted summation. If you want to finally get a vector, you need multiple sets of Weight, it can be said that the length of the final vector depends on the number of weight groups and has nothing to do with the length of the data vector. It only requires that the length of the data vector is equal to the length of a set of weight vectors. It does not require that the number of weight groups is equal to the length of the data vector. The vectors have the same length, and their number is free, depending on how long the vector is ultimately wanted to be.
Then, think about how it takes location information into account. I don’t quite understand. It feels like the coordinates are added, subtracted, multiplied, divided, and the thresholds set have no reasonable explanation. Anyway, as long as the location information is taken into account, the location information is considered.

Code implementation:

import torch
import pandas as pd
import csv
import math
import numpy as np

# 1. Absolute position
def get_absolute_positions(height, width, dtype, device):
    rows = torch.arange(height, dtype=dtype, device=device).view(
        height, 1).repeat(1, width)
    cols = torch.arange(width, dtype=dtype, device=device).view(
        1, width).repeat(height, 1)
    return torch.stack([rows, cols], 2)
pos=get_absolute_positions(8,8,dtype=torch.long,device='cuda:0')

# 2. diff
max_L = 8 * 8
pos1 = pos.view((max_L, 1, 2))
pos2 = pos.view((1, max_L, 2))
diff = pos1 - pos2

# Three, r, c
diff1=diff[:,:,0]
diff2=diff[:,:,1]
# tensor_list1 = diff1.tolist()
# tensor_list2 = diff2.tolist()
# # Save to CSV file
# with open('tensor_data1.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in tensor_list1:
# writer.writerow(row)
#
# with open('tensor_data2.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in tensor_list2:
# writer.writerow(row)
def piecewise_index(relative_position, alpha, beta, gamma, dtype):
    rp_abs = relative_position.abs()
    mask = rp_abs <= alpha
    not_mask = ~mask
    rp_out = relative_position[not_mask]
    rp_abs_out = rp_abs[not_mask]
    y_out = (torch.sign(rp_out) * (alpha +
                                   torch.log(rp_abs_out/alpha)/
                                   math.log(gamma / alpha) *
                                   (beta - alpha)).round().clip(max=beta)).to(dtype)

    idx = relative_position.clone()
    if idx.dtype in [torch.float32, torch.float64]:
        idx = idx.round().to(dtype)
    idx[not_mask] = y_out
    returnidx
r=piecewise_index(diff1,alpha=1.9,beta=3.8,gamma=15.2,dtype=torch.long)
r_list=r.tolist()
# with open('r.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in r_list:
# writer.writerow(row)
c=piecewise_index(diff2,alpha=1.9,beta=3.8,gamma=15.2,dtype=torch.long)
c_list=c.tolist()
# with open('c.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in c_list:
# writer.writerow(row)

r=r+3
c=c+3
r=r*7

# 4. rp_bucket
rp_bucket=r + c

# 5. num_buckets
num_buckets=49

# 6. self._ctx_rp_bucket_flatten
_ctx_rp_bucket_flatten = None
offset=torch.arange(0, 64 * 49, 49,dtype=torch.long, device='cuda:0').view(-1, 1)
_ctx_rp_bucket_flatten = (rp_bucket + offset).flatten()

# 7. lookup_table
num_heads=3
head_dim=32
lookup_table_weight=torch.zeros(num_heads,head_dim,num_buckets)
batch_size=8
patches_num_each_window=64
q=torch.randn(batch_size,num_heads,patches_num_each_window,head_dim)
lookup_table=torch.matmul(
                q.transpose(0, 1).reshape(-1, batch_size * 64, head_dim),
                lookup_table_weight).view(-1, batch_size, 64, num_buckets).transpose(0, 1)
'''
(8,3,64,32)->(3,8*64,32)
(3,32,49)
(8,3,64,49)
'''
lookup_table=lookup_table.flatten(2)[:, :, _ctx_rp_bucket_flatten].view(batch_size, -1, 64, 64)