Skip to content

大模型推理加速——KV Cache篇

来自:AiGC面试宝典

📝通俗解释:这篇文章主要介绍大模型推理时如何通过缓存技术来加速生成文本的过程。就像我们在做数学题时,如果前面算过的结果后面还要用到,就记下来不用重新算,KV Cache就是这个道理。


一、介绍一下 KV Cache 是啥?

KV Cache 中的 K 和 V 分别指的是 Attention 机制中的 Key 和 Value 的状态值。KV Cache 只出现在 Transformer 结构的自回归 Decoder 中,像 BERT 就没有 KV Cache。KV Cache 的存在是为了避免 Scaled Dot-Product Attention 过程中的重复计算。

该技术可以在不影响任何计算精度的前提下,通过空间换时间的思想,提高推理性能

📝通俗解释:可以把KV Cache想象成一个小抄本。在生成句子时,每次预测下一个字都需要用到前面所有字的"特征信息"(K和V),如果没有小抄,每次都要重新计算前面所有字的特征,很浪费时间。有了小抄,把之前计算过的特征存起来,下次直接查表就行了,省时但占地方。


二、为什么要进行 KV Cache?

2.1 不使用 KV Cache 场景

给定“天气”,模型会逐个预测剩下的字,假设接下来预测的两个字为“真好”。

注意:下面的示例图只给出了和 KV Cache 相关的细节。

  1. 第一步会预测“真”

![大模型推理加速流程图](图片描述:这是一个展示自注意力机制(Self-Attention)计算过程的流程图。

  1. 输入层:左侧有一个输入矩阵 $X$ (2x4),包含两行数据,分别对应“天”和“气”的向量表示。
  2. 权重矩阵:中间列出了三个权重矩阵 $W_Q$ (4x4), $W_K$ (4x4), $W_V$ (4x4)。
  3. Q, K, V 计算
    • $X$ 与 $W_Q$ 相乘得到 $Q$ 矩阵 (2x4)。
    • $X$ 与 $W_K$ 相乘得到 $K$ 矩阵 (2x4)。
    • $X$ 与 $W_V$ 相乘得到 $V$ 矩阵 (2x4)。
  4. Attention 计算
    • $Q$ 与 $K^T$ 相乘得到 $Q@K^T$ (2x2)。
    • 进行 scale 操作(除以 $\sqrt{d_k}$,图中显示除以2)。
    • 应用 mask(图中显示右上角元素被 mask 为 $-\infty$)。
    • 经过 softmax 函数得到 attention scores (2x2)。
    • scores 与 $V$ 相乘得到 output (2x4)。
  5. 输出:output 经过 feedforward 层,最终输出预测结果(图中显示为“真”)。)

下面是上图计算流程的代码实现:

python
import torch
import torch.nn.functional as F

X = torch.tensor(
    [[0.1, 0.3, 0.2, -0.1],
     [0.2, -0.1, 0.4, 0.5]]
)
W_Q = torch.tensor(
    [[0.3, -0.3, -0.1, 0.2],
     [0.2, 0.4, 0.1, 0.3],
     [0.1, 0.2, 0.3, 0.5],
     [-0.3, 0.3, 0.4, -0.5]]
)
W_K = torch.tensor(
    [[0.1, 0.8, 0.2, -0.3],
     [0.6, 0.5, -0.3, 0.1],
     [-0.4, 0.3, 0.7, 0.2],
     [0.6, -0.1, 0.2, 0.3]]
)
W_V = torch.tensor(
    [[-0.7, 0.5, -0.9, 0.1],
     [0.1, 0.8, 0.4, 0.3],
     [0.4, 0.2, -0.4, 0.5],
     [0.1, 0.2, 0.1, -0.4]]
)

Q = torch.matmul(X, W_Q)
K = torch.matmul(X, W_K)
V = torch.matmul(X, W_V)
scores = torch.matmul(Q, K.T) / 2
scores += torch.tensor(
    [[0, float('-inf')],
     [0, 0],
    ])
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, V)

output 再经过 feedforward 等步骤最终得到预测的 token "真";

  1. 第二步会将"真"拼接到"天气"的后面,即新的输入为"天气真",再预测"好"

Attention计算流程图 image_description

(注:图片展示了一个Attention机制的计算流程。左侧输入矩阵 $X$ 包含“天”、“气”、“真”三个token的向量。分别通过 $W_Q, W_K, W_V$ 矩阵乘法得到 $Q, K, V$。中间部分展示了 $Q$ 与 $K^T$ 相乘,除以 scale,进行 mask 操作(下三角掩码,填充 $-inf$),经过 softmax 得到注意力权重,最后与 $V$ 相乘得到输出。右侧展示了输出经过 feedforward 层得到最终结果。)

下面是上图计算流程的代码实现:

python
X = torch.tensor(
    [[0.1, 0.3, 0.2, -0.1],
     [0.2, -0.1, 0.4, 0.5],
     [0.4, 0.2, 0.3, -0.1]]
)
Q = torch.matmul(X, W_Q)
K = torch.matmul(X, W_K)
V = torch.matmul(X, W_V)
scores = torch.matmul(Q, K.T) / 2
scores += torch.tensor(
    [[0, float('-inf'), float('-inf')],
     [0, 0, float('-inf')],
     [0, 0, 0]]
)
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, V)

同样的,output 再经过 feedforward 等步骤最终得到预测的 token "好";

📝通俗解释:可以看到,第二步预测"好"的时候,实际上把"天气真"三个字都输入模型重新计算了一遍。但问题是:第一步已经算过"天气"的K和V了,第二步为什么要重新算?所以KV Cache的思路就是——不算重复的账,把之前算过的K和V存起来。

2.2 使用 KV Cache 场景

观察上面的计算过程,可以看到,在第二步的预测中,"好"的预测只和"真"以及完整的 $K, V$ 有关。

于是,KV Cache 的想法就很直观了,缓存上一轮的 $K, V$,即可达到减少计算,提速的效果。从第二步开始时,只需输入当前位置的 token,得到当前位置对应的 $K_{cur}, V_{cur}$,再拼接上一步缓存的 $K_{last}, V_{last}$ 得到完整的 $K, V$,即可完成下一个 token 的预测。下图是在上图的基础上只保留和预测"好"相关的数据:

![图表描述:这是一个展示Transformer模型中Attention机制计算流程的示意图。左侧展示了输入矩阵 $X$ 分别与权重矩阵 $W_Q$、$W_K$、$W_V$ 进行矩阵乘法($X@W_Q$ 等),生成 $Q$、$K$、$V$ 矩阵。中间部分展示了 $Q$ 与 $K$ 的转置进行乘法运算($Q@K^T$),经过 scale 缩放和 softmax 激活函数处理,得到注意力分数,再与 $V$ 相乘(score@V)。右侧展示了结果经过 feedforward 前馈神经网络层,最终输出结果矩阵。图中用不同颜色的方块代表矩阵中的具体数值。]

即当前轮输出 token 与输入 tokens 拼接,并作为下一轮的输入 tokens,反复多次。可以看出第 i+1 轮输入数据只比第 i 轮输入数据新增了一个 token,其他全部相同!

因此第 i+1 轮推理时必然包含了第 i 轮的部分计算。KV Cache 的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果,就是这么简单,不存在什么 Cache Miss 问题。

📝通俗解释:就像做菜时,第一道菜炒完了,锅里的调料和半成品还在,你想做第二道菜时,直接用现成的就行了,不用每次都从洗菜切菜开始。KV Cache就是那个"锅",把之前算好的中间结果存着,下次直接用。


三、说一下 KV Cache 在大模型中的应用?

3.1 KV Cache 在 Llama 推理流程中应用?

代码来自 https://github.com/facebookresearch/llama

LLaMA 类是对模型和 tokenizer 的封装,只实现了 generate 方法,这个方法主要接受 prompt 列表。

python
class LLaMA:
    def __init__(self, model: Transformer, tokenizer: Tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(
        self,
        prompts: List[str],
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
    ) -> List[str]:
        pass

下面是一个调用 LLaMA generate 方法的示例:

python
prompts = [
    "天气",
    "你好",
]
generator = LLaMA(model, tokenizer)
results = generator.generate(prompts)

对应的 generate 方法的具体实现:

python
class LLaMA:
    def generate(
        self,
        prompts: List[str],
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
    ) -> List[str]:
        bsz = len(prompts)
        params = self.model.params
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        # step 1: 将 prompt 处理 (tokenizer.encode) 成 prompt_tokens
        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
        min_prompt_size = min([len(t) for t in prompt_tokens])
        max_prompt_size = max([len(t) for t in prompt_tokens])
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

        # step 2: 构造一个大小为 (bsz, total_len) 且初始值为 tokenizer.pad_id 的张量 tokens
        tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()

        # step 3: 将 prompt_tokens 赋值给对应位置的 tokens
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t).long()

        # step 4: 循环预测下一个 token: 只有第一次预测时会将前面所有的 token (例如"天气") 输入给模型,从第二次预测开始只将当前的 token (例如"真") 输入给模型
        input_text_mask = tokens != self.tokenizer.pad_id
        start_pos = min_prompt_size
        prev_pos = 0
        for cur_pos in range(start_pos, total_len):
            # 调用 模型 forward 函数
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            prev_pos = cur_pos

        # step 5: 预测结束后,将 token 转成 (tokenizer.decode) 字符串
        decoded = []
        for i, t in enumerate(tokens.tolist()):
            # cut to max gen len
            t = t[: len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            try:
                t = t[: t.index(self.tokenizer.eos_id)]
            except ValueError:
                pass
            decoded.append(self.tokenizer.decode(t))
        return decoded

下面介绍模型的 forward 函数:

python
class Transformer(nn.Module):

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        # step 1: 对 seqlen 进行判断,只有第一次预测下一个 token 时才会初始化 mask(输入的长度大于1),因为第二次开始每次只会输入当前位置的 token
        if seqlen > 1:
            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)
        output = self.output(h[:, -1, :])  # only compute last logits
        return output.float()

之所以可以只输入当前位置的 token 就可以预测下一个 token,是因为缓存了 K, V。 具体实现可以看一下 Attention 类:

python
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        # 此处省略了和 KV cache 无关的初始化代码
        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.shape
        # step 1: 计算当前位置 token 对应的 xq, xk, xv
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # step 2: 将 xk, xv 缓存到对应的 cache_k, cache_v 中
        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # step 3: 使用 xk 与前面所有的 k 计算 score,再与 v 进行计算
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

📝通俗解释:这段代码展示了LLaMA模型中实际使用KV Cache的过程。cache_k和cache_v是两个大数组,用来存储之前所有token的K和V值。每次生成新token时,只需要计算当前token的K和V,然后和之前缓存的值拼在一起做注意力计算。这样就避免了重复计算之前所有token的K和V。


四、KV Cache 优点?

  1. 避免重复计算:避免每次采样 token 时重新计算键值向量,利用预先计算好的 K 值和 V 值,可以节省大量计算时间。
  2. 提升推理速度:通过空间换时间的思想,显著提升自回归生成过程中的推理效率。
  3. 不影响精度:缓存的是精确的 K 和 V 值,不会引入额外的计算误差。

📝通俗解释:KV Cache最大的好处就是"快"。比如生成一段100字的文字,原来可能需要计算100×100=10000次注意力,现在有了缓存,可能只需要计算100+99+98+...+1=5050次,速度提升将近一倍,而且结果完全一样。


五、 KV Cache 缺点?

  1. 占用显存:需要存储所有历史 token 的 K 和 V 值,随着序列长度增加,显存占用会不断增长。
  2. 显存压力大:在长序列生成或多 batch 推理时,KV Cache 可能成为显存瓶颈。
  3. 管理复杂:需要动态管理缓存的分配和释放,增加了系统复杂度。

📝通俗解释:KV Cache的缺点就是"占地方"。生成100个字的句子需要存储前面99个字的特征,生成1000个字就要存前面999个字的特征。如果同时生成很多句子,或者句子很长,显存很快就满了,就像你的手机内存不够用一样。


六、 KV Cache 优化策略?

尽可能减少推理过程中 KV 键值对的重复计算,实现 KV Cache 的优化。目前减少 KV Cache 的手段有许多,比如 PageAttention、PQA(分页注意力)、MQA(多查询注意力)等,另外 FlashAttention 可以通过硬件内存使用的优化,提升推理性能。

📝通俗解释:研究人员想了很多办法来解决KV Cache"占地方"的问题:有的想办法让存储更高效(像PageAttention),有的想办法减少需要存储的数据量(像MQA),有的想办法让读写更快(像FlashAttention)。

6.1 PageAttention 显存优化

  • 动机:在缓存中,这些 KV Cache 都很大,并且大小是动态变化的,难以预测。已有的系统中,由于显存碎片和过度预留,浪费了 60%-80% 的显存。
  • 解决方法:作为 VLLM 核心技术,PageAttention 通过对显存碎片化问题进行处理,以达到减少显存占用,提高 KV Cache 可使用的显存空间,提升推理性能。
  • 优化策略
  1. PageAttention 借助 OS 系统中虚拟内存和分页的思想。可以实现在不连续的空间存储连续的 KV 键值。

    图表描述: 左侧有一个 "Query vector" 指向单词 "for"。从 "for" 发出三条箭头,分别指向右侧表格中的不同行(Block)。 右侧表格标题为 "Key and value vectors",表格内容如下:

    Block 1computerscientistandmathematician
    Block 2renownedfor
    Block 0AlanTuringisa
  2. 因为所有键值都是分布存储的,需要通过分页管理彼此的关系。序列的连续逻辑块通过 Block Table 映射到非连续物理块。

  3. 同一个 prompt 生成多个输出序列,可以共享计算过程中的 Attention 键值,实现 Copy-on-Write 机制,即只有需要修改的时候才会复制,从而大大降低显存占用。

图表描述: 该图展示了 Seq A 和 Seq B 共享同一个 Physical KV Cache Blocks 的过程。

  • 中间表格 (Physical KV Cache Blocks):存储了实际的 KV 数据。第一行填入了 "The", "future", "of", "artificial";第二行填入了 "intelligence", "is"。
  • 左侧表格 (Seq A - Logical KV Cache Blocks):逻辑块包含 "The", "future", "of", "artificial", "intelligence", "is"。箭头指向中间的 Physical Blocks,表示映射关系。
  • 右侧表格 (Seq B - Logical KV Cache Blocks):逻辑块同样包含 "The", "future", "of", "artificial", "intelligence", "is"。箭头也指向中间的 Physical Blocks,表示两个序列共享了相同的物理内存块来存储 prompt 的 KV 值。

📝通俗解释:PageAttention的灵感来自电脑操作系统的虚拟内存技术。就像操作系统可以把一个程序需要的内存分散存储在物理内存的不同位置,PageAttention也可以把KV Cache分散存放在显存各处,然后用一张"地址表"把它们串起来。这样就不需要提前预留一大块连续显存,显存利用率大大提高。另外,多个相似的输出序列可以共享同一份KV数据,就像两个人同时写句子,都可以用同一本词典。

6.2 MHA、GQA、MQA 优化技术

  • MHA(Multi-Head Attention):标准的多头注意力机制,h 个 Query、Key 和 Value 矩阵。
  • MQA(Multi-Query Attention):让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
  • GQA(Grouped-Query Attention):将查询头分成 N 组,每个组共享一个 Key 和 Value 矩阵。

![图表:展示了 Multi-head, Grouped-query, Multi-query 三种注意力机制的结构对比。左侧 Multi-head 中 Queries, Keys, Values 数量一致且一一对应;中间 Grouped-query 中多个 Queries 分组共享少量的 Keys 和 Values;右侧 Multi-query 中所有 Queries 共享同一组 Keys 和 Values。]

GQA 以及 MQA 都可以实现一定程度的 Key Value 共享,从而可以使模型体积减小,GQA 是 MQA 和 MHA 的折中方案。

这两种技术的加速原理是:

  1. 减少了数据的读取
  2. 减少了推理过程中的 KV Cache

📝通俗解释:可以把这三种技术想象成不同规模的"共享办公室":

  • MHA就像每人一个独立办公室,资源充足但成本高
  • MQA像所有人都共用一个会议室和资料室,最省钱但要排队
  • GQA则像把员工分成几组,每组共用一个会议室,平衡了效率和成本

关键是:Key和Value需要存储到显存里,Query不需要。所以减少Key和Value的复制次数,就能大大节省显存。

需要注意的是 GQA 和 MQA 需要在模型训练的时候开启,按照相应的模式生成模型。

6.3 FlashAttention 优化技术

FlashAttention 推理加速技术是利用 GPU 硬件非均匀的存储器层次结构实现内存节省和推理加速,意思是通过合理的应用 GPU 显存实现 IO 的优化,从而提升资源利用率,提高性能。

![图表:内存层级金字塔图 (Memory Hierarchy with Bandwidth & Memory Size)。从上到下依次为:

  1. GPU SRAM: SRAM: 19 TB/s (20 MB)
  2. GPU HBM: HBM: 1.5 TB/s (40 GB)
  3. Main Memory (CPU DRAM): DRAM: 12.8 GB/s (>1 TB) ]

计算速度越快的硬件往往越昂贵且体积越小,FlashAttention 的核心原理是尽可能合理地应用 SRAM 内存计算资源。

A100 GPU 有 40-80GB 的高带宽内存(HBM),带宽为 1.5-2.0 TB/s,而每 108 个流处理器有 192KB 的 SRAM,带宽估计在 19TB/s 左右。也就是说,存在一种优化方案是利用 SRAM 远快于 HBM 的性能优势,将密集计算尽可能放在 SRAM,减少与 HBM 的反复通信,实现整体的 IO 效率最大化。比如可以将矩阵计算过程、softmax 函数尽可能在 SRAM 中处理并保留中间结果,全部计算完成后再写回 HBM,这样就可以减少 HBM 的写入写出频次,从而提升整体的计算性能。如何有效分割矩阵的计算过程,涉及到 FlashAttention 的核心计算逻辑 Tiling 算法。

📝通俗解释:FlashAttention的思路就像这样:如果你要处理一堆文件,最快的方法是把常用文件放在手边(SRAM),而不是每次都去远处的大仓库(HBM)翻找。SRAM虽然很小但速度极快,HBM虽然很大但速度慢很多。FlashAttention就是把计算过程拆分成小块,在快速的SRAM里算完再一次性存回HBM,减少了反复跑腿的时间。


七、GPT 模型单次 inference 输入生成下一个 token,为什么会产生 kv-cache?

因为 GPT 每次 inference 只能生成一个 token,所以要输出一句话,会进行多次 inference,直到遇到终止 token。这样每次计算输出的 token 的概率分布时,都需要把之前生成的 token 重新计算 key 和 value。第 i+1 次输入的 token 只比第 i 次输入 token 多一个新的 token,因此第 i+1 轮推理就包含了第 i 次推理的很多计算。对于第 i 次推理,只需要再计算对应的 $k_{i+1}$ 和 $v_{i+1}$ 即可,所以可以把之前的 k 和 v 缓存下来,不需要每次都计算之前的 k 和 v,所以就有了 KV Cache。这是一种空间换时间的方法。

📝通俗解释:就像你说一句话,每次只能说一个字。说完"天"要说"气",说完"天气"要说"真"...每次说新字的时候,你都需要回想之前说过的所有字的意思。如果每次都重新理解之前所有字,太累了。不如拿个小本本把之前每个字的"理解"记下来(这就是KV Cache),下次直接查小本本就行了。


八、为什么 kv cache 会造成显存很大?

因为随着推理输出的序列变长,需要缓存的 KV Cache 也越多。以 GPT3 为例,GPT3 模型占用显存大小约为 350GB。假设批次大小 $b=64$,输入序列长度 $s=512$,输出序列长度 $n=32$,则 KV Cache 占用显存为 164GB,大约是模型参数显存的 0.5 倍。

📝通俗解释:想象一下,你每说一个字,都要在一本大字典里记录这个字的所有"特征信息"。说一句100字的话,就要记录100个字的特征。如果同时让64个人每人说一句100字的话,需要记录6400个字的特征。这本"字典"占用的空间是非常大的,可能比模型本身还要大!


总结

优化技术核心思想主要优势
KV Cache缓存K和V避免重复计算,提升速度
PageAttention分页管理显存减少显存碎片,提高利用率
MQA/GQA共享K和V减少显存占用和模型参数量
FlashAttention利用SRAM加速减少IO开销,提升计算效率

📝通俗解释:这些技术就像给大模型推理"提速"和"省空间"的各种工具。有的是换算法(KV Cache),有的是换存储方式(PageAttention),有的是精简数据(MQA/GQA),有的是优化硬件使用(FlashAttention)。实际应用中往往会组合使用多种技术来达到最佳效果。

基于 MIT 许可发布