Understanding the mask mechanism in Transformer

Tips

This article is suitable for students who have carefully read the Transformer paper and source code. The code used is the Pytorch version, and it only records my personal understanding during the learning process. If there are any mistakes, please communicate in the comment area. Thank you for your corrections.

Prefixed content

 def forward(self, src_seq, trg_seq):

        src_mask = get_pad_mask(src_seq, self.src_pad_idx) #Only perform attention calculation on the effective length, pad 0 requires mask
        trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & amp; get_subsequent_mask(trg_seq) #Not only masks the padding part, but also masks the upper triangle

        enc_output, *_ = self.encoder(src_seq, src_mask)
        dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask) #This trg_seq refers to the target language
        seq_logit = self.trg_word_prj(dec_output)
        if self.scale_prj:
            seq_logit *= self.d_model ** -0.5

        return seq_logit.view(-1, seq_logit.size(2))

We know that src_seq is the complete sequence of the sentence, and trg_seq is the target sequence, which is decoded by the decoder one word at a time. src_mask is a mask that eliminates the impact of padding elements used to fill the length on attention (corresponding to the function get_pad_mask). In addition to this mask, trg_mask also uses a mask that eliminates elements that have not yet been decoded. The masking mechanism of word influence (corresponding to the function get_subsequent_mask). Let’s take a look at how these masking mechanisms work in encoder and decoder.

Multi-head attention masking mechanism in encoder

Assume that the size of src_seq is batch_size=2, seq_len=3, and 0 is the padding element that needs to be masked. Let’s take the masking mechanism in the attention mechanism in the encoder as an example, and give an example that has the same operating principle but is simple and easy to understand:

import torch

def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)

if __name__ == "__main__":
    seq = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3, padding_idx=0, torch.Size([2, 3])
    embedding = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0) # Encode each word, the dimension of each word is 10
    query, key = embedding(seq), embedding(seq) # torch.Size([2, 3, 10]), torch.Size([2, 3, 10])
    att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 3, 3])
    """
    att:
    tensor([[[6.3899, 0.5517, 0.0000],
             [0.5517, 6.4545, 0.0000],
             [0.0000, 0.0000, 0.0000]],

            [[14.2199, -1.2504, -3.7615],
             [-1.2504, 8.2810, 0.3213],
             [-3.7615, 0.3213, 12.7485]]])
    """
    mask = get_pad_mask(seq, 0) # torch.Size([2, 1, 3])
    """
    mask:
    tensor([[[True, True, False]],

            [[True, True, True]]])
    """
    masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])
    """
    masked_att:
    tensor([[[6.3899e + 00, 5.5172e-01, -1.0000e + 09],
             [5.5172e-01, 6.4545e + 00, -1.0000e + 09],
             [0.0000e + 00, 0.0000e + 00, -1.0000e + 09]],

            [[1.4220e + 01, -1.2504e + 00, -3.7615e + 00],
             [-1.2504e + 00, 8.2810e + 00, 3.2127e-01],
             [-3.7615e + 00, 3.2127e-01, 1.2749e + 01]]])
    """

noticed

masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])

This line of code uses the broadcast mechanism, that is, the size of the mask is broadcast from [2,1,3] to [2,3,3], and then corresponds to the attention matrix att. The mask after broadcasting is as follows:

 """
    mask:
    tensor([[[True, True, False],
             [True, True, False],
             [True, True, False]],

            [[True, True, True],
             [True, True, True],
             [True, True, True]]])
    """

The first multi-head attention masking mechanism in decoder

class DecoderLayer(nn.Module):
    ''' Compose with three layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(
            self, dec_input, enc_output,
            slf_attn_mask=None, dec_enc_attn_mask=None):
        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output = self.pos_ffn(dec_output)
        return dec_output, dec_slf_attn, dec_enc_attn

The above is the part of the code corresponding to the decoder (Layers.py), self.slf_attn and self.enc_attn correspond to the first and second sublayer respectively. After reading the source code, we found that first of all, among the parameters passed into forward, slf_attn_mask corresponds to trg_mask, and dec_enc_attn_mask corresponds to src_mask. src_mask is the mask obtained through the get_pad_mask function, and trg_mask is the mask obtained through the get_pad_mask and dec_enc_attn_mask functions. Let’s simulate the first sublayer, give an example:

import torch

def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)

def get_subsequent_mask(seq):
    batch_size, seq_len = seq.size()
    mask = 1- torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
    return mask

if __name__ == "__main__":
    seq = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3, padding_idx=0, torch.Size([2, 3])
    embedding = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0) # Encode each word, the dimension of each word is 10
    query, key = embedding(seq), embedding(seq) # torch.Size([2, 3, 10]), torch.Size([2, 3, 10])
    att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 3, 3])
    """
    att:
    tensor([[[6.3899, 0.5517, 0.0000],
             [0.5517, 6.4545, 0.0000],
             [0.0000, 0.0000, 0.0000]],

            [[14.2199, -1.2504, -3.7615],
             [-1.2504, 8.2810, 0.3213],
             [-3.7615, 0.3213, 12.7485]]])
    """
    p_mask = get_pad_mask(seq, 0) # torch.Size([2, 1, 3])
    """
    p_mask:
    tensor([[[True, True, False]],

            [[True, True, True]]])
    """
    s_mask = get_subsequent_mask(seq) # torch.Size([2, 3, 3])
    """
    s_mask:
    tensor([[[1, 0, 0],
         [1, 1, 0],
         [1, 1, 1]],

        [[1, 0, 0],
         [1, 1, 0],
         [1, 1, 1]]])
    """
    mask = p_mask & s_mask # torch.Size([2, 3, 3])
    """
    tensor([[[1, 0, 0],
         [1, 1, 0],
         [1, 1, 0]],

        [[1, 0, 0],
         [1, 1, 0],
         [1, 1, 1]]])
    """
    masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])
    """
    masked_att:
    tensor([[[ 7.0830e + 00, -1.0000e + 09, -1.0000e + 09],
             [-2.7239e + 00, 1.5791e + 01, -1.0000e + 09],
             [0.0000e + 00, 0.0000e + 00, -1.0000e + 09]],
    
            [[ 1.4356e + 01, -1.0000e + 09, -1.0000e + 09],
             [4.3740e + 00, 9.3937e + 00, -1.0000e + 09],
             [5.1368e + 00, 1.1749e + 00, 5.1988e + 00]]])
    """

Compared with the mask mechanism in the encoder, there is just an additional upper triangular mask. In fact, in order to be more rigorous in this example, the sequence length is best set to be shorter than the encoder. This is a bit lazy.

The second multi-head attention masking mechanism in decoder

The second sublayer uses the same mask as the encoder. The upper triangle mask is not used. Note that the Q used here comes from the decoder, and K and V come from the encoder. For the attention mechanism here, here is an example:

import torch

def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)

if __name__ == "__main__":
    seq_k = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3, padding_idx=0 torch.Size([2, 3])
    seq_q = torch.LongTensor([[4, 5], [6, 7]]) # batch_size=2, seq_len=2, padding_idx=0 torch.Size([2, 2])
    embedding_k = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0)
    embedding_q = torch.nn.Embedding(num_embeddings=8, embedding_dim=10, padding_idx=0)
    query = embedding_q(seq_q) #torch.Size([2, 3, 10])
    key = embedding_k(seq_k) #torch.Size([2, 3, 10])
    att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 2, 3])
    print(att)
    """
    att:
    tensor([[[-2.6561, -3.2418, 0.0000],
             [4.2412, -2.5950, 0.0000]],
    
            [[ 3.1960, 9.1766, 0.6027],
             [-4.0462, -1.4987, -0.4528]]])
    """
    mask = get_pad_mask(seq_k, 0) # torch.Size([2, 1, 3])
    """
    tensor([[[ True, True, False]],

        [[ True, True, True]]])
    """
    att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 2, 3])
    """
    tensor([[[-2.3804e-04, -1.8567e-01, -1.0000e + 09],
         [4.5441e-01, 1.8053e-01, -1.0000e + 09]],

        [[-1.8674e + 00, 2.6307e + 00, 2.6570e + 00],
         [-1.5631e + 00, 2.2473e-02, 3.7925e + 00]]])
    """

The principle is similar to before, except that the sequence length of Q will be different from K.

Coding ideas of mask mechanism

Reference links:

Full solution to Mask in NLP_What does mask represent in natural language processing-CSDN Blog

The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge. Python entry skill treeHomepageOverview 384,263 people are learning the system