Transformer 复杂度分析

下面默认讨论单层 Transformer block,记序列长度为 \(L\),hidden size 为 \(d\),head 数为 \(h\),每个 head 的维度为 \(d_h = d / h\)。若模型共有 \(N\) 层,则总复杂度再乘一个 \(N\)

训练计算复杂度

设输入为 \(X \in \mathbb{R}^{L \times d}\)

  • \(Q,K,V\) 投影和输出投影的复杂度都是 \(O(L d^2)\):因为它们都对应一个 \((L \times d)\) 矩阵与一个 \((d \times d)\) 矩阵的乘法,例如 \(Q = XW_Q\),其计算量为 \(O(L d^2)\)
  • 每个 head 计算 \(QK^T\) 的复杂度为 \(O(L^2 d_h)\),所有 head 合起来为 \(O(L^2 d)\)
  • attention 权重与 \(V\) 相乘的复杂度也是 \(O(L^2 d)\)

因此 self-attention 的主要计算复杂度为:

\[ O(L d^2 + L^2 d) \]

若再计入 FFN,则通常还会增加一个 \(O(L d^2)\) 量级的项。

训练内存复杂度

标准 attention 在训练时需要保存 attention score matrix:

\[ QK^T \in \mathbb{R}^{L \times L} \]

因此单层的主要激活内存复杂度为:

\[ O(L^2) \]

更准确地写,多头情况下是 \(O(h L^2)\);通常把 \(h\) 视为常数后简写为 \(O(L^2)\)

推理复杂度

推理通常分为 prefilldecode 两个阶段。

Prefill

对整段 prompt 做一次完整 attention,因此单层计算复杂度仍为:

\[ O(L d^2 + L^2 d) \]

Decode

若使用 KV cache,生成下一个 token 时:

  • 当前 token 的投影复杂度为 \(O(d^2)\)
  • query 与历史 key/value 交互的复杂度为 \(O(L d)\)

因此单步 decode 复杂度为:

\[ O(d^2 + L d) \]

单层 KV cache 的空间复杂度为:

\[ O(L d) \]

若模型共有 \(N\) 层,则总 KV cache 空间复杂度约为:

\[ O(N L d) \]

小结

  • 训练 / prefill 计算复杂度:\(O(L d^2 + L^2 d)\)
  • 标准训练内存复杂度:\(O(L^2)\)
  • 带 KV cache 的单步 decode 复杂度:\(O(d^2 + L d)\)
  • 单层 KV cache 空间复杂度:\(O(L d)\)

其中,FlashAttention 主要优化的是训练时的内存访问与显存占用,但不会改变 self-attention 的渐近计算复杂度 \(O(L^2 d)\)