Prefill vs. Decode

Prefill 阶段做了什么

prefill = 对整个 prompt 做一次完整的 transformer forward。
包括了:

  • 对所有输入 token 一起做 embedding
  • 过每一层的 attention
  • 过每一层的 FFN
  • 同时把每层对应的 K/V 写入 KV cache
  • 最后根据最后位置的 hidden state 过 lm head,得到 next-token logits
  • 采样 / greedy / top-k / top-p,选出第一个输出 token

这个过程由于是有了所有的输入参数的计算,所以其可以在输入启动阶段就做所有启动处理,结果是 compute-bound.

Decode 阶段做了什么

对每一个新的 token (这里来看类似于一个单词的 向量表示), 来计算他们预测的下一个单词输出,这里应该会经过多个阶段,分别是

  1. linear 阶段,也就是从 token X 处理得到 Q, K, V
  2. attention 阶段,也就是通过 QKV, softmax 等计算得到 attention 得分
  3. FFN 阶段,也就是两个 linear forwarding layer + 一个 非线性层

这些称为一个 iteration, 然后别的一些论文研究了怎么做到 per-iteration batching 也就是在调度上实现精细化调度来支持不同 iteration 之间的独立运算组合 (高效可插拔设计),具体会在 Orca, vLLM 这些论文看到

这个过程如果是 naive 实现的版本是每次只能处理一个 token (因为所有的 token 都从因果关系上依赖于前面的每一个 token), 所以会比较慢,但是理想情况下我们会扩大 batch size 来让多个不同请求同时并发处理 token 预测,这样互不干扰,但是结果就是 memory-bound 因为会受到当前 batch size 的影响以及我们整个 gpu 的 mem capacity 的容量影响

现有的两种调度系统

Prefill-prioritizing Decode-prioritizing
Pros better throughput low TBT, not affect ongoing
Cons high latency, generation stall low throughput
Example Orca, vLLM FasterTransformer
Feature
有请求到达的时候就直接放入 prefill 阶段开始执行,只要 gpu 上面显存还有资源,因此可能会导致类似抢占现在正在推理的 decoding 阶段请求 新的请求必须等上一轮的 batch 中的所有请求都被处理完成了之后再开始执行下一轮的请求

generation stall 现象

用于定量描述推理性能的两个指标:

  • TTFT (Time-to-First-Token) 表示从发送请求到第一个 token 生成的时间
  • TBT (Time-between-Tokens) 表示连续生成的 token 之间的时间

不过也有从 utilization 角度描述性能的参数工具:

  • Model FLOP’s Utilization (对应 CPU bound)
  • Model Bandwidth Utilization (对应 Memory bound)

在上文中提到了 prefill-prioritizing 的问题,就是会倾向于让新来的请求直接下场抢断 decoding phase work, 并且由于我们的 request token 尺寸无法知道,所以我们很难知道这里的计算时长,如果很长的话会导致一段时间内 decode 抢不到活,出现所谓的 generation stall

类似的现象在 Orca 的 pipeline stall/bubbles 中也会提到;

各自阶段的负载对比

![[sarathi_serve_phase_cost.png]]

也就是两个结论:

  1. prefill 阶段受到 batch size 并行影响不大因为本来已经被 saturated 了,增加并行度没有更多并行运算,且 prefill 天生运算效率和运算量都比 decode 大
  2. linear 部分占据了大部分的时间,这里 linear 指的是从 token X 变成 QKV 矩阵的过程,也就是类似 Q:=XWQQ := XW_Q 这样的计算过程; 但是由于 decode 阶段的 gpu 利用率太低了,所以我们并行多个 decode 请求的时候会发现这里的 linear time 都不大

不过现在可能会有一个小问题: 在同一个 gpu 上同时处理不同请求 (不同上下文等) 的计算的时候除了 token 不一样,其他的 QKV 也不同吧?如何高效管理这个问题呢?这就是另一篇论文 Orca 讨论的内容了;

Sarathi-Serve System

核心的两个设计理念是 chunked-prefillstall-free