![]()
本文的第一作者羅琪竣、第二作者李夢琦為香港中文大學(深圳)計算機科學博士生,本文在上海交通大學趙磊老師、香港中文大學(深圳)李肖老師的指導下完成。
長序列訓練對于模型的長序列推理等能力至關重要。隨著序列長度增加,訓練所需儲存的激活值快速增加,占據訓練的大部分內存。即便使用梯度檢查點(gradient checkpointing)方法,激活值依然占據大量內存,限制訓練所能使用的序列長度。
來自港中文(深圳)和上海交通大學的團隊提出StreamBP算法。通過對鏈式法則進行線性分解和分步計算,StreamBP 將大語言模型訓練所需的激活值內存(logits 和 layer activation)降低至梯度檢查點(gradient checkpointing)的 20% 左右。
![]()
- 論文標題:StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs
- 論文:https://arxiv.org/abs/2506.03077
- 代碼:https://github.com/Ledzy/StreamBP
在相同內存限制下,StreamBP 最大序列長度為梯度檢查點的 2.8-5.5 倍。在相同序列長度下,StreamBP 的速度和梯度檢查點接近甚至更快。StreamBP 適用于 SFT、GRPO、PPO 和 DPO 等常見 LLM 目標函數。代碼已開源,可集成至現有訓練代碼。
![]()
![]()
![]()
StreamBP 所需儲存的激活值和注意力掩碼(橙色)大幅低于梯度檢查點(橙色 + 白色部分)。
對于 lmhead 層,當以 SFT 或 GRPO 為目標函數時,觀察到不同位置的 logits 對于目標函數的影響相互獨立。因此,StreamBP 從序列維度分塊,每次計算單塊損失函數的梯度,從而只需儲存單塊 logits 和 logits 梯度。
![]()
圖:StreamBP for SFT
![]()
圖:StreamBP for GRPO
對于 DPO,由于非線性 sigmoid 函數的存在,每個位置的 logits 對于目標函數的影響并不獨立。StreamBP 利用 logits 梯度在序列維度的獨立性,分塊進行梯度計算。
![]()
圖:StreamBP for DPO
實驗結果
我們在單張 A800-80GB GPU 上測試了不同大小的模型,StreamBP 的最大 BP 序列長度為標準 BP 的 23-36 倍,梯度檢查點的 2.5-5.5 倍。
![]()
圖:不同序列長度下的 BP 峰值內存
在現有 Transformers 框架下,StreamBP 的實現可避免計算掩碼部分的 pre-attention score(見論文 3.2.2 部分),在長序列訓練下相較于梯度檢查點實現了加速。
![]()
通過使用 StreamBP,不同目標函數下最大的序列長度得到了大幅提升。在同樣的序列長度下,StreamBP 允許更大的批處理大小以加速訓練。
![]()
![]()
表:Qwen 3-4B 單個樣本 BP 時間,序列長度為 9000。
在 Deepspeed ZeRO 分布式訓練模式下,Distributed StreamBP 比梯度檢查點的最大可訓練序列長度提升了5—5.6倍。
![]()
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.