KV Cache

下面默认讨论单层因果 self-attention。设当前已经生成了 \(t-1\) 个 token,它们对应的 hidden states 为:

\[ X_{1:t-1} \in \mathbb{R}^{(t-1) \times d} \]

对于第 \(t\) 个 token,记其 hidden state 为 \(x_t \in \mathbb{R}^{1 \times d}\)

什么是 KV Cache

在 attention 中,

\[ Q = XW_Q,\quad K = XW_K,\quad V = XW_V \]

其中 \(W_Q, W_K, W_V \in \mathbb{R}^{d \times d}\)

在自回归生成时,第 \(t\) 步只需要用当前 token 的 query 去和历史 token 的 key、value 交互:

\[ q_t = x_t W_Q,\quad K_{1:t} = X_{1:t} W_K,\quad V_{1:t} = X_{1:t} W_V \]

所谓 KV cache,就是把历史 token 的

\[ K_{1:t-1},\quad V_{1:t-1} \]

保存下来。这样在第 \(t\) 步就不必重新计算前 \(t-1\) 个 token 的 key 和 value,只需要计算当前 token 的

\[ k_t = x_t W_K,\quad v_t = x_t W_V \]

然后拼接到缓存后面即可。

为什么需要 KV Cache

如果不使用 KV cache,那么在第 \(t\) 步生成时,需要把整个前缀 \(X_{1:t}\) 再做一遍投影:

\[ K_{1:t} = X_{1:t}W_K,\quad V_{1:t} = X_{1:t}W_V \]

这意味着历史 token 的 key 和 value 会被反复重复计算。

如果使用 KV cache,那么第 \(t\) 步只需要新增:

\[ k_t = x_t W_K,\quad v_t = x_t W_V \]

历史部分 \(K_{1:t-1},V_{1:t-1}\) 可以直接复用。因此,KV cache 的核心作用是:

  • 避免在 decode 阶段重复计算历史 token 的 \(K,V\)
  • 把自回归生成从“反复处理整个前缀”变成“只处理当前 token,再读取历史缓存”

复杂度视角

设当前上下文长度为 \(L\)

不使用 KV Cache

\(t\) 步需要重新计算长度为 \(L\) 的整段前缀,因此单层计算复杂度仍接近 prefill:

\[ O(L d^2 + L^2 d) \]

使用 KV Cache

\(t\) 步只需要:

  • 计算当前 token 的 \(Q,K,V\) 投影,复杂度为 \(O(d^2)\)
  • 用当前 query 与历史 \(L\) 个 key/value 交互,复杂度为 \(O(L d)\)

因此单步 decode 复杂度变为:

\[ O(d^2 + L d) \]

这就是 KV cache 在长文本生成中非常重要的原因。

KV Cache 的代价

KV cache 并不是免费得到的。为了避免重复计算,需要额外保存每一层历史 token 的 key 和 value。

单层缓存大小约为:

\[ O(L d) \]

若模型共有 \(N\) 层,则总缓存空间约为:

\[ O(N L d) \]

因此,KV cache 用线性增长的显存换取了更快的 decode 速度

小结

KV cache 就是在自回归生成时,把历史 token 的 key 和 value 保存下来,后续 token 直接复用这些结果,而不再重复计算整个前缀。它不会改变 prefill 的复杂度,但会显著降低 decode 阶段的重复计算,是大语言模型高效推理的关键机制之一。