![]()
https://medium.com/@prathamgrover777/kv-caching-attention-optimization-from-o-n%C2%B2-to-o-n-8b605f0d4072
我們見過 LLM 如何逐字逐句地敲出上千字的回答,仿佛“邊想邊說”。表面順滑,背后卻低效得驚人。
在生成第 t 步時,模型必須確保下一個詞與之前所有內容保持一致:你的提示、已生成的部分、系統指令,乃至任何隱藏上下文。實際上,模型會重新通過所有 Transformer 層重建之前全部的隱藏狀態,并再次計算 Query、Key 和 Value,即便前面的 token 絲毫未變。這種重復計算逐層、逐頭進行,沒有任何復用。如圖 1 所示,由于之前步驟的結果未被復用,每個 token 的計算量隨序列長度持續增加。
![]()
圖 1. 解碼階段每個 token 的計算成本(樸素 vs KV 緩存)。
為什么這種方式擴展性這么差?想象你在寫一段文字,每添加一個新句子之前,都要從頭把整篇文檔重新讀一遍。然后寫下一個句子時,再從頭讀一遍。如此反復。這就是樸素解碼循環在做的事情。
現在把它放進 transformer:
L= 層數
H= 每層注意力頭數
n= 當前序列長度
因此,每個 token 的計算量不僅與序列長度 n 成正比,內部還要再乘以 L × H。這就是成本飆升的原因。
現在,把這個問題放到現代 LLMs 的規模下:
幾十層網絡層層堆疊
每一層有多個注意力頭,各自“角度”不同地回顧歷史。
長提示(上千條),還要生成一長段回答。
更致命的是,模型對過去的 Key 和 Value 毫無記憶;因此在第 1000 個 token 時,它得把 token 1 到 999 全部重新算一遍。這就導致解碼的時間復雜度是 **O(n2)**。
結果就是巨大的冗余。
![]()
圖 2. 無 KV 緩存的樸素解碼
如圖 2 所示,模型不會記住過去的 K/V 投影,而是對每個新 token 都重新計算它們。 當 t = 2 時 → 重新計算 token **1**當 t = 3 時 → 重新計算 token 1 和 **2**對于 t = 4 → 重新計算 tokens **1, 2, 3**對于 t = n → 重新計算 tokens 1 到 n?1
在每一層內部,對于每個 token,它執行:
K = X · W?
V = X · W?
Q · K?(與所有先前 key 的點積)
對所有先前 Value 的加權和。
這些結果從未被復用,因此到第 n 個 token 時,你已經重復了(n–1) × L × H次!這意味著計算量持續增長, 永遠不會穩定 。
硬件瓶頸
在解決它之前,我們必須先理解瓶頸所在:搬運數據代價高昂。GPU 做數學運算極快,但真正的成本往往在于把正確的數據在正確的時間送到正確的位置。
舉個簡單的例子,把它想象成廚房里的廚師:
GPU 核心就是那位廚師,切菜和烹飪的速度快得驚人。
VRAM(GPU 自有內存)就像緊挨著廚師的小備餐臺,空間有限,但伸手就能拿到上面的東西。
系統 RAM則是走廊盡頭的大儲藏室,空間充足,可每次都得停下、走過去再把食材搬回來。
用 ML 的話說:
GPU 核心每秒能執行數萬億次FLOPs,真正的瓶頸并不是算力。
但內存帶寬(將 K/V 張量搬運到高帶寬內存)是有限的。
并且注意力機制需要反復讀取這些張量,這會把帶寬壓垮。
這就是為什么序列變長時生成會變慢——不是因為計算變復雜,而是因為 GPU等數據的時間比真正計算的時間還多。
模型的權重和對話的歷史 token都必須放在顯存里才能快速處理。但在樸素的解碼中,每生成一個新 token,廚師(GPU)就得折回儲藏室(內存)再拿一遍同樣的食材(歷史 K/V)。儲藏室與灶臺之間的路越來越擠,走路時間越來越多,真正炒菜的時間越來越少。
那種對舊數據持續、重復的抓取,就是我們所說的帶寬之痛。“帶寬”指的是這條通路的容量,“痛”則源于把剛才已經存在的數據再次塞進去所造成的嚴重擁堵。帶著這個概念,我們來看看到底在哪個環節,這一過程在規模擴大時會崩潰。更具體地說,注意力機制會變成受內存限制,而非計算限制。你的 GPU 空轉著,等待從顯存取數,盡管它的算力足以瞬間完成運算。
樸素的注意力機制在何處因規模而崩潰?
延遲悄然上升:早期的 token 反應迅捷,但隨著對話變長,每一個新 token 的生成時間都比前一個更久。你正拖著一段越來越長的歷史前行。這正是注意力計算復雜度帶來的直接后果——它是二次方,即 **O(n2)**,其中n為序列長度。
帶寬之痛,而非算力之痛:現代 GPU 的數學運算飛快,但數據在內存之間搬來搬去才是“征稅員”;反復把整個“過去”拖過總線,會把帶寬壓垮。你更多時間是在等數據,而不是在計算。
推理崩潰:在線上,你不是給一個人生成,而是同時給成千上萬人生成。如果每條 token 流都重新處理自己的完整歷史,系統立刻垮掉,成本飆升。
此刻,一個自然的問題浮現:**“等等……我們為什么每次都重復同樣的投影?”****KV 緩存**就是你拒絕重讀過去的瞬間。簡單說,KV 緩存就是:一旦某 token 在某層里的 Key 和 Value 向量算完,我們把它存進 GPU 內存,而不是直接扔掉。
在注意力機制中,每個 token 被轉換為兩個緊湊向量:Key(K)和Value(V),它們描述_該 token 應如何與后續 token 交互_ 。生成方式是將 token 的嵌入(x)通過該層的權重矩陣:
Key = X @ W?
Value = X @ W?
這并非只計算一次,而是對模型的每一層、每一個注意力頭,都各自計算一套 K 和 V。
舉例:若模型有 32 層、每層 32 個頭,則每個 token 要計算 32 × 32 = 1024 組 K/V 投影。
訣竅在這里:
在推理過程中,模型的權重(W?、W?)不會改變。
一旦計算完成,token 的嵌入向量X也不再改變。
這意味著它的Key和Value向量是確定性的。一旦算出,它們就像被刻在石頭上一樣,在整個序列中不會變化。
那為何每一步都要重新計算它們呢?
于是,不再 :每生成一個新 token → 為所有舊 token 重新計算 K/V。 而是: 一次性算出 K/V → 存起來 → 后續所有 token 直接復用。
這就是KV 緩存。
底層到底發生了什么變化?當模型生成第 t 個 token 時,常規的注意力操作會這樣執行:
![]()
在樸素解碼中,每當時間步 t 增加,模型會重新計算所有層、所有頭之前的 K 和 V 向量,只為再次把它們代入這個方程。
啟用 KV 緩存后,方程本身_并未_改變,但K 和 V 的來源變了。
我們停止重新計算K?…K??? 和 V?…V???
取而代之的是,我們從 GPU 內存中一個名為KV Cache的張量里直接讀取它們
因此,同樣的公式變為:
![]()
模型不再在每一步重新計算 K?…K??? 和 V?…V???,而是直接從 GPU 內存中讀取。在 token t 時唯一需要的新工作是計算 K? 和 V?,并將它們追加到緩存中。
KV 緩存在 GPU 內部到底是什么樣子?
它既不是列表,也不是 Python 字典。在實際的 LLM 實現中(如 vLLM、TensorRT-LLM、Hugging Face),緩存以張量形式存儲在 GPU 顯存中,維度固定。
![]()
num_layers= 模型中 Transformer 塊的總數(例如 LLaMA-7B 為 32)
num_heads= 每層注意力頭數(例如 32)
seq_len= 當前已見的 token 數量
head_dim= 每個注意力頭的維度(例如 64 或 128)
每當生成一個新 token:
→ 我們計算K?和V?
→ 我們沿著seq_len維度將它們追加到末尾
→ 其他所有內容保持不變
我們為什么不緩存 Query(Q)?
因為Q(Query 向量)與K和V在本質上不同。
Key 和 Value 代表 之前 token 的記憶 。
它們一旦計算完成就不會再改變。但 Query 只依賴于_當前正在生成的 token_,而非過去的。
Q?用來提問:“鑒于我已看到的全部(所有已緩存的K/V),下一個 token 應該是什么?”
所以:
K 和 V = 記憶 → 緩存一次,反復使用
Q = 按步驟、臨時生成 → 無需存儲
如果我們緩存了Q,就永遠不會再用到它,因為它只在當前時間步使用。緩存它只會白白浪費內存,毫無收益。
可視化差異
為了真正理解 KV 緩存的神奇之處,讓我們跟隨動畫,看看我們的 LLM 如何生成短語“I Love cats”。我們將重點觀察模型如何處理這些 token,以預測序列中的下一個詞。(GIF 可能加載較慢——稍等片刻)
![]()
圖 3. 有無 KV 緩存的對比:過去的 K/V 是重新計算 vs. 復用。 1. 無緩存
原始而浪費的做法。每一步都必須從頭重新處理全部內容。
步驟 1:預測 “Love”(歷史:“I”)
模型接收第一個 token “I”。
它計算其 Key(記為K?)及其 Value(記為V?)。
它還計算其 Query(Q?)。
它執行注意力計算(Q?關注K?),以預測下一個詞:“Love”。
然后,它會丟棄掉K?和V?,所有這些計算成果都被浪費。
在張量層面,這意味著 GPU 剛剛計算出形狀為
[num_heads, head_dim]的矩陣K?和V?,卻立即將它們丟棄。當模型處理下一個 token 時,會毫無必要地重新構建這些相同的矩陣。
步驟 2:預測“cats”(歷史:“I Love”)
模型現在需要處理新 token“Love”,但它對“I”沒有任何記憶。
它必須重新為“I”計算 Key 和 Value(生成K’?和V’?)。
它還會為“Love”計算 Key 和 Value(生成K?和V?),并為新 token“Love”計算 Query(Q?)。
它執行注意力計算(Q?同時查看K’?和K?)來預測“cats”。
隨后,它把K’?、V’?、K?和V?全部丟棄。
你看出規律了嗎?為了預測第三個詞,我們不得不把第一個詞的所有計算重新做一遍。
這樣,模型只需計算每個 token 的 Key 和 Value 一次 ,并將其保存下來。
步驟 1:預測 “Love”(歷史:“I”)
模型接收第一個 token “I”。
它計算自己的 Key(K?)和 Value(V?)。
執行注意力計算以預測 “Love”。
關鍵是,它將K?和V?存入一塊特殊內存:KV 緩存。
緩存現在包含:{(K?, V?)}。
這意味著 GPU 現在為每一層、每一個注意力頭都保存著一個小張量,代表該詞元對所有未來注意力查詢的貢獻。無需重新計算,只需查表即可。
步驟 2:預測“cats”(歷史:“I Love”)
模型忽略 “I” 詞元。
它知道它已經存在于緩存中。
它只處理新的 token “Love”。
它只為 “Love” 計算 Key(K?)和 Value(V?)。
它將這對新的鍵值追加到緩存中。
緩存現在包含:_{(K?, V?), (K?, V?)}_ 它為“Love”計算查詢(Q?),并通過查看整個緩存(K?和K?)進行注意力計算。
它預測“cats”。沒有任何內容被丟棄。
此時工作量恒定且最小。這就是每一步的工作量變為線性(**O(n)**)的方式。
模型依舊會“回顧”歷史,只是它通過查詢一張簡單的查找表(緩存)來完成,而無需從頭重新計算。注意力計算仍會將Q?與所有過去的鍵進行比較,但由于過去的 K/V 已從緩存中取出,重算成本降至 **O(1)**。每一步的工作現在僅僅是將一個新查詢與已有鍵進行比較。
這種優化是現代推理服務器(如 vLLM、TensorRT-LLM 或 Hugging Face 的transformers庫)能夠實現實時文本生成的原因。沒有它,每生成幾百個 token,延遲就會翻倍,讓聊天模型根本無法使用。
老實說,這并不像發明一種新算法,你只是不再做重復勞動。這就是 KV 緩存的作用:過去不再重新計算,而是被記住。
權衡:用內存換計算![]()
圖 4. KV 緩存作為時空權衡
KV 緩存并非“免費的午餐”,而是一種典型的時空權衡:我們消除了冗余計算(時間),但必須把緩存塞進 GPU 顯存(空間),而且這塊緩存會變得非常非常龐大!
KV 緩存的大小由以下幾個因素決定:
對于一個上下文窗口長達 32,000 token 的大模型,這塊緩存就能吃掉幾十 GB 的寶貴 GPU 顯存。如果服務器還要同時服務這么多并發用戶(大批量),內存需求就會成為主要瓶頸,直接限制系統容量。
極簡代碼示例(PyTorch 樸素實現 vs KV 緩存)
到目前為止,我們只_討論_了 KV 緩存。現在讓我們通過代碼來_親眼看看它是如何工作的_ 。
下面的示例清楚地展示了樸素解碼如何一次又一次地重復計算所有內容,以及 KV 緩存如何通過存儲先前計算好的鍵(K)和值(V)來避免這種浪費。
import torch
import torch.nn as nn
# A single multi-head attention layer
attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
# Dummy input sequence
tokens = torch.randn(1, 5, 512) # [batch, seq_len, embedding_dim]
# -------------------------------
# 1. Naive decoding (no caching)
# Recomputes attention over full history at every step
# -------------------------------
for t in range(1, tokens.size(1)):
x = tokens[:, :t, :] # tokens from 1 to t
out, _ = attn(x, x, x) # recompute Q,K,V for all past tokens again
# -------------------------------
# 2. KV Caching (compute K/V once → reuse forever)
# -------------------------------
past_k, past_v = None, None
for t in range(tokens.size(1)):
x = tokens[:, t:t+1, :] # only the new token
# Project to Q, K, V (like attention does internally)
q = attn.in_proj_q(x)
k = attn.in_proj_k(x)
v = attn.in_proj_v(x)
# Save (or append) K/V into cache
past_k = k if past_k isNoneelse torch.cat([past_k, k], dim=1)
past_v = v if past_v isNoneelse torch.cat([past_v, v], dim=1)
# Attention now only compares new query with cached keys
attn_scores = torch.matmul(q, past_k.transpose(-1, -2)) / (k.size(-1) ** 0.5)
attn_probs = attn_scores.softmax(dim=-1)
output = torch.matmul(attn_probs, past_v)
故事基本上就是這樣。
KV 緩存并沒有讓注意力機制變得更智能,它只是讓它不再愚蠢。與其在每個 token 上重新計算過去的內容,我們只需計算一次并記住它。這就是為什么即使上下文變長,生成速度依然很快。
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.