大家好,我是Ai學習的老章
推薦一篇迎合文章,From:Horace He 與 Thinking Machines 的其他成員
![]()
可重復性是科學進步的基石。然而,要從大型語言模型中獲得可重復的結果卻出奇地困難。
例如,你可能會發現多次向 ChatGPT 提出同一個問題會得到不同的結果。這本身并不奇怪,因為從語言模型獲取結果涉及“采樣”,這是一個將語言模型的輸出轉換為概率分布并依概率選擇 token 的過程。
更令人驚訝的可能是,即使我們將溫度調到 0這意味著 LLM 總是選擇概率最高的 token,這被稱為貪婪采樣。(理論上使采樣變得確定),LLM API 在實際中仍然不具備確定性。即便在你自己的硬件上使用 vLLM 或 SGLang 這類 OSS 推理庫運行推理,采樣依舊不是確定性的。
但為什么 LLM 推理引擎不具備確定性?一種常見假設是,浮點非結合性與并發執行共同作用,導致非確定性,具體取決于哪個并發核心先完成。我們將此稱為 LLM 推理非確定性的“并發 + 浮點”假設。例如,最近的一篇 arXiv 預印本寫道:
GPU 中的浮點運算表現出非結合性,意味著 (a+b)+c≠a+(b+c)(a+b)+c=a+(b+c) 由于有限精度和舍入誤差。這一特性直接影響 Transformer 架構中注意力分數和 logits 的計算,其中跨多個線程的并行操作可能因執行順序不同而產生不同結果。
你也可以在其他地方看到“并發 + 浮點”這一假設被反復提及,比如這里(“存在速度權衡,為了讓端點足夠快,我們使用 GPU,而 GPU 會進行并行 [非確定性] 計算。任何現代 GPU 上的神經網絡計算都會受到這些影響。”),或者這里(“由于 GPU 高度并行化,每次執行時加法或乘法的順序可能不同,這會導致輸出出現微小差異。”)。
雖然這一假設并非完全錯誤,但它并沒有揭示全貌。例如,即使在 GPU 上,對同一數據反復執行相同的矩陣乘法,結果也會始終位級相等。我們確實在使用浮點數,我們的 GPU 也確實擁有大量并發。為什么在這個測試中我們看不到非確定性?
A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16) B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16) ref = torch.mm(A, B) for _ in range(1000): assert (torch.mm(A, B) - ref).abs().max().item() == 0要理解 LLM 推理非確定性的真正原因,我們必須看得更深。
不幸的是,就連定義“LLM 推理是確定性的”到底意味著什么都很困難。可能令人困惑的是,以下所有說法同時成立:
某些 GPU 上的 kernel 是非確定性的。
然而,語言模型前向傳播中使用的所有 kernel 都是確定性的。
此外,LLM 推理服務器(如 vLLM)的前向傳播也可以被認為是確定性的。
然而,從任何使用推理服務器的人的角度來看,結果都是非確定性的。
在這篇文章中,我們將解釋“并發 + 浮點”假設為何偏離靶心,揭開 LLM 推理非確定性的真正元兇,并說明如何擊敗非確定性,在 LLM 推理中獲得真正可復現的結果。
原罪:浮點運算的非結合性
在討論非確定性之前,先解釋為何會出現數值差異是有益的。畢竟,我們通常將機器學習模型視為遵循交換律或結合律等結構規則的數學函數。難道不應該存在一個“數學上正確”的結果,由我們的機器學習庫提供給我們嗎?
罪魁禍首是浮點數的非結合性。也就是說,對于浮點數:
(a+b)+c≠a+(b+c)(a+b)+c=a+(b+c)
(0.1 + 1e20) - 1e20 >>> 0 0.1 + (1e20 - 1e20) >>> 0.1諷刺的是,正是打破結合律才讓浮點數變得有用。
浮點數之所以有用,是因為它們允許“動態”的精度水平。為了便于解釋,我們將使用十進制(而非二進制),其中浮點數的格式為 mantissa?10exponentmantissa?10exponent 。我們還將使用 3 位尾數和 1 位指數。
例如,對于值 3450,我們可以精確地表示為 3.45?1033.45?103 。我們也可以表示更小的值,如 0.486,表示為 4.86?10?14.86?10?1 。通過這種方式,浮點數讓我們既能表示非常小的值,也能表示非常大的值。在科學領域,我們可能會說浮點數讓我們保持恒定的“有效數字”位數。
如果將兩個具有相同指數的浮點數相加,看起來類似于整數加法。例如,123( 1.23?1021.23?102 )+ 456( 4.56?1024.56?102 )結果為 579( 5.79?1025.79?102 )。
但當我們將兩個指數不同的浮點數相加時,比如 1230 和 23.4,會發生什么?此時精確結果為 1253.4。然而,我們一次只能保留 3 位有效數字。因此,浮點加法會舍去最后兩位,得到 1.25?1031.25?103 (即 1250)。
![]()
我們需要 3 位有效數字來表示 1230,也需要 3 位有效數字來表示 23.4。然而,將這兩個數相加后,結果需要 5 位有效數字才能精確表示(1253.4)。我們的浮點格式只能把末尾的 34 截掉。某種意義上,我們實際上先把原來的 23.4 四舍五入成了 20.0,然后再相加。
但此時,我們已經破壞了信息。請注意,每當我們把兩個“尺度”不同(即指數不同)的浮點數相加時,都可能發生這種情況。而指數不同的浮點數相加在實際中非常常見。事實上,如果我們能保證永遠不需要不同的指數,那干脆用整數就好了!
換句話說,每當我們以不同順序相加浮點數時,就可能得到完全不同的結果。舉個極端的例子,僅因求和順序不同,這個數組就可能產生 102 種不同的結果。
import random vals = [1e-10, 1e-5, 1e-2, 1] vals = vals + [-v for v in vals] results = [] random.seed(42) for _ in range(10000): random.shuffle(vals) results.append(sum(vals)) results = sorted(set(results)) print(f"There are {len(results)} unique results: {results}") # Output: # There are 102 unique results: [-8.326672684688674e-17, -7.45931094670027e-17, ..., 8.326672684688674e-17]盡管這是輸出不一致的根本原因,但它并未直接解釋非確定性的來源。它無法幫助我們理解浮點值為何會以不同順序相加、何時發生這種情況,以及如何避免。
答案在于內核是如何實現的。
為什么內核不總是以相同順序相加數字?
如上所述,對于內核為何以不同順序相加數字,一種常見解釋是“并發 + 浮點”假說。該假說指出,如果并發線程完成的順序是非確定性的,并且累加順序依賴于并發線程完成的順序(例如使用原子加法),那么我們的累加順序也將是非確定性的。
令人困惑的是,盡管這可能導致非確定性內核,但在 LLM 推理的非確定性中,并發(以及原子加法)最終卻完全無關!為了解釋真正的罪魁禍首是什么,我們首先來理解為什么現代 GPU 內核很少需要原子加法。
什么時候才需要原子加法?
通常,GPU 會在許多“核心”(即 SM)上并發地啟動一個程序。由于這些核心之間沒有內在的同步機制,當它們需要相互通信時就會帶來挑戰。例如,如果所有核心都必須累加到同一個元素,你可以使用“原子加”(有時稱為“fetch-and-add”)。原子加是“非確定性的”——結果累加的順序完全取決于哪個核心最先完成。
具體地說,假設你正在用 100 個核心歸約一個 100 元素的向量(例如torch.sum())。雖然你可以并行加載全部 100 個元素,但最終必須歸約到單個元素。實現這一點的一種方法是使用某種“原子加”原語,硬件保證所有加法都會被處理,但不保證順序。
![]()
原子加確保每個核心的貢獻都會體現在最終和中。然而,它并不保證這些貢獻將以何種順序被累加。順序完全取決于哪個核心最先完成,這是一種非確定性屬性。因此,多次執行同一個并行程序可能會產生非確定性的輸出。
這通常就是人們所說的“非確定性”——你用完全相同的輸入兩次執行同一個內核,卻得到了不同的結果。這被稱為“運行間非確定性”,即你用完全相同的依賴項兩次運行同一個 Python 腳本,卻得到了不同的結果。
盡管并發原子加法確實會讓內核變得非確定性,但絕大多數內核并不需要原子加法。事實上,在 LLM 的典型前向傳播中,通常連一個原子加法都不會出現。
考慮到并行化歸約操作可以從原子加法中受益,這一點可能會令人驚訝。原子加法最終不被需要主要有兩個原因。
通常沿著“批次”維度已經有足夠的并行度,因此我們無需沿著歸約維度進行并行化。例如,假設我們不是歸約單個 100 維向量,而是并行歸約 500 個向量。在這種情況下,我們可以在每個核心上歸約一個完整的向量,并讓每個核心處理不同的向量。
隨著時間的推移,大多數神經網絡庫都采用了多種策略,在不影響性能的前提下實現確定性。例如,我們可以進行“拆分”(或樹形)歸約,將 100 個元素的歸約拆分為五個 20 個元素的歸約(從而實現五路并行)。然后,為了合并剩下的五個元素,我們可以執行一次單獨的“清理”歸約(這部分不再并行,但元素數量很少,開銷極低),或者使用信號量(確保每個并發線程塊按確定順序累加)。信號量策略的描述可在此處找到。
由于這兩個因素,在絕大多數神經網絡運算中,避免原子加操作帶來的性能損失可以忽略不計。
仍有少數常見運算在避免原子加時會帶來顯著的性能損失。例如,PyTorch 中的scatter_add(a[b] += c)。然而,在 LLMs 中唯一常用的就是 FlashAttention 的反向傳播。有趣的事實:你知道嗎?廣泛使用的 Triton 版 FlashAttention 反向實現,在算法上與 Tri Dao 的 FlashAttention-2 論文并不相同?標準的 Triton 實現會在反向傳播中額外重新計算,從而避免原子操作,但代價是 FLOPs 增加 40%!
然而,LLM 的前向傳播中沒有任何需要原子加法的操作。因此,LLM 的前向傳播實際上是“運行到運行確定性的”。
![]()
從推理服務器的角度來看,它是確定性的。給定完全相同的用戶請求,它總會給出相同的確定性輸出。
維基百科寫道:“確定性算法是指給定特定輸入時,總會產生相同輸出的算法。”而在這個場景下,給定完全相同的輸入(即推理服務器正在處理的完全相同的請求),前向傳播總會產生完全相同的輸出。
然而,前向傳播本身“確定”并不足以保證包含它的整個系統也是確定的。例如,如果我們的請求輸出依賴于并行的用戶請求(例如 batch-norm)呢?由于每個單獨請求都無法預知并行請求會是什么,從它們的角度看,我們的整體 LLM 推理也是非確定的!
事實證明,我們的請求輸出確實依賴于并行的用戶請求。并不是因為我們以某種方式在批次之間泄露信息——而是我們的前向傳播缺乏“批次不變性”,導致我們的請求輸出依賴于前向傳播的批次大小。
批次不變性與“確定性”
為了解釋 batch invariance,讓我們簡化系統,只看 matmul。你可以假設所有 matmul 實現都是“運行間確定”的。這并不完全正確,但大多數常見的 matmul 實現確實具有這一特性。然而,它們并不是“batch 不變”的。換句話說,當 batch size 改變時,batch 中的每個元素都可能得到不同的結果。
從數學角度來看,這是一個相當不尋常的特性。矩陣乘法在 batch 的每個元素上應該是“獨立”的——batch 中的其他元素或 batch 的大小都不應影響 batch 中某個特定元素的計算結果。
然而,正如我們憑經驗觀察到的那樣,事實并非如此。
import torch torch.set_default_device('cuda') B = 2048 D = 4096 a = torch.linspace(-1000, 1000, B*D).reshape(B, D) b = torch.linspace(-1000, 1000, D*D).reshape(D, D) # Doing a matrix vector multiplication by taking # the first element of the batch out1 = torch.mm(a[:1], b) # Doing a matrix matrix multiplication and then taking # the first element of the batch out2 = torch.mm(a, b)[:1] print((out1 - out2).abs().max()) # tensor(1669.2500, device='cuda:0')請注意,這是“運行間確定性”。如果你多次運行該腳本,它會確定性地返回相同的結果。它并非“硬件/軟件版本不變”——你的 GPU/PyTorch 版本可能會返回不同的值,但它應該確定性地返回相同的值。
然而,當一個非批次不變的內核被用作更大推理系統的一部分時,系統就可能變得非確定性。當你向推理端點發出查詢時,服務器當前的負載量從用戶角度來看實際上是“非確定性”的。負載決定了內核運行的批次大小,從而改變了每個單獨請求的最終結果!
![]()
盡管推理服務器本身可以被認為是“確定性的”,但對單個用戶而言情況卻不同。從單個用戶的角度來看,其他并發用戶并不是系統的“輸入”,而是系統的一種非確定性屬性。這使得 LLM 推理在每個用戶看來都是“非確定性”的。
如果你將某個內核不具備不變性的屬性(例如 batch-size)與該屬性的非確定性(例如服務器當前負載)組合在一起,就會得到一個非確定性系統。
換句話說,幾乎所有 LLM 推理端點之所以非確定,根本原因就是負載(進而導致 batch-size)本身在不可預測地變化!這種非確定性并非 GPU 獨有——無論是 CPU 還是 TPU 提供的 LLM 推理端點,同樣會受這一非確定性來源的影響。
因此,若想在我們的推理服務器中避免非確定性,就必須在內核層面實現 batch 不變性。為了弄清如何做到這一點,我們先來看看為什么內核一開始就不具備 batch 不變性。
我們如何讓內核具備 batch 不變性?
為了讓 transformer 實現對 batch 不敏感,我們必須讓每個 kernel 都對 batch 不敏感。幸運的是,我們可以假設所有逐點運算都是對 batch 不敏感的。盡管對于 PyTorch 中的所有 kernel 來說確實如此,但這并非必然成立。例如,CPU 上的一些 kernel 實現會在數組的某些部分使用向量化 intrinsic,而在其他部分使用非向量化 intrinsic,而這些 intrinsic 的數值結果并不總是逐位一致。因此,我們只需關注涉及歸約的 3 種操作——RMSNorm、矩陣乘法和注意力。與并行相關的歸約不在本文討論范圍內,但同樣的原則適用。一個可能有用的信息是:在 Blackwell 以及使用 CUDA 12.8+ 的 Hopper 上,NVLink-Sharp 的 in-switch 歸約是確定性的。和許多事情一樣,這些信息可以在 NCCL 的 GitHub issues 中找到。
方便的是,這些也按難度遞增的順序排列。每一項都需要額外考慮,才能在合理性能下實現批次不變性。我們先從 RMSNorm 說起。
批次不變的 RMSNorm![]()
數據并行的 RMSNorm 理想情況下,我們希望并行策略中核心之間無需通信。一種實現方法是將每個批次元素分配給單獨的核心,從而保證所有歸約操作完全在一個核心內完成。這就是所謂的“數據并行”策略,因為我們只是沿著無需通信的維度進行并行。在此示例中,我們有四行和四個核心,正好占滿所有核心。
RMSNorm 的實現如下:
# x: [batch_size, hidden_dim] # weight: [hidden_dim] def rms_norm(x, weight): return x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * weight批不變性的要求是:無論內核的批大小如何,每個元素的歸約順序都必須固定。請注意,這并不意味著我們必須始終使用相同的歸約策略。例如,如果我們改變要歸約的元素數量,即使歸約策略發生變化,我們仍然可以保持批不變性。Quack 博客文章中有一些很好的示例,展示了可以使用的各種歸約策略的層次結構(例如線程歸約、warp 歸約、block 歸約、cluster 歸約)。
因此,只有當我們的批大小影響歸約策略時,我們才會破壞批不變性。
讓我們來看看 RMSNorm 的標準并行策略。一般而言,并行算法通過最小化跨核心的通信來獲得收益。為了本次討論的目的,你可以假設當我們提到“cores”時,我們指的是 SMs。更具體地說,這里重要的屬性是:我們內核啟動的 threadblock 數量大于 SMs 的數量。因此,我們可以從一個簡單的策略開始:將每個批次元素分配給單個核心,如上圖所示。
增大批次大小不會影響我們的歸約策略;如果批次大小為 200 就能為內核提供足夠的并行度,那么批次大小為 2000 時肯定也能提供足夠的并行度。
![]()
更大批次的數據并行 RMSNorm 將數據并行策略擴展到更大批次非常簡單——不再讓每個核心處理一行,而是讓每個核心按順序處理不同的行。這保持了批次不變性,因為每個批次元素的歸約策略保持不變。
另一方面,減小批大小也會帶來挑戰。由于我們將每個批元素分配給一個核心,當批大小減小時,最終會出現核心數量多于批元素的情況,導致部分核心閑置。
遇到這種情況,優秀的內核工程師會采用上一節提到的解決方案(原子加法或拆分歸約),以保持高并行度,從而維持良好性能。然而,這會改變歸約策略,導致該內核不再具備批不變性。
![]()
拆分歸約的 RMSNorm 如果批大小較小,我們的數據并行策略可能無法提供足夠的并行度來充分利用所有核心。此時,將歸約操作“拆分”到多個核心上執行可能更高效,從而充分利用 GPU。但這會失去批不變性,因為我們不再以相同順序歸約每個元素。
最簡單的解決方案是直接忽略這些情況。這并非完全不合理——小批大小意味著內核本身執行速度較快,因此性能下降可能不會造成災難性后果。
如果我們被迫優化這一用例,一種方法是始終采用一種即便在極小批量下也具備足夠并行度的歸約策略。這種策略在較大批量時會產生過量并行,但能在整個尺寸范圍內都獲得尚可(而非峰值)的性能。
批不變矩陣乘法![]()
數據并行 Matmul 與 RMSNorm 類似,matmul 的標準并行策略是“數據并行”,將整個規約操作保留在一個核心內。最直觀的做法是把輸出張量拆分成 2D 瓦片,并將每塊瓦片分配給不同的核心。每個核心隨后計算屬于該瓦片的點積,再次在單個核心內完成全部規約。
與 RMSNorm 不同的是,圍繞算術強度以及充分利用 Tensor Core 的額外約束,迫使我們在高效 matmul 內核中拆分 2D 瓦片,而不是單個輸出元素。
本質上,你可以把矩陣乘法看作一個逐點操作后再進行規約。于是,如果我們通過將輸出分塊來并行化矩陣乘法,就得到了一種類似的“數據并行”內核策略,使每次規約都保留在單個核心內。
與 RMSNorm 類似,我們的“批”維度(M 和 N)也可能變得過小,從而被迫沿著歸約維度(K)切分。盡管有兩個“批”維度,矩陣乘法仍需要每個核心承擔更多“工作量”,才能有效利用 Tensor Core。例如,對于 [1024, K] × [K, 1024] 的矩陣乘法,若采用標準的 2D 瓦片大小 [128, 128],數據并行策略只能將其拆分到 64 個核心,不足以讓 GPU 飽和。
在矩陣乘法中沿著歸約維度切分被稱為 Split-K Matmul。與 RMSNorm 一樣,這種策略會破壞批不變性。
另一種有趣的矩陣乘法并行策略是 stream-k。stream-k 的有趣之處在于,它比典型的矩陣乘法具有更少的“不變性”。如前所述,大多數矩陣乘法庫并非 batch-invariant,但至少可以稱為 batch-position-invariant(即改變 batch 中元素的位置不會影響數值結果)。然而,stream-k 連 batch-position-invariant 都不是!其核心洞見是:通過為不同的輸出 tile 以不同方式沿 k 維度切分,可以獲得更均衡的負載,但利用這一點會使我們的 kernel 也不再具備 batch-position-invariant 特性。
![]()
Split-K 矩陣乘法 如果我們的 batch 維度非常小,可能無法提供足夠的并行度,此時就需要使用 split-k 矩陣乘法。在這個例子中,我們將每個規約操作拆分到兩個核心上,這兩個核心分別累加,最后再合并結果。然而,把每個規約拆分到兩個核心,仍讓我們能夠充分利用八個核心。
矩陣乘法還有一個額外的復雜性——張量核心指令。對于歸約操作,我們可以一次只處理一行,而高效的矩陣乘法內核必須一次處理整個“瓦片”。
每條張量核心指令(例如wgmma.mma_async.sync.aligned.m64n128k16)內部可能采用不同的歸約順序。選擇不同張量核心指令的一個原因可能是批次非常小。例如,如果我們使用一條對長度為 256 的瓦片進行運算的張量核心 PTX 指令,而批次大小只有 32,那么幾乎浪費了所有算力!當批次大小為 1 時,最快的內核通常完全不使用張量核心。
![]()
填充的 Tensor-Core 指令 如果 batch size 太小,我們可能會遇到連一個 2D tile 都無法放入輸出的情況。此時,最有效的方法是切換到更小的 tensor-core 指令,或者干脆不用 tensor-core!然而,這兩種選擇都會使我們的 kernel 無法保持 batch 不變性。
因此,確保 matmul 的 batch 不變性最簡單的方法是:編譯一個 kernel 配置,并在所有形狀下都使用它。雖然會損失一些性能,但在 LLM 推理中這通常不會帶來災難性后果。特別是,split-k 在 M 和 N 都很小時才最需要,而幸運的是,在我們的場景里 N(即模型維度)通常非常大!
![]()
盡管實現了批次不變性,與 cuBLAS 相比我們只損失了約 20% 的性能。請注意,這也不是一個經過優化的 Triton 內核(例如沒有使用 TMA)。然而,性能中的一些模式可以說明我們的批次不變需求在何處導致性能下降。首先,在極小的批次規模下,由于指令過大且并行度不足,我們損失了大量性能。其次,隨著批次規模增加,會出現一種“拼圖”模式,這是由量化效應(包括 tile 和 wave)引起的,通常通過改變 tile 大小可以緩解。你可以在這里了解更多關于這些量化效應的信息。
批次不變注意力![]()
FlashAttention2 策略 我們沿著 Q 并行化,同時沿著 K/V 進行歸約。這意味著我們的整個歸約可以保持在單個核心內,使其成為另一種數據并行策略。
在為矩陣乘法實現批次不變性之后,注意力機制又引入了兩個額外的難題——恰如其分,因為它包含兩個矩陣乘法。
與 RMSNorm 和 matmul 僅沿特征維度進行歸約不同,我們現在同時沿特征維度和序列維度進行歸約。
由于上述原因,注意力機制必須處理各種影響序列處理方式的推理優化(分塊預填充、前綴緩存等)。
因此,為了在 LLM 推理中實現確定性,我們的數值計算必須不受以下兩個因素影響:一次處理多少請求,以及推理引擎如何對每個請求進行切片。
讓我們首先回顧 FlashAttention2 首次引入的標準注意力并行策略。與 RMSNorm 和 Matmul 類似,默認策略是“數據并行”策略。由于我們沿著 key/value 張量進行歸約,數據并行策略只能沿著 query 張量進行并行化。
例如,根據推理引擎的選擇,一個序列可能會被分塊處理(如分塊預填充),也可能一次性處理(如果預填充未被拆分)。為了實現“批處理不變性”,必須確保某個 token 的規約順序不依賴于其序列中同時被處理的其他 token 數量。如果你將 KV 緩存中的 K/V 值與當前正在處理的 token 的 K/V 值分開規約(如 vLLM 的 Triton attention kernel 所做的那樣),就無法實現這一點。例如,在處理序列中的第 1000 個查詢 token 時,無論 KV 緩存中有 0 個 token(預填充)還是 999 個 token(解碼),其規約順序都必須完全一致。![]()
帶 KV 緩存的 FlashAttention 之所以把 KV 緩存與當前 KV 值分開顯式處理會破壞批不變性,原因有些微妙,與“邊界條件”有關。具體來說,假設塊大小為 32,而當前 KV 緩存中有 80 個元素。我們再計算 48 個尚未緩存的元素。此時,需要 3 個塊(2 個完整塊 + 1 個掩碼塊)來計算 “P cache”,再需要 2 個塊(1 個完整塊 + 1 個掩碼塊)來計算 “P”。因此總共需要 5 個塊來完成歸約,而我們總共只有 4 個塊(即 128 個元素)需要計算,這必然會改變歸約順序。
例如,如果 KV 緩存為空,我們一次性處理 128 個元素,那么這兩種情況必須得到完全相同的數值,才能保證 attention 的“批不變性”。
為解決此問題,我們只需在 attention kernel 之前更新 KV 緩存和頁表,確保無論處理多少 token,鍵和值的布局始終一致。
有了這些額外細節(以及上一節提到的所有內容,如一致的 tile 大小),我們就能實現一個不受 batch 影響的 attention 實現!
然而,這里有一個顯著的問題。與矩陣乘法不同,我們在 LLM 推理中看到的 attention 形狀通常確實需要一個 split-reduction 內核,通常稱為 Split-KV 或 FlashDecoding。這是因為如果我們不沿著 reduction 維度并行化,就只能沿著 batch 維度、head 維度和“query 長度”維度并行化。在 attention 的 decode 階段,query 長度非常小,因此除非 batch size 非常大,否則我們通常無法充分利用 GPU。
不幸的是,這次不能像對待 RMSNorm 和 Matmul 那樣輕易忽略這種情況。例如,如果你有一個非常長的 KV cache,即使只處理一個請求,attention 內核也可能需要很長時間。![]()
固定 # Split-KV 策略(即 FlashDecode) 如果查詢長度變得非常小(如在解碼階段),我們可能會遇到內核中幾乎沒有任何并行性的情況。此時,我們需要再次沿歸約維度——這次是 KV 維度——進行切分。沿 KV 維度切分的典型策略是:先確定需要多少并行度,然后將 KV 維度均勻劃分。例如,如果 KV 長度為 1000 且需要 4 個分片,每個核心將處理 250 個元素。
不幸的是,這也破壞了批不變性,因為我們的精確歸約策略取決于在任何給定請求中我們要處理序列中的多少個查詢 token。
此外,常用于注意力的 split-reduction 策略也對批不變性提出了挑戰。例如,FlashInfer 的“平衡調度算法”會選擇仍能飽和所有 GPU 核心的最大 split-size,從而使歸約策略不再是“批不變”的。然而,與 RMSNorm/Matmul 不同,僅固定一個與 batch size 無關的 split 數量是不夠的。
相反,為了實現批不變性,我們必須采用“固定 split-size”策略。換句話說,我們不再固定 split 的數量,而是固定每個 split 的大小,從而得到可變的 split 數量。這樣,無論處理多少 token,我們都能保證始終執行完全相同的歸約順序。這需要對 FlexAttention 內部做一些修改,這些修改尚未包含在我們的代碼發布中。我們將在不久的將來將其上游!![]()
固定大小 Split-KV 策略 此策略與前一種策略的唯一區別在于,我們的拆分現在是“固定大小”的。例如,如果 KV 長度為 1000,我們不再將其拆分為四個等長 250 的片段,而是拆分為三個固定大小 256 的片段和一個 232 的片段。
這使得我們能夠保持批次不變性,因為我們的歸約策略不再依賴于我們一次性處理的查詢 token 數量!
實現
我們通過利用 vLLM 的 FlexAttention 后端以及 torch.Library,在 vLLM 之上實現了確定性推理的演示。借助 torch.Library,我們能夠以非侵入式的方式替換掉大多數相關的 PyTorch 算子。你可以在 thinking-machines-lab/batch-invariant-ops 找到“批不變”內核庫,以及以“確定性”模式運行 vLLM 的示例。
實驗 補全結果有多不確定?
我們使用Qwen/Qwen3-235B-A22B-Instruct-2507,在溫度 0 下對提示“告訴我關于理查德·費曼的事”(非思考模式)采樣 1000 次補全,每次生成 1000 個 token。令人驚訝的是,我們得到了 80 種不同的補全結果,其中最常見的一種出現了 78 次。
觀察這些補全結果出現差異的位置,我們發現前 102 個 token 實際上完全相同!第一次出現分歧是在第 103 個 token。所有補全都生成了序列“Feynman was born on May 11, 1918, in”,然而其中 992 個補全繼續生成“Queens, New York”,而另外 8 個補全則生成“New York City”。
另一方面,當我們啟用批不變內核時,我們的 1000 個補全結果完全一致。這正是我們從采樣器數學上期望的結果,但如果沒有批不變內核,我們就無法獲得確定性的輸出。
性能
我們尚未對批不變內核的性能進行顯著優化。不過,讓我們運行一些實驗來驗證性能是否仍然可用。
我們將使用一塊 GPU 啟動一個 API 服務器,運行 Qwen-3-8B,并請求 1000 條序列,輸出長度在 90 到 110 之間。
配置
時間(秒)
vLLM 默認
26
未優化的確定性 vLLM
55
+ 改進的注意力內核
42
大部分性能下降源于 vLLM 中的 FlexAttention 集成尚未經過深度優化。盡管如此,性能表現并不算災難級。
真正的 on-policy RL
正如研究人員所指出的,訓練與推理之間數值上的差異,會隱式地將我們的 on-policy RL 變成 off-policy RL。
當然,如果連兩次完全相同的推理請求都無法得到按位一致的結果,就更不可能在訓練與推理之間實現按位一致。而確定性推理使我們能夠進一步改造訓練棧,從而在采樣與訓練之間獲得按位一致的結果,最終實現真正的 on-policy RL。
我們在 Bigmath 上使用 RLVR 設置進行實驗,RL 策略從 Qwen 2.5-VL instruct 8B 初始化,最大 rollout 長度為 4096。
如果我們訓練時不進行 off-policy 校正(即不使用重要性加權),獎勵會在訓練中途崩潰;而加入 off-policy 校正項后,訓練可以順利進行。然而,如果我們的采樣器和訓練器在比特級別完全一致,我們就完全處于 on-policy(即 KL 散度為 0),同樣可以順利訓練。
我們還可以繪制采樣器與訓練器之間 logprobs 的 KL 散度,三條曲線表現出明顯不同的行為。使用重要性加權時,KL 散度保持在約 0.001,偶爾出現峰值。然而,不使用重要性加權時,KL 散度最終會在獎勵崩潰的同一時間點出現飆升。當然,在運行“True On-Policy RL”時,KL 散度始終為 0,表明訓練策略與采樣策略之間沒有差異。
![]()
請注意,未使用重要性加權的運行在第 318 步左右出現了顯著的損失尖峰,同時 logprobs 的 KL 散度也相應飆升。相比之下,無論是采用 off-policy 修正還是“True On-Policy”運行,RL 都能平穩繼續。藍色線顯示的“True On-Policy”并非 bug——它只是 0 處的一條平直線。
結論
現代軟件系統包含多層抽象。在機器學習中,當我們遇到非確定性和細微數值差異時,往往會傾向于掩蓋它們。畢竟,我們的系統已經是“概率性”的,再多一點非確定性又何妨?把失敗單元測試的 atol/rtol 調高一點又有什么問題?訓練器和采樣器之間 logprobs 的差異大概不是真正的 bug,對吧?
我們拒絕這種失敗主義。只需稍加努力,我們就能理解非確定性的根本原因,甚至解決它們!我們希望這篇博客文章能為社區提供如何消除推理系統中非確定性的扎實理解,并激勵更多人全面掌握自己的系統。
引用
@article{he2025nondeterminism, author = {Horace He and Thinking Machines Lab}, title = {Defeating Nondeterminism in LLM Inference}, journal = {Thinking Machines Lab: Connectionism}, year = {2025}, note = {https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/}, doi = {10.64434/tml.20250910} }特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.