KV Cache
下面默认讨论单层因果 self-attention。设当前已经生成了 \(t-1\) 个 token,它们对应的 hidden states 为:
\[ X_{1:t-1} \in \mathbb{R}^{(t-1) \times d} \]
对于第 \(t\) 个 token,记其 hidden state 为 \(x_t \in \mathbb{R}^{1 \times d}\)。
什么是 KV Cache
在 attention 中,
\[ Q = XW_Q,\quad K = XW_K,\quad V = XW_V \]
其中 \(W_Q, W_K, W_V \in \mathbb{R}^{d \times d}\)。
在自回归生成时,第 \(t\) 步只需要用当前 token 的 query 去和历史 token 的 key、value 交互:
\[ q_t = x_t W_Q,\quad K_{1:t} = X_{1:t} W_K,\quad V_{1:t} = X_{1:t} W_V \]
所谓 KV cache,就是把历史 token 的
\[ K_{1:t-1},\quad V_{1:t-1} \]
保存下来。这样在第 \(t\) 步就不必重新计算前 \(t-1\) 个 token 的 key 和 value,只需要计算当前 token 的
\[ k_t = x_t W_K,\quad v_t = x_t W_V \]
然后拼接到缓存后面即可。
为什么需要 KV Cache
如果不使用 KV cache,那么在第 \(t\) 步生成时,需要把整个前缀 \(X_{1:t}\) 再做一遍投影:
\[ K_{1:t} = X_{1:t}W_K,\quad V_{1:t} = X_{1:t}W_V \]
这意味着历史 token 的 key 和 value 会被反复重复计算。
如果使用 KV cache,那么第 \(t\) 步只需要新增:
\[ k_t = x_t W_K,\quad v_t = x_t W_V \]
历史部分 \(K_{1:t-1},V_{1:t-1}\) 可以直接复用。因此,KV cache 的核心作用是:
- 避免在 decode 阶段重复计算历史 token 的 \(K,V\)
- 把自回归生成从“反复处理整个前缀”变成“只处理当前 token,再读取历史缓存”
复杂度视角
设当前上下文长度为 \(L\)。
不使用 KV Cache
第 \(t\) 步需要重新计算长度为 \(L\) 的整段前缀,因此单层计算复杂度仍接近 prefill:
\[ O(L d^2 + L^2 d) \]
使用 KV Cache
第 \(t\) 步只需要:
- 计算当前 token 的 \(Q,K,V\) 投影,复杂度为 \(O(d^2)\)
- 用当前 query 与历史 \(L\) 个 key/value 交互,复杂度为 \(O(L d)\)
因此单步 decode 复杂度变为:
\[ O(d^2 + L d) \]
这就是 KV cache 在长文本生成中非常重要的原因。
KV Cache 的代价
KV cache 并不是免费得到的。为了避免重复计算,需要额外保存每一层历史 token 的 key 和 value。
单层缓存大小约为:
\[ O(L d) \]
若模型共有 \(N\) 层,则总缓存空间约为:
\[ O(N L d) \]
因此,KV cache 用线性增长的显存换取了更快的 decode 速度。
小结
KV cache 就是在自回归生成时,把历史 token 的 key 和 value 保存下来,后续 token 直接复用这些结果,而不再重复计算整个前缀。它不会改变 prefill 的复杂度,但会显著降低 decode 阶段的重复计算,是大语言模型高效推理的关键机制之一。