Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve
Prefill vs. Decode
Prefill 阶段做了什么
prefill = 对整个 prompt 做一次完整的 transformer forward。
包括了:
- 对所有输入 token 一起做 embedding
- 过每一层的 attention
- 过每一层的 FFN
- 同时把每层对应的 K/V 写入 KV cache
- 最后根据最后位置的 hidden state 过 lm head,得到 next-token logits
- 采样 / greedy / top-k / top-p,选出第一个输出 token
这个过程由于是有了所有的输入参数的计算,所以其可以在输入启动阶段就做所有启动处理,结果是 compute-bound.
Decode 阶段做了什么
对每一个新的 token (这里来看类似于一个单词的 向量表示), 来计算他们预测的下一个单词输出,这里应该会经过多个阶段,分别是
- linear 阶段,也就是从 token X 处理得到 Q, K, V
- attention 阶段,也就是通过 QKV, softmax 等计算得到 attention 得分
- FFN 阶段,也就是两个 linear forwarding layer + 一个 非线性层
这些称为一个 iteration, 然后别的一些论文研究了怎么做到 per-iteration batching 也就是在调度上实现精细化调度来支持不同 iteration 之间的独立运算组合 (高效可插拔设计),具体会在 Orca, vLLM 这些论文看到
这个过程如果是 naive 实现的版本是每次只能处理一个 token (因为所有的 token 都从因果关系上依赖于前面的每一个 token), 所以会比较慢,但是理想情况下我们会扩大 batch size 来让多个不同请求同时并发处理 token 预测,这样互不干扰,但是结果就是 memory-bound 因为会受到当前 batch size 的影响以及我们整个 gpu 的 mem capacity 的容量影响
现有的两种调度系统
| Prefill-prioritizing | Decode-prioritizing | |
|---|---|---|
| Pros | better throughput | low TBT, not affect ongoing |
| Cons | high latency, generation stall | low throughput |
| Example | Orca, vLLM | FasterTransformer |
| Feature |
有请求到达的时候就直接放入 prefill 阶段开始执行,只要 gpu 上面显存还有资源,因此可能会导致类似抢占现在正在推理的 decoding 阶段请求 | 新的请求必须等上一轮的 batch 中的所有请求都被处理完成了之后再开始执行下一轮的请求 |
generation stall 现象
用于定量描述推理性能的两个指标:
- TTFT (Time-to-First-Token) 表示从发送请求到第一个 token 生成的时间
- TBT (Time-between-Tokens) 表示连续生成的 token 之间的时间
不过也有从 utilization 角度描述性能的参数工具:
- Model FLOP’s Utilization (对应 CPU bound)
- Model Bandwidth Utilization (对应 Memory bound)
在上文中提到了 prefill-prioritizing 的问题,就是会倾向于让新来的请求直接下场抢断 decoding phase work, 并且由于我们的 request token 尺寸无法知道,所以我们很难知道这里的计算时长,如果很长的话会导致一段时间内 decode 抢不到活,出现所谓的 generation stall
类似的现象在 Orca 的 pipeline stall/bubbles 中也会提到;
各自阶段的负载对比
![[sarathi_serve_phase_cost.png]]
也就是两个结论:
- prefill 阶段受到 batch size 并行影响不大因为本来已经被 saturated 了,增加并行度没有更多并行运算,且 prefill 天生运算效率和运算量都比 decode 大
- linear 部分占据了大部分的时间,这里 linear 指的是从 token X 变成 QKV 矩阵的过程,也就是类似 这样的计算过程; 但是由于 decode 阶段的 gpu 利用率太低了,所以我们并行多个 decode 请求的时候会发现这里的 linear time 都不大
不过现在可能会有一个小问题: 在同一个 gpu 上同时处理不同请求 (不同上下文等) 的计算的时候除了 token 不一样,其他的 QKV 也不同吧?如何高效管理这个问题呢?这就是另一篇论文 Orca 讨论的内容了;
Sarathi-Serve System
核心的两个设计理念是 chunked-prefill 和 stall-free
