[LG]《Fast Inference via Hierarchical Speculative Decoding》A Globerson, H Kaplan, Y Mansour, C Mohri... [Google Research & Tel Aviv University] (2025)
层次化推测解码:加速大语言模型推理的新方法
在Transformer语言模型中,推理过程是自回归的,每生成一个token都需要完整的前向传播,导致延迟与输出长度成正比,计算资源利用率低下。传统推测解码(Speculative Decoding)通过小型草稿模型(drafter)预先生成多个token,再由大型目标模型(target model)并行验证,能显著降低延迟,同时保证输出分布与目标模型一致——最坏情况下每轮至少验证1个token,最好情况下验证所有草稿加1个。
然而,单一草稿模型存在权衡:更小的模型更快但准确率低,更大的模型可靠但延迟高。最新研究(如LayerSkip模型的早期退出)虽优化了Pareto前沿,但仍局限于单drafter。本文提出Hierarchical Speculative Decoding (HSD)算法,将多个drafter堆叠成层次结构:仅最底层drafter自回归生成token,上层模型依次验证下层提案,直至目标模型最终验证。这种设计最大化并行验证(利用硬件加载开销),最小化自回归生成成本,并确保添加模型时延迟可进一步降低。
HSD算法核心原理
HSD的核心是递归验证机制(详见Algorithm 1)。给定模型集{M0, M1, ..., MK}(MK为目标模型),每个模型Mi有推理成本ci和成对接受率αi,j(基于总变差距离计算,衡量分布相似度)。参数T = {T0, ..., TK-1}控制每个层缓冲区大小:
- 最底层M0自回归生成T0个token。
- 中间层Mi请求下层生成批次,验证后累积至Ti个已验证token,若不足则递归请求更多。
- 顶层MK验证所有下层提案,使用拒绝采样规则(详见Algorithm 2):对每个draft token xt,若qc(xt) ≤ pc(xt)则接受;否则以pc(xt)/qc(xt)概率接受,或从修正分布p'c采样替换。拒绝率αc = 1 - Σ (pc(x) - qc(x))^2 / 2,确保输出分布精确匹配MK。
正确性证明(Theorem 2.1):通过归纳,HSD输出严格遵循目标模型分布,无论层次如何。
延迟分析与优化
假设接受独立同分布(IID,实证验证合理),HSD每token预期延迟为:
L = Σ_{i=0}^K ci - Π_{j=i}^K R(α_{j-1,j}, j)
其中R函数捕捉每层预期轮次:R(α, n) = γ(α, T_{n-1}, T_n)(中间层预期递归轮次,由经验估计)或T0(底层)或(1-α)/(1-α^{T_{K-1}+1})(顶层预期生成token数)。
关键洞见:多drafter可优于单drafter。例如,表1展示配置中,模型数从1增至6,加速从1×升至3.08×,延迟降至10.61(成本c1~c6递增,α矩阵满足三角不等式)。这启发我们:层次化能放大累积接受率,平衡成本-准确 tradeoff,尤其在早期退出模型中。
优化挑战:2^K种子集组合,指数爆炸。HSD问题定义为min L(σ, T),通过归约至广义最短路径(GSP)问题求解(Theorem 3.7):构建图G(顶点包括(MK)、(Mi,j)参数选择及自环L),边乘子μ和成本c编码接受率与延迟(详见表2)。GSP在O(m n^2 log n)时间内解(Oldham, 2001),HSD总复杂度O(T^4 K^4 log(TK)),远优于穷举。即使T_max=15,K=80,优化仅需数小时CPU。
此优化深度思考:并非所有模型都优(冗余或低α可能有害),GSP自动选最佳子集σ和T;实证显示,预训练早期退出(如LayerSkip)比后训练头(如Gemma2)获益更大,凸显训练策略的重要性。
实证结果
在CNN-DM和XSUM数据集上评估开源模型(NVIDIA H100 GPU,batch=1):
- LayerSkip系列(7B/13B/70B,32/40/80层,预训早期退出):HSD选[7,9,32]/[2,5]等,加速1.76×(vs.单drafter1.62×,vs.自回归1×),秒/token降至0.0102。70B模型加速1.77×,证明大模型中层次化潜力。
- Gemma2-9B(42层,后训LM头):CNN-DM上1.06×,XSUM上1.15×(vs.单drafter1.03×/1.08×)。内存开销线性(每层一头),但整体fit单GPU。
HSD vs. 单drafter加速高达1.2×,假设验证(IID接受、恒定成本)误差