Enhancing linear attention with residual learning
利用殘差學習增強線性注意力
https://arxiv.org/pdf/2509.25223v1
![]()
摘要
線性注意力為自注意力機制提供了一種線性時間復雜度的替代方案,但往往難以捕捉長距離模式。我們通過"預測-校正"的視角重新審視線性注意力,發現主流變體都可以被表示為歷史預測與單令牌校正的組合,這造成了表達能力瓶頸。為解決這一瓶頸,我們提出了殘差線性注意力(RLA),這是一個為線性注意力配備顯式殘差擬合機制的框架。RLA 維護一個輔助循環狀態,用于學習隨時間累積殘差誤差并校正基礎預測。我們進一步實例化了一個 delta 規則版本——殘差 Delta 網絡(RDN),結合了自適應門控和殘差裁剪以增強校正控制和穩定性。我們的實現利用了高度優化的線性注意力核函數,并保持線性的時間和內存復雜度。在語言建模和召回密集型評估中,RLA 和 RDN 始終優于各自的基線模型及其他現代線性注意力方法,在保持線性擴展性的同時縮小了與標準 Transformer 的差距。
1 引言
Transformer(Vaswani 等人,2017)架構已成為大型語言模型的標準。然而,其自注意力機制的二次時間復雜度仍然是一個關鍵瓶頸,限制了其在長序列上的應用(Li 等人,2024)。線性注意力最近作為標準自注意力的高效替代方案涌現,直接解決了其過高的二次復雜度問題。通過將注意力計算重構為循環過程,這些模型實現了線性時間的訓練和推理,使其非常適合處理長序列。RetNet(Sun 等人,2023)和 Mamba(Gu & Dao,2023;Dao & Gu,2024)等架構已展現出具有競爭力的性能。GLA(Yang 等人,2023)和 DeltaNet(Yang 等人,2024b)等方法通過引入數據依賴的門控和狀態更新規則來管理單一狀態矩陣內的信息流,進一步改進了性能。
現代線性注意力方法可以被統一為學習從鍵到值的直接映射(Sun 等人,2024),這一過程類似于測試時訓練。例如,delta 更新規則(Schlag 等人,2021)可以從二次損失目標的單步在線梯度下降推導得出。這一視角開辟了若干改進途徑,包括探索不同的在線學習損失函數以推導新的更新規則(Schlag 等人,2021;Yang 等人,2024b)、采用更復雜的映射函數,或修改在線梯度更新機制(von Oswald 等人,2025;Siems 等人,2025)。例如,TTT-MLP(Sun 等人,2024)和 Titans(Behrouz 等人,2024)等近期工作利用多層感知機(MLP)作為深度記憶模塊來實現更強大的映射。然而,這種方法犧牲了模型的線性循環特性,從而使并行訓練變得復雜。
基于這一視角,我們對注意力輸出提供了一種新的解釋。我們證明,主流線性注意力模型的輸出可以分解為來自歷史狀態的基礎分量和僅源自當前令牌的校正項(見第 2.3 節)。依賴單一令牌來執行這種系統性校正造成了瓶頸,損害了模型的表達能力。為解決這些問題,我們引入了殘差線性注意力,這是一個用顯式殘差擬合機制增強線性注意力模型的框架。我們的方法不依賴單一令牌進行校正,而是采用輔助狀態矩陣來顯式建模和校正基礎線性注意力的系統性預測誤差。最終輸出是基礎預測與該學習誤差校正的組合。我們的方法可以推廣為適用于各種線性注意力方法的統一框架,為構建更強大的序列模型提供了一種強大而高效的策略。
在現有線性注意力方法的基礎上,我們提出了兩種增強殘差擬合的變體:殘差線性注意力(RLA)和殘差 Delta 網絡(RDN)。我們在一系列基準測試上評估了它們,包括語言建模和召回密集型任務。我們的結果表明,這些模型優于各自的基線模型和其他現代線性注意力方法,而我們的消融分析證實了框架內每個關鍵設計選擇的重要性。
2 預備知識
2.1 作為循環模型的線性注意力
Softmax 注意力機制相對于序列長度表現出二次計算復雜度,在處理長序列時構成了顯著的瓶頸。線性注意力(Katharopoulos 等人,2020)架構通過移除 softmax 函數來解決這一問題,從而允許對計算進行重新排序。
![]()
![]()
這種循環形式在推理過程中保持每步恒定的時間和內存復雜度,并通過分塊并行算法實現高效訓練(Yang 等人,2023)。此外,門控機制的使用催生了更多變體的發展,如 RetNet(Sun 等人,2023)、Lightning Attention(Qin 等人,2024a)和 Mamba-2(Dao & Gu,2024)。
2.2 在線學習視角
![]()
![]()
這種形式化使 Delta Net(Yang 等人,2024b;Schlag 等人,2021)等模型能夠實現細粒度的記憶控制。門控 Delta Net(Yang 等人,2024a)進一步通過在學習過程中引入權重衰減來增強這一方法。
2.3 分解為預測與校正
![]()
![]()
![]()
![]()
基于預測-校正的視角,我們引入了一個殘差擬合框架來增強線性注意力。我們的框架通過顯式擬合超出當前令牌的上下文信息,學習一個更具表達力的校正項。
3 方法
本節介紹我們提出的方法,該方法通過殘差擬合過程來增強線性注意力。我們首先描述支撐我們方法的基礎殘差學習框架。接下來,我們引入自適應校正因子以增強建模能力,并引入裁剪方法來穩定殘差擬合過程。最后,我們展示我們方法的兩種最終變體。
3.1 顯式殘差擬合
![]()
![]()
利用第 2 節中線性注意力的在線學習視角,我們對輔助狀態應用類似的更新規則。這產生了以下循環過程:
![]()
![]()
3.2 自適應門控與校正因子
![]()
![]()
![]()
![]()
這種形式化使用衰減因子和校正因子來分別對來自基礎狀態和輔助狀態的檢索進行動態門控。
3.3 歸一化與殘差裁剪
為確保計算穩定性,我們引入兩種機制。首先,我們對查詢和鍵向量應用 L2 歸一化以提高數值穩定性。其次,我們通過裁剪殘差來解決輔助狀態中的潛在不穩定性:
![]()
這確保了誤差校正狀態保持穩定的學習軌跡,即使基礎模型產生瞬態的、較大的預測誤差。該裁剪方法的詳細推導見附錄 B。
3.4 最終形式化
殘差擬合原理是一種通用技術,可以與各種線性注意力主干網絡集成。通過將我們的殘差機制應用于標準加法更新規則和 delta 更新規則,我們推導出兩種強大的變體。這導出了我們的最終模型:
![]()
![]()
![]()
4 實驗
4.1 實驗設置
實現 為了最大化效率,我們在 Triton(Tillet 等人,2019)中實現了自定義注意力核函數,基于 flash-linear-attention 庫(Yang & Zhang,2024)構建。我們利用了這樣一個事實:我們的狀態更新規則與線性注意力的相同,只需對其核函數進行微小修改:我們將其增強為返回注意力結果和中間殘差。這種設計允許在所有殘差擬合階段重用相同的高度優化核函數,確保高吞吐量。
![]()
4.2 主要結果
核函數效率 我們將我們的核函數運行時間與線性注意力基線和 FlashAttention(Dao 等人,2022;Dao,2023)進行基準測試,如圖 2 所示。盡管殘差擬合過程增加了計算開銷,但我們方法的運行時間隨序列長度線性擴展。這使其在較長序列上顯著快于二次擴展的 FlashAttention。關于吞吐量,我們的方法與其他線性注意力機制一樣,保持幾乎恒定的高吞吐量。相反,計算受限的 FlashAttention 的吞吐量隨序列長度增加而迅速下降。
![]()
語言建模與常識推理 我們在 WikiText(Merity 等人,2016)困惑度以及一系列評估推理和常識理解的基準測試上評估 RLA 和 RDN。推理任務包括 ARC-Easy、ARC-Challenge(Clark 等人,2018)、PIQA(Bisk 等人,2020)和 MMLU(Hendrycks 等人,2020),而常識理解則在 HellaSwag(Zellers 等人,2019)、Winogrande(Sakaguchi 等人,2021)、SocialIQA(Sap 等人,2019)和 LAMBADA(Paperno 等人,2016)上進行評估。我們的主要結果總結于表 2,顯示我們提出的殘差學習變體 RLA 和 RDN 在困惑度上相對于各自的基線 sGLA 和 GDN 取得了一致的改進。此外,我們的模型在多個基準測試上優于其他領先的線性注意力方法,并提供與標準 Transformer 相當的性能。
![]()
召回密集型任務 為了評估記憶容量,我們在 Arora 等人(2024)的召回密集型任務上對我們的模型進行基準測試。此外,我們還直接使用"大海撈針"任務(NIAH)(gkamradt,2023)評估模型的檢索能力,該任務需要檢索插入在長文檔不同深度的鍵值對。這些基準測試對線性注意力模型具有挑戰性,因為它們的有限狀態空間造成了信息瓶頸,如表 3 所示。結果表明,我們提出的 RLA 和 RDN 始終優于其相應的基線,在 DROP 和 FDA 基準測試上取得了特別顯著的收益。此外,它們在 NIAH 任務上大幅優于其他模型,突顯了增強的信息召回能力。
4.3 消融研究
在本節中,我們進行一系列消融研究以驗證關鍵組件的貢獻。我們首先量化我們學習的殘差擬合方法相對于預定義校正的優勢。接下來,我們研究使用專用校正因子的重要性,然后分析將基礎預測與校正相結合的門控機制的必要性。最后,我們檢查歸一化和殘差裁剪的效果。
![]()
如表 4 所示,缺乏顯式殘差擬合的變體表現不如我們的完整方法。盡管該消融變體在某些基準測試上保持競爭力,但它在訓練集和評估集上的困惑度都顯著增加。這種性能下降延伸到專業領域,在 GSM8k(Cobbe 等人,2021)和 HumanEval(Chen 等人,2021)的困惑度測量中,其數學和代碼能力顯著退化。這證明了輔助狀態在累積過去殘差以有效細化模型輸出方面的關鍵作用。
![]()
專用校正因子 我們通過將我們的完整模型與 γ 綁定到更新因子 β 的變體進行比較,分析使用專用校正因子 γ 的益處。在圖 3a 中,具有獨立 γ 的模型始終實現更低的評估損失,其中 RDN 變體顯示出更大的改進。這一趨勢延伸到下游性能,如圖 3b 的結果所示,該圖還顯示專用校正因子在多個基準測試上帶來性能提升。值得注意的是,我們的基礎架構(不需要額外的 γ)仍然比基線線性注意力方法有顯著改進。
![]()
![]()
![]()
歸一化與殘差裁剪 最后,我們研究歸一化和殘差裁剪的重要性。我們通過對 RLA 移除歸一化和裁剪來進行消融研究。如圖 4 所示,兩個組件對穩定訓練都至關重要;移除它們會導致無界激活和性能退化。相比之下,RDN 模型對殘差裁剪很大程度上不敏感。這種魯棒性歸因于其 delta 規則更新的固有穩定性,即使沒有殘差裁剪也能保持一致的損失曲線(圖 4b)。
![]()
5 相關工作
序列建模歷史上由循環神經網絡(RNN)(Lipton 等人,2015)主導,包括長短期記憶網絡(LSTM)(Hochreiter & Schmidhuber,1997)和門控循環單元(GRU)(Cho 等人,2014)等變體。雖然有效,但其固有的順序性質阻礙了訓練并行化。Transformer 架構(Vaswani 等人,2017)克服了這一限制,成為序列建模的事實標準。然而,其自注意力機制具有相對于序列長度的二次計算復雜度,對長上下文應用構成了顯著瓶頸。
為解決這些挑戰,近期研究重新審視了線性 RNN,將其作為高效 Transformer 替代方案的基礎。通過將序列處理形式化為線性循環,這些模型實現了可并行化訓練和線性時間推理。該領域的早期探索,如 S4(Gu 等人,2021)、LRU(Orvieto 等人,2023)和 RetNet(Sun 等人,2023),利用了結構化狀態轉移矩陣。通過引入數據依賴的動態特性,后續實現了性能飛躍。Mamba(Gu & Dao,2023;Dao & Gu,2024)、HGRN(Qin 等人,2023;2024b)和門控線性注意力(Yang 等人,2023)等模型利用輸入依賴的門控來動態控制狀態轉移,從而增強其表達能力。
更先進的方法引入了 delta 學習規則,將狀態更新從簡單的門控衰減重新框架為細粒度的記憶校正。這種方法以 DeltaNet(Yang 等人,2024b;Schlag 等人,2021)和門控 DeltaNet(Yang 等人,2024a)為代表,實現了更精確的動態記憶修改。該機制可以從在線學習視角理解,其中狀態更新被框架為優化過程,如 TTT(Sun 等人,2024)所探索的。這一觀點啟發了進一步的工作,旨在發現和改進序列模型內的內在學習算法(von Oswald 等人,2023;2025)。
同期研究聚焦于增加狀態轉移的表達能力。例如,RWKV-7(Peng 等人,2025)采用對角加低秩結構,而 DeltaProduct(Siems 等人,2025)通過每令牌執行多步更新來推廣 DeltaNet。為進一步提升容量,近期架構如 Titans(Behrouz 等人,2024)和 Miras(Behrouz 等人,2025)引入了非線性深度記憶,用 MLP 對狀態進行參數化。
6 結論
在本文中,我們介紹了殘差線性注意力,這是一個通過顯式殘差擬合過程來增強線性注意力模型的框架。我們的方法利用輔助狀態來校正基礎模型的預測誤差,從而構建更魯棒和準確的上下文表示。該框架具有高度適應性,可應用于各種線性注意力方法。我們的實驗證明了這種多功能性,顯示我們的方法始終優于各自的基線。雖然這種改進以擬合過程的額外計算為代價,但平衡這一權衡為未來的研究提供了一個有前景的方向。
原文鏈接:https://arxiv.org/pdf/2509.25223v1
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.