Multi-Head Latent Attention (MLA) **Summary:** Multi-Head Latent Attention (MLA) is an attention mechanism used in DeepSeek-V2/V3 and Kimi K2.x models that compresses the Key-Value (KV) cache by projecting full KV pairs into a shared, low-dimensional latent space. This achieves a 5-10x reduction in KV cache size with minimal quality loss by storing only a single latent vector per token instead of separate KV pairs for each attention head. The design also absorbs up-projection matrices into query projections to avoid explicit key decompression during attention score computation, significantly improving memory efficiency during the decoding phase. Compressing KV cache via low-rank projections — the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x Multi-Head Latent Attention MLA is the attention variant that replaces standard Multi-Head Attention MHA in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 5-10x KV cache compression with minimal quality loss. For input X∈Rn×d\mathbf{X} \in \mathbb{R}^{n \times d}X∈Rn×d , MHA computes per-head projections: where WQ h ∈Rd×dk\mathbf{W} Q^{ h } \in \mathbb{R}^{d \times d k}WQ h ∈Rd×dk , WK h ∈Rd×dk\mathbf{W} K^{ h } \in \mathbb{R}^{d \times d k}WK h ∈Rd×dk , WV h ∈Rd×dv\mathbf{W} V^{ h } \in \mathbb{R}^{d \times d v}WV h ∈Rd×dv . KV cache size per token: 2×nh×dk2 \times n h \times d k2×nh×dk elements. MLA replaces the per-head KV projections with a shared low-rank latent compression: Compression KV → Latent : where WDKV∈Rd×dc\mathbf{W} {DKV} \in \mathbb{R}^{d \times d c}WDKV∈Rd×dc is the down-projection matrix and dc≪nh×dkd c \ll n h \times d kdc≪nh×dk . Decompression Latent → KV : where WUK h ∈Rdc×dk\mathbf{W}{UK}^{ h } \in \mathbb{R}^{d c \times d k}WUK h ∈Rdc×dk and WUV h ∈Rdc×dv\mathbf{W}{UV}^{ h } \in \mathbb{R}^{d c \times d v}WUV h ∈Rdc×dv are up-projection matrices. KV cache per token: Only cKV∈Rdc\mathbf{c}^{KV} \in \mathbb{R}^{d c}cKV∈Rdc is stored — a single vector of dimension dcd cdc . For a model with nhn hnh heads and head dimension dkd kdk : In DeepSeek-V3: nh=128n h = 128nh=128 , dk=128d k = 128dk=128 , dc=512d c = 512dc=512 : MLA also compresses queries for training efficiency: This doesn't affect the KV cache but reduces the activation memory during training. RoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate "absorbed" key projection: where WKR h ∈Rdc×dr\mathbf{W} {KR}^{ h } \in \mathbb{R}^{d c \times d r}WKR h ∈Rdc×dr with dr≪dkd r \ll d kdr≪dk is a narrow projection that carries positional information. The cached representation remains cKV\mathbf{c}^{KV}cKV position-agnostic , and the RoPE key K^h\hat{\mathbf{K}} hK^h is recomputed at attention time from the cached latent. The critical insight in MLA is that the up-projection matrices WUK h \mathbf{W} {UK}^{ h }WUK h can be absorbed into the query projection during attention computation: Substituting the decompressed forms: If we define Wabsorbed h =WUQ h WUK h T∈Rdc′×dc\mathbf{W}{absorbed}^{ h } = \mathbf{W}{UQ}^{ h } {\mathbf{W} {UK}^{ h }}^T \in \mathbb{R}^{d c' \times d c}Wabsorbed h =WUQ h WUK h T∈Rdc′×dc , then: This means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output. Practical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax. RoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key: The attention score becomes: Practical implication: The KV cache stores both cKV\mathbf{c}^{KV}cKV latent and Khrope\mathbf{K} h^{rope}Khrope decoupled rope key . Total cache per token: dc+nh×drd c + n h \times d rdc+nh×dr . GQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information. MLA dramatically changes the memory-vs-compute tradeoff in serving: Memory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows: Compute-bound prefill phase: MLA adds decompression overhead, but this is amortized: This is where it gets interesting for Siraj's EAGLE-3 work: Draft model constraints: Verification with MLA: vLLM implementation challenge: vLLM's PagedAttention was designed for MHA/GQA. MLA requires: import torch import torch.nn as nn import math class MultiHeadLatentAttention nn.Module : """ MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture. Key features: - Low-rank KV compression cache only c KV latent vector - Decoupled RoPE for position-aware attention - Weight absorption for efficient score computation """ def init self, d model: int = 4096, n heads: int = 128, d k: int = 128, d v: int = 128, d c: int = 512, KV latent dimension compression target d c prime: int = 1536, Query latent dimension d r: int = 64, Decoupled RoPE key dimension per head max seq len: int = 8192, rope base: float = 10000.0, : super . init self.d model = d model self.n heads = n heads self.d k = d k self.d v = d v self.d c = d c self.d c prime = d c prime self.d r = d r === Down-projections compression === self.w dkv = nn.Linear d model, d c, bias=False KV latent self.w dq = nn.Linear d model, d c prime, bias=False Q latent === Up-projections decompression === KV up-projections: latent - per-head K and V self.w uk = nn.Linear d c, n heads d k, bias=False self.w uv = nn.Linear d c, n heads d v, bias=False Q up-projection: latent - per-head Q self.w uq = nn.Linear d c prime, n heads d k, bias=False === Decoupled RoPE projections === self.w kr = nn.Linear d c, n heads d r, bias=False Rope key from latent self.w qr = nn.Linear d c prime, n heads d r, bias=False Rope query from latent === Output projection === self.w o = nn.Linear n heads d v, d model, bias=False RoPE frequencies inv freq = 1.0 / rope base torch.arange 0, d r, 2 .float / d r self.register buffer 'inv freq', inv freq def apply rope self, x: torch.Tensor, seq len: int - torch.Tensor: """Apply rotary position embedding to tensor of shape batch, seq, n heads, d r .""" t = torch.arange seq len, device=x.device, dtype=self.inv freq.dtype freqs = torch.outer t, self.inv freq seq, d r//2 cos = freqs.cos .unsqueeze 0 .unsqueeze 2 1, seq, 1, d r//2 sin = freqs.sin .unsqueeze 0 .unsqueeze 2 x1, x2 = x ..., ::2 , x ..., 1::2 rotated = torch.stack x1 cos - x2 sin, x1 sin + x2 cos, , dim=-1 .flatten -2 return rotated def forward self, x: torch.Tensor, kv cache: torch.Tensor = None, start pos: int = 0, - tuple torch.Tensor, torch.Tensor : """ Args: x: Input tensor batch, seq len, d model kv cache: Cached c KV from previous tokens batch, cache len, d c start pos: Position offset for RoPE Returns: output: batch, seq len, d model new kv cache: Updated cache batch, cache len + seq len, d c """ B, S, = x.shape === Step 1: Compress to latent space === c kv = self.w dkv x B, S, d c — THIS is what gets cached c q = self.w dq x B, S, d c' === Step 2: Decompress for attention computation === K, V up-projection from latent k content = self.w uk c kv B, S, n heads d k v = self.w uv c kv B, S, n heads d v q content = self.w uq c q B, S, n heads d k Reshape to multi-head format q content = q content.view B, S, self.n heads, self.d k k content = k content.view B, S, self.n heads, self.d k v = v.view B, S, self.n heads, self.d v === Step 3: Decoupled RoPE === Project to rope-specific dimensions and apply RoPE k rope = self.w kr c kv .view B, S, self.n heads, self.d r q rope = self.w qr c q .view B, S, self.n heads, self.d r k rope = self. apply rope k rope, start pos + S q rope = self. apply rope q rope, start pos + S Concatenate content + rope for full key and query q = torch.cat q content, q rope , dim=-1 B, S, n heads, d k + d r k = torch.cat k content, k rope , dim=-1 B, S, n heads, d k + d r === Step 4: KV cache management === if kv cache is not None: Append new latent to cache new kv cache = torch.cat kv cache, c kv , dim=1 Decompress full cache for attention k cache = self.w uk kv cache .view B, -1, self.n heads, self.d k k cache rope = self. apply rope self.w kr kv cache .view B, -1, self.n heads, self.d r , start pos cache already has positions 0..start pos-1 k = torch.cat torch.cat k cache, k cache rope , dim=-1 , k , dim=1 v cache = self.w uv kv cache .view B, -1, self.n heads, self.d v v = torch.cat v cache, v , dim=1 else: new kv cache = c kv === Step 5: Compute attention === Transpose for attention: B, n heads, seq, dim q = q.transpose 1, 2 k = k.transpose 1, 2 v = v.transpose 1, 2 d attn = self.d k + self.d r attn weights = torch.matmul q, k.transpose -2, -1 / math.sqrt d attn attn weights = torch.softmax attn weights, dim=-1 attn output = torch.matmul attn weights, v B, n heads, S, d v === Step 6: Output projection === attn output = attn output.transpose 1, 2 .contiguous .view B, S, -1 output = self.w o attn output return output, new kv cache === Example: Compare MLA vs MHA cache sizes === def compare cache sizes : """Demonstrate the KV cache savings of MLA over MHA.""" n heads = 128 d k = 128 d c = 512 DeepSeek-V3 latent dim d r = 64 Decoupled rope dim seq len = 65536 64K context bytes per element = 2 FP16 MHA: cache K and V for all heads mha cache per token = 2 n heads d k K + V mha total = mha cache per token seq len bytes per element / 1024 3 MLA: cache only c KV + decoupled rope keys mla cache per token = d c + n heads d r latent + rope keys mla total = mla cache per token seq len bytes per element / 1024 3 print f"MHA KV cache 64K ctx : {mha total:.2f} GB per layer" print f"MLA KV cache 64K ctx : {mla total:.2f} GB per layer" print f"Compression ratio: {mha cache per token / mla cache per token:.1f}x" print f"\nFor 60 layers:" print f" MHA: {mha total 60:.1f} GB" print f" MLA: {mla total 60:.1f} GB" print f" Savings: { mha total - mla total 60:.1f} GB" if name == " main ": Test MLA forward pass mla = MultiHeadLatentAttention d model=4096, n heads=8, d k=64, d v=64, d c=128, d c prime=256, d r=32, x = torch.randn 2, 10, 4096 batch=2, seq=10 output, cache = mla x print f"Output shape: {output.shape}" 2, 10, 4096 print f"Cache shape: {cache.shape}" 2, 10, 128 — only d c Autoregressive step x2 = torch.randn 2, 1, 4096 output2, cache2 = mla x2, kv cache=cache, start pos=10 print f"Output2 shape: {output2.shape}" 2, 1, 4096 print f"Cache2 shape: {cache2.shape}" 2, 11, 128 — grew by 1 print "\n--- Cache Comparison ---" compare cache sizes DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — Liu et al., 2024. arxiv:2405.04434 — Original MLA paper introducing the latent compression and decoupled RoPE strategy. DeepSeek-V3 Technical Report — DeepSeek-AI, 2024. arxiv:2412.19437 — Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction MTP that inspired EAGLE-style draft heads. vLLM MLA Implementation — github.com/vllm-project/vllm — Production MLA kernel with weight absorption and FlashAttention integration. FlashInfer MLA Attention — github.com/flashinfer-ai/flashinfer — Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.