现代 AI 的“记忆碎片”:KV Cache 深度解析

在讨论 LLM 推理优化时,PagedAttention 和 Speculative Decoding 经常被提及,但它们共同作用的核心对象只有一个:KV Cache(Key-Value Cache)。如果你想理解为什么大模型推理如此吃显存,以及为什么推理速度会随着上下文增加而变慢,KV Cache 是唯一的入口。

🔥

现代 AI 的“记忆碎片”:KV Cache 深度解析

在讨论 LLM 推理优化时,PagedAttention 和 Speculative Decoding 经常被提及,但它们共同作用的核心对象只有一个:KV Cache(Key-Value Cache)。如果你想理解为什么大模型推理如此吃显存,以及为什么推理速度会随着上下文增加而变慢,KV Cache 是唯一的入口。

什么是 KV Cache?

LLM 的核心是 Transformer 架构。在生成文本时,模型采用的是自回归(Autoregressive)模式:每生成一个新 token,都需要将之前所有生成的 token 作为输入重新喂给模型。

在 Transformer 的注意力机制(Attention)中,每个 token 都会产生三个向量:Query (Q)、Key (K) 和 Value (V)。
- Query:当前 token “想要寻找什么”。
- Key:历史 token “能提供什么”。
- Value:历史 token “实际包含的内容”。

计算过程是:$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$。

关键点在于:对于已经生成过的历史 token,它们的 $K$ 和 $V$ 向量在后续的生成步骤中是完全不变的。如果每次生成新 token 都要重新计算一遍所有历史 token 的 $K$ 和 $V$,那么计算量将随序列长度呈平方级增长 $\mathcal{O}(n^2)$。

为了避免这种重复计算,我们将每一步产生的 $K$ 和 $V$ 存储在显存中,这就是 KV Cache。这样,每一步只需要计算当前新 token 的 $Q, K, V$,然后直接从缓存中读取之前的 $K, V$ 进行矩阵乘法即可。计算复杂度因此降为 $\mathcal{O}(n)$。

KV Cache 的代价:显存黑洞

虽然 KV Cache 极大地提升了速度,但它带来了巨大的显存压力。

1. 计算公式

一个模型的 KV Cache 大小取决于:
- $\text{batch_size}$ (并发数)
- $\text{seq_len}$ (序列长度)
- $\text{num_layers}$ (层数)
- $\text{num_heads}$ (注意力头数)
- $\text{head_dim}$ (每个头的维度)
- $\text{precision}$ (精度,如 FP16 为 2 bytes)

公式为:$\text{Memory} = 2 \times \text{batch_size} \times \text{seq_len} \times \text{num_layers} \times \text{num_heads} \times \text{head_dim} \times \text{precision}$
(乘以 2 是因为要同时存储 Key 和 Value)。

2. 实战量化

以 Llama-3-8B 为例(假设 FP16):
- 层数: 32, 头数: 32, 每个头维度: 128.
- 单个 token 在单层产生的 KV 大小 = $2 \times 32 \times 128 \times 2\text{ bytes} = 16\text{ KB}$。
- 全模型单 token KV 大小 = $32\text{ layers} \times 16\text{ KB} = 512\text{ KB}$。

看起来不多?但如果 batch_size=32seq_len=4096
$32 \times 4096 \times 512\text{ KB} \approx 67\text{ GB}$。

这意味着即使模型权重本身只占约 15GB,为了支持高并发的长文本推理,你可能需要一张 A100 (80GB) 或更多显卡,仅仅是为了存放这些“记忆碎片”。

如何优化 KV Cache?

面对显存压力,工业界演进出了三种主流方案:

MQA 与 GQA(结构优化)

传统的 Multi-Head Attention (MHA) 每个 Query 头对应一个 Key/Value 头。
- MQA (Multi-Query Attention):所有 Query 头共享一对 KV 头。显存占用直接降低到原来的 $1/\text{num_heads}$,但精度有损。
- GQA (Grouped-Query Attention):折中方案。将 Query 分组,每组共享一对 KV 头(如 Llama-3 使用)。在保持性能的同时显著降低显存占用。

PagedAttention(内存管理优化)

传统的 KV Cache 要求连续内存空间,导致严重的内存碎片化(Internal Fragmentation)。vLLM 推出的 PagedAttention 将 KV Cache 分页存储(类似操作系统的虚拟内存),允许非连续存储且动态按需分配,将显存利用率提升至接近 100%。

量化(精度优化)

将 KV Cache 从 FP16 量化到 INT8 或 FP8,甚至 INT4。这可以直接将显存占用减半或更多,而对模型生成质量的影响在可接受范围内。

总结

KV Cache 是 LLM 推理的“空间换时间”典范。它解决了自回归生成的冗余计算问题,但也成为了限制吞吐量和上下文长度的最大瓶颈。从 GQA 到 PagedAttention 再到量化,AI 系统工程的演进本质上就是在与这块巨大的“记忆碎片”做斗争。

留言区

欢迎分享你的想法!

发表留言

0/500

加载留言中…