Interpretation of KV Cache for Large Model Reasoning Performance Optimization

0. Introduction

Those who optimize the performance of large models must be familiar with KV Cache, so how much do we know about this technology? Please try to answer the following questions:

  • Which part of the calculation in the Self-Attention layer does KV Cache save?

  • Does the KV Cache affect the calculation amount of the MLP layer?

  • Does KV Cache affect the data transfer volume between blocks? This article intends to dissect the technology and give answers to the above questions.

1. What is KV Cache

A commonly used technology for large model inference performance optimization is KV Cache, which can improve inference performance by exchanging space for time without affecting any calculation accuracy. There are some analysis blogs about this technology on the Internet, but after reading it, you will still be confused, and you may even be biased, thinking that this Cache process is similar to the absurd conclusion of database reading or CPU Cache acceleration. At the beginning, I also had similar misunderstandings. It was not until I checked and ran the source code line by line that I clearly understood what was in the cache and how to save calculations.

2. Background

The reasoning process of the generative model is very characteristic. We give an input text, and the model will output an answer (length N). In fact, N times of reasoning are performed in this process. That is, the GPT class model only outputs one token for one inference, and the output token will be spliced together with the input tokens, and then used as the input for the next inference, and this is repeated until a terminator is encountered.

The above description is the GPT reasoning process that we usually recognize. The code description is as follows:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


model = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

#tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))

# inference
token_eos = torch. tensor([198]) # line break symbol
out_token = None
i = 0
with torch.no_grad():
    while out_token != token_eos:
        logits, _ = model(in_tokens)
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = torch.cat((in_tokens, out_token), 0)
        text = tokenizer. decode(in_tokens)
        print(f'step {i} input: {text}', flush=True)
        i + = 1

out_text = tokenizer. decode(in_tokens)
print(f'Input: {in_text}')
print(f'Output: {out_text}')

output:

step 0 input: Lionel Messi is a player
step 1 input: Lionel Messi is a player who
step 2 input: Lionel Messi is a player who has
step 3 input: Lionel Messi is a player who has been
step 4 input: Lionel Messi is a player who has been a
step 5 input: Lionel Messi is a player who has been a key
step 6 input: Lionel Messi is a player who has been a key part
step 7 input: Lionel Messi is a player who has been a key part of
step 8 input: Lionel Messi is a player who has been a key part of the
step 9 input: Lionel Messi is a player who has been a key part of the team
step 10 input: Lionel Messi is a player who has been a key part of the team's
step 11 input: Lionel Messi is a player who has been a key part of the team's success
step 12 input: Lionel Messi is a player who has been a key part of the team's success.
step 13 input: Lionel Messi is a player who has been a key part of the team's success.

 Input: Lionel Messi is a
Output: Lionel Messi is a player who has been a key part of the team's success.

Can you see the problem with the above calculation? The input tokens of each inference process become longer, resulting in an increase in inference FLOPs. Is there a way to make the FLOPs of the inference process basically constant or smaller? (Foreshadowing, note that it is basically constant).

3. Principle

In the above reasoning process, in each step, a token sequence is input, and the input token sequence is converted into a three-dimensional tensor [b, s, h] through the Embedding layer, after one-pass calculation, and finally through The logits layer maps the calculation results to the vocabulary space, and the output tensor dimension is [b, s, vocab_size].

The output tokens of the current round are spliced with the input tokens, and used as the input tokens of the next round, repeated many times. It can be seen that the input data of the i + 1 round is only one token more than the input data of the i round, and everything else is the same! Therefore, the i + 1 round of inference must include part of the calculation of the i round. The starting point of KV Cache is here. It caches the reusable calculation results of the current round, and directly reads the cached results in the next round of calculation. It is as simple as that, and there is no cache miss problem.

4. Implementation details

At present, all major model inferences have implemented KV Cache. Let’s see how to use it. We can modify the above code based on the main changes:

  • The past_key_values parameter is added during inference, which will save the K V value of each round in an additional way. The content of kvcache variable is ((k,v), (k,v), …, (k,v)), that is, it consists of L k,v A tuple of where k and v both have dimensions [b, n_head, s, head_dims]. By the way, it can be calculated that the amount of cache data corresponding to each round of reasoning is 2bshL, where the value of s is equal to the value of the current round. With the increase of output tokens, the amount of cache data shows a linear increase. Taking GPT3-175B as an example, assuming that float16 is used to save the KV cache, the length of the sequence is 100, and batchsize=1, then the video memory occupied by the KV cache is 2×100×12288×96×2 Byte = 472MB.

  • The token output by inference is directly used as the input of the next round, and no splicing is required, because the above information is already in kvcache.

Code example:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


model = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

#tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))

# inference
token_eos = torch. tensor([198]) # line break symbol
out_token = None
kvcache = None
out_text = in_text
i = 0
with torch.no_grad():
    while out_token != token_eos:
        logits, kvcache = model(in_tokens, past_key_values=kvcache) #Add a parameter of past_key_values
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = out_token # The output token is directly used as the input of the next round without splicing
        text = tokenizer. decode(in_tokens)
        print(f'step {i} input: {text}', flush=True)
        i + = 1
        out_text += text

print(f'Input: {in_text}')
print(f'Output: {out_text}')

Through the above code, only the changes at the call level can be seen. The implementation details depend on the underlying implementation of each framework. For example, the code implementation of the transformers library of Hugging Face is relatively refreshing. The relevant codes of the Attention part in modeling_gpt2.py are as follows:

query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None: # When the first token is output, layer_past is not None
            past_key, past_value = layer_past # Take out the key and value calculated before
            key = torch.cat((past_key, key), dim=-2) # past_key and the key splicing corresponding to the current token
            value = torch.cat((past_value, value), dim=-2) # past_value splicing with the value corresponding to the current token

        if use_cache is True:
            present = (key, value)
        else:
            present = None

There are also related codes at the block level, please take a closer look when you have time. Still the same sentence, it is better to read and run the source code once than say a thousand words and ten thousand.

In fact, after the KV Cache configuration is enabled, the reasoning process can be divided into two stages:

  • Pre-population stage: occurs during the calculation of the first output token. At this time, the Cache is empty. During the calculation, it is necessary to calculate and save the key cache and value cache for each transformer layer. When the token is output, the Cache is filled; FLOPs are the same as KV The cache is closed consistently, there are a large number of gemm operations, and the reasoning speed is slow.

  • Use the KV Cache stage: It occurs during the calculation of the second output token to the last token. At this time, the Cache is valuable. Each round of reasoning only needs to read the Cache, and at the same time append the new Key and Value calculated in the current round Write to the Cache; FLOPs are reduced, gemm becomes a gemv operation, and the inference speed is faster than that of the first stage, which belongs to the Memory-bound type calculation.

It may be more helpful to use a picture here. The picture below is a Decoder Block, which contains Self-Attention and MLP. The red part is the content affected by KV Cache, that is, after KV Cache is turned on, the red sequence length s becomes 1. When batch_size=1, the two denses in Self-Attention all become gemv operations, and the denses in MLP also become gemv operations. After reading this picture, you can answer the 3 questions above.

3e4e90f36eb7d83649ef2484b96f97cb.png

Decoder Block of GPT

The following link also has a quantitative analysis in this area, which is well written, and I recommend everyone to read it.

Convoluted Thomas x: Analyze the parameter amount, calculation amount, intermediate activation, and KV cache of the transformer model

5. Summary

KV Cache is an important engineering technology for Transformer inference performance optimization. All major inference frameworks have been implemented and encapsulated (for example, the generate function of transformers library has encapsulated it, and users do not need to manually pass in past_key_values) and it is enabled by default. (use_cache=True in the config.json file). This article attempts to open the package and analyze the internal implementation of the technology. I hope it will be helpful to everyone. If there are any mistakes in the article, please correct me.

Finally, I would like to advertise the open source project of our team. Adlik is a deep learning reasoning tool contributed by ZTE. It has been supported by the Linux AI Foundation and is still being improved. We look forward to your support and attention. In addition, we are also doing research and development of cutting-edge technologies in the direction of deep learning compilers and large models. Interested partners are welcome to join us~~

f651ac44f91844646e658ea85a5c8f7d.png