Transformer 复杂度分析
下面默认讨论单层 Transformer block,记序列长度为 \(L\),hidden size 为 \(d\),head 数为 \(h\),每个 head 的维度为 \(d_h = d / h\)。若模型共有 \(N\) 层,则总复杂度再乘一个 \(N\)。
训练计算复杂度
设输入为 \(X \in \mathbb{R}^{L \times d}\)。
- \(Q,K,V\) 投影和输出投影的复杂度都是 \(O(L d^2)\):因为它们都对应一个 \((L \times d)\) 矩阵与一个 \((d \times d)\) 矩阵的乘法,例如 \(Q = XW_Q\),其计算量为 \(O(L d^2)\)
- 每个 head 计算 \(QK^T\) 的复杂度为 \(O(L^2 d_h)\),所有 head 合起来为 \(O(L^2 d)\)
- attention 权重与 \(V\) 相乘的复杂度也是 \(O(L^2 d)\)
因此 self-attention 的主要计算复杂度为:
\[ O(L d^2 + L^2 d) \]
若再计入 FFN,则通常还会增加一个 \(O(L d^2)\) 量级的项。
训练内存复杂度
标准 attention 在训练时需要保存 attention score matrix:
\[ QK^T \in \mathbb{R}^{L \times L} \]
因此单层的主要激活内存复杂度为:
\[ O(L^2) \]
更准确地写,多头情况下是 \(O(h L^2)\);通常把 \(h\) 视为常数后简写为 \(O(L^2)\)。
推理复杂度
推理通常分为 prefill 和 decode 两个阶段。
Prefill
对整段 prompt 做一次完整 attention,因此单层计算复杂度仍为:
\[ O(L d^2 + L^2 d) \]
Decode
若使用 KV cache,生成下一个 token 时:
- 当前 token 的投影复杂度为 \(O(d^2)\)
- query 与历史 key/value 交互的复杂度为 \(O(L d)\)
因此单步 decode 复杂度为:
\[ O(d^2 + L d) \]
单层 KV cache 的空间复杂度为:
\[ O(L d) \]
若模型共有 \(N\) 层,则总 KV cache 空间复杂度约为:
\[ O(N L d) \]
小结
- 训练 / prefill 计算复杂度:\(O(L d^2 + L^2 d)\)
- 标准训练内存复杂度:\(O(L^2)\)
- 带 KV cache 的单步 decode 复杂度:\(O(d^2 + L d)\)
- 单层 KV cache 空间复杂度:\(O(L d)\)
其中,FlashAttention 主要优化的是训练时的内存访问与显存占用,但不会改变 self-attention 的渐近计算复杂度 \(O(L^2 d)\)。