本文的第一作者羅琪竣、第二作者李夢琦為香港中文大學(xué)(深圳)計算機(jī)科學(xué)博士生,本文在上海交通大學(xué)趙磊老師、香港中文大學(xué)(深圳)李肖老師的指導(dǎo)下完成。
長序列訓(xùn)練對于模型的長序列推理等能力至關(guān)重要。隨著序列長度增加,訓(xùn)練所需儲存的激活值快速增加,占據(jù)訓(xùn)練的大部分內(nèi)存。即便使用梯度檢查點(gradient checkpointing)方法,激活值依然占據(jù)大量內(nèi)存,限制訓(xùn)練所能使用的序列長度。
來自港中文(深圳)和上海交通大學(xué)的團(tuán)隊提出StreamBP算法。通過對鏈?zhǔn)椒▌t進(jìn)行線性分解和分步計算,StreamBP 將大語言模型訓(xùn)練所需的激活值內(nèi)存(logits 和 layer activation)降低至梯度檢查點(gradient checkpointing)的 20% 左右。
- 論文標(biāo)題:StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs
- 論文:https://arxiv.org/abs/2506.03077
- 代碼:https://github.com/Ledzy/StreamBP
在相同內(nèi)存限制下,StreamBP 最大序列長度為梯度檢查點的 2.8-5.5 倍。在相同序列長度下,StreamBP 的速度和梯度檢查點接近甚至更快。StreamBP 適用于 SFT、GRPO、PPO 和 DPO 等常見 LLM 目標(biāo)函數(shù)。代碼已開源,可集成至現(xiàn)有訓(xùn)練代碼。
StreamBP 所需儲存的激活值和注意力掩碼(橙色)大幅低于梯度檢查點(橙色 + 白色部分)。
對于 lmhead 層,當(dāng)以 SFT 或 GRPO 為目標(biāo)函數(shù)時,觀察到不同位置的 logits 對于目標(biāo)函數(shù)的影響相互獨立。因此,StreamBP 從序列維度分塊,每次計算單塊損失函數(shù)的梯度,從而只需儲存單塊 logits 和 logits 梯度。
圖:StreamBP for SFT
圖:StreamBP for GRPO
對于 DPO,由于非線性 sigmoid 函數(shù)的存在,每個位置的 logits 對于目標(biāo)函數(shù)的影響并不獨立。StreamBP 利用 logits 梯度在序列維度的獨立性,分塊進(jìn)行梯度計算。
圖:StreamBP for DPO
實驗結(jié)果
我們在單張 A800-80GB GPU 上測試了不同大小的模型,StreamBP 的最大 BP 序列長度為標(biāo)準(zhǔn) BP 的 23-36 倍,梯度檢查點的 2.5-5.5 倍。
圖:不同序列長度下的 BP 峰值內(nèi)存
在現(xiàn)有 Transformers 框架下,StreamBP 的實現(xiàn)可避免計算掩碼部分的 pre-attention score(見論文 3.2.2 部分),在長序列訓(xùn)練下相較于梯度檢查點實現(xiàn)了加速。
通過使用 StreamBP,不同目標(biāo)函數(shù)下最大的序列長度得到了大幅提升。在同樣的序列長度下,StreamBP 允許更大的批處理大小以加速訓(xùn)練。
表:Qwen 3-4B 單個樣本 BP 時間,序列長度為 9000。
在 Deepspeed ZeRO 分布式訓(xùn)練模式下,Distributed StreamBP 比梯度檢查點的最大可訓(xùn)練序列長度提升了5—5.6倍。
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺“網(wǎng)易號”用戶上傳并發(fā)布,本平臺僅提供信息存儲服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.