{"slug": "multi-head-latent-attention-mla", "title": "Multi-Head Latent Attention (MLA)", "summary": "**Summary:** Multi-Head Latent Attention (MLA) is an attention mechanism used in DeepSeek-V2/V3 and Kimi K2.x models that compresses the Key-Value (KV) cache by projecting full KV pairs into a shared, low-dimensional latent space. This achieves a 5-10x reduction in KV cache size with minimal quality loss by storing only a single latent vector per token instead of separate KV pairs for each attention head. The design also absorbs up-projection matrices into query projections to avoid explicit key decompression during attention score computation, significantly improving memory efficiency during the decoding phase.", "body_md": "Compressing KV cache via low-rank projections — the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x\nMulti-Head Latent Attention (MLA) is the attention variant that replaces standard Multi-Head Attention (MHA) in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 5-10x KV cache compression with minimal quality loss.\nFor input X∈Rn×d\\mathbf{X} \\in \\mathbb{R}^{n \\times d}X∈Rn×d , MHA computes per-head projections:\nwhere WQ(h)∈Rd×dk\\mathbf{W}_Q^{(h)} \\in \\mathbb{R}^{d \\times d_k}WQ(h)∈Rd×dk , WK(h)∈Rd×dk\\mathbf{W}_K^{(h)} \\in \\mathbb{R}^{d \\times d_k}WK(h)∈Rd×dk , WV(h)∈Rd×dv\\mathbf{W}_V^{(h)} \\in \\mathbb{R}^{d \\times d_v}WV(h)∈Rd×dv .\nKV cache size per token: 2×nh×dk2 \\times n_h \\times d_k2×nh×dk elements.\nMLA replaces the per-head KV projections with a shared low-rank latent compression:\nCompression (KV → Latent):\nwhere WDKV∈Rd×dc\\mathbf{W}_{DKV} \\in \\mathbb{R}^{d \\times d_c}WDKV∈Rd×dc is the down-projection matrix and dc≪nh×dkd_c \\ll n_h \\times d_kdc≪nh×dk .\nDecompression (Latent → KV):\nwhere WUK(h)∈Rdc×dk\\mathbf{W}{UK}^{(h)} \\in \\mathbb{R}^{d_c \\times d_k}WUK(h)∈Rdc×dk and WUV(h)∈Rdc×dv\\mathbf{W}{UV}^{(h)} \\in \\mathbb{R}^{d_c \\times d_v}WUV(h)∈Rdc×dv are up-projection matrices.\nKV cache per token: Only cKV∈Rdc\\mathbf{c}^{KV} \\in \\mathbb{R}^{d_c}cKV∈Rdc is stored — a single vector of dimension dcd_cdc .\nFor a model with nhn_hnh heads and head dimension dkd_kdk :\nIn DeepSeek-V3:\nnh=128n_h = 128nh=128\n,\ndk=128d_k = 128dk=128\n,\ndc=512d_c = 512dc=512\n:\nMLA also compresses queries for training efficiency:\nThis doesn't affect the KV cache but reduces the activation memory during training.\nRoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate \"absorbed\" key projection:\nwhere WKR(h)∈Rdc×dr\\mathbf{W}_{KR}^{(h)} \\in \\mathbb{R}^{d_c \\times d_r}WKR(h)∈Rdc×dr with dr≪dkd_r \\ll d_kdr≪dk is a narrow projection that carries positional information. The cached representation remains cKV\\mathbf{c}^{KV}cKV (position-agnostic), and the RoPE key K^h\\hat{\\mathbf{K}}_hK^h is recomputed at attention time from the cached latent.\nThe critical insight in MLA is that the up-projection matrices WUK(h)\\mathbf{W}_{UK}^{(h)}WUK(h) can be absorbed into the query projection during attention computation:\nSubstituting the decompressed forms:\nIf we define Wabsorbed(h)=WUQ(h)WUK(h)T∈Rdc′×dc\\mathbf{W}{absorbed}^{(h)} = \\mathbf{W}{UQ}^{(h)} {\\mathbf{W}_{UK}^{(h)}}^T \\in \\mathbb{R}^{d_c' \\times d_c}Wabsorbed(h)=WUQ(h)WUK(h)T∈Rdc′×dc , then:\nThis means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output.\nPractical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax.\nRoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key:\nThe attention score becomes:\nPractical implication: The KV cache stores both cKV\\mathbf{c}^{KV}cKV (latent) and Khrope\\mathbf{K}_h^{rope}Khrope (decoupled rope key). Total cache per token: dc+nh×drd_c + n_h \\times d_rdc+nh×dr .\nGQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information.\nMLA dramatically changes the memory-vs-compute tradeoff in serving:\nMemory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows:\nCompute-bound prefill phase: MLA adds decompression overhead, but this is amortized:\nThis is where it gets interesting for Siraj's EAGLE-3 work:\nDraft model constraints:\nVerification with MLA:\nvLLM implementation challenge: vLLM's PagedAttention was designed for MHA/GQA. MLA requires:\nimport torch\nimport torch.nn as nn\nimport math\nclass MultiHeadLatentAttention(nn.Module):\n\"\"\"\nMLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.\nKey features:\n- Low-rank KV compression (cache only c_KV latent vector)\n- Decoupled RoPE for position-aware attention\n- Weight absorption for efficient score computation\n\"\"\"\ndef __init__(\nself,\nd_model: int = 4096,\nn_heads: int = 128,\nd_k: int = 128,\nd_v: int = 128,\nd_c: int = 512, # KV latent dimension (compression target)\nd_c_prime: int = 1536, # Query latent dimension\nd_r: int = 64, # Decoupled RoPE key dimension per head\nmax_seq_len: int = 8192,\nrope_base: float = 10000.0,\n):\nsuper().__init__()\nself.d_model = d_model\nself.n_heads = n_heads\nself.d_k = d_k\nself.d_v = d_v\nself.d_c = d_c\nself.d_c_prime = d_c_prime\nself.d_r = d_r\n# === Down-projections (compression) ===\nself.w_dkv = nn.Linear(d_model, d_c, bias=False) # KV latent\nself.w_dq = nn.Linear(d_model, d_c_prime, bias=False) # Q latent\n# === Up-projections (decompression) ===\n# KV up-projections: latent -> per-head K and V\nself.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)\nself.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)\n# Q up-projection: latent -> per-head Q\nself.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)\n# === Decoupled RoPE projections ===\nself.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False) # Rope key from latent\nself.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False) # Rope query from latent\n# === Output projection ===\nself.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)\n# RoPE frequencies\ninv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))\nself.register_buffer('inv_freq', inv_freq)\ndef _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:\n\"\"\"Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r].\"\"\"\nt = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)\nfreqs = torch.outer(t, self.inv_freq) # [seq, d_r//2]\ncos = freqs.cos().unsqueeze(0).unsqueeze(2) # [1, seq, 1, d_r//2]\nsin = freqs.sin().unsqueeze(0).unsqueeze(2)\nx1, x2 = x[..., ::2], x[..., 1::2]\nrotated = torch.stack([\nx1 * cos - x2 * sin,\nx1 * sin + x2 * cos,\n], dim=-1).flatten(-2)\nreturn rotated\ndef forward(\nself,\nx: torch.Tensor,\nkv_cache: torch.Tensor = None,\nstart_pos: int = 0,\n) -> tuple[torch.Tensor, torch.Tensor]:\n\"\"\"\nArgs:\nx: Input tensor [batch, seq_len, d_model]\nkv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]\nstart_pos: Position offset for RoPE\nReturns:\noutput: [batch, seq_len, d_model]\nnew_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]\n\"\"\"\nB, S, _ = x.shape\n# === Step 1: Compress to latent space ===\nc_kv = self.w_dkv(x) # [B, S, d_c] — THIS is what gets cached\nc_q = self.w_dq(x) # [B, S, d_c']\n# === Step 2: Decompress for attention computation ===\n# K, V up-projection from latent\nk_content = self.w_uk(c_kv) # [B, S, n_heads * d_k]\nv = self.w_uv(c_kv) # [B, S, n_heads * d_v]\nq_content = self.w_uq(c_q) # [B, S, n_heads * d_k]\n# Reshape to multi-head format\nq_content = q_content.view(B, S, self.n_heads, self.d_k)\nk_content = k_content.view(B, S, self.n_heads, self.d_k)\nv = v.view(B, S, self.n_heads, self.d_v)\n# === Step 3: Decoupled RoPE ===\n# Project to rope-specific dimensions and apply RoPE\nk_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)\nq_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)\nk_rope = self._apply_rope(k_rope, start_pos + S)\nq_rope = self._apply_rope(q_rope, start_pos + S)\n# Concatenate content + rope for full key and query\nq = torch.cat([q_content, q_rope], dim=-1) # [B, S, n_heads, d_k + d_r]\nk = torch.cat([k_content, k_rope], dim=-1) # [B, S, n_heads, d_k + d_r]\n# === Step 4: KV cache management ===\nif kv_cache is not None:\n# Append new latent to cache\nnew_kv_cache = torch.cat([kv_cache, c_kv], dim=1)\n# Decompress full cache for attention\nk_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)\nk_cache_rope = self._apply_rope(\nself.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),\nstart_pos # cache already has positions 0..start_pos-1\n)\nk = torch.cat([\ntorch.cat([k_cache, k_cache_rope], dim=-1),\nk\n], dim=1)\nv_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)\nv = torch.cat([v_cache, v], dim=1)\nelse:\nnew_kv_cache = c_kv\n# === Step 5: Compute attention ===\n# Transpose for attention: [B, n_heads, seq, dim]\nq = q.transpose(1, 2)\nk = k.transpose(1, 2)\nv = v.transpose(1, 2)\nd_attn = self.d_k + self.d_r\nattn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)\nattn_weights = torch.softmax(attn_weights, dim=-1)\nattn_output = torch.matmul(attn_weights, v) # [B, n_heads, S, d_v]\n# === Step 6: Output projection ===\nattn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)\noutput = self.w_o(attn_output)\nreturn output, new_kv_cache\n# === Example: Compare MLA vs MHA cache sizes ===\ndef compare_cache_sizes():\n\"\"\"Demonstrate the KV cache savings of MLA over MHA.\"\"\"\nn_heads = 128\nd_k = 128\nd_c = 512 # DeepSeek-V3 latent dim\nd_r = 64 # Decoupled rope dim\nseq_len = 65536 # 64K context\nbytes_per_element = 2 # FP16\n# MHA: cache K and V for all heads\nmha_cache_per_token = 2 * n_heads * d_k # K + V\nmha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)\n# MLA: cache only c_KV + decoupled rope keys\nmla_cache_per_token = d_c + n_heads * d_r # latent + rope keys\nmla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)\nprint(f\"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer\")\nprint(f\"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer\")\nprint(f\"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x\")\nprint(f\"\\nFor 60 layers:\")\nprint(f\" MHA: {mha_total * 60:.1f} GB\")\nprint(f\" MLA: {mla_total * 60:.1f} GB\")\nprint(f\" Savings: {(mha_total - mla_total) * 60:.1f} GB\")\nif __name__ == \"__main__\":\n# Test MLA forward pass\nmla = MultiHeadLatentAttention(\nd_model=4096, n_heads=8, d_k=64, d_v=64,\nd_c=128, d_c_prime=256, d_r=32,\n)\nx = torch.randn(2, 10, 4096) # batch=2, seq=10\noutput, cache = mla(x)\nprint(f\"Output shape: {output.shape}\") # [2, 10, 4096]\nprint(f\"Cache shape: {cache.shape}\") # [2, 10, 128] — only d_c!\n# Autoregressive step\nx2 = torch.randn(2, 1, 4096)\noutput2, cache2 = mla(x2, kv_cache=cache, start_pos=10)\nprint(f\"Output2 shape: {output2.shape}\") # [2, 1, 4096]\nprint(f\"Cache2 shape: {cache2.shape}\") # [2, 11, 128] — grew by 1\nprint(\"\\n--- Cache Comparison ---\")\ncompare_cache_sizes()\nDeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — Liu et al., 2024. arxiv:2405.04434 — Original MLA paper introducing the latent compression and decoupled RoPE strategy.\nDeepSeek-V3 Technical Report — DeepSeek-AI, 2024. arxiv:2412.19437 — Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.\nvLLM MLA Implementation — github.com/vllm-project/vllm — Production MLA kernel with weight absorption and FlashAttention integration.\nFlashInfer MLA Attention — github.com/flashinfer-ai/flashinfer — Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.", "url": "https://wpnews.pro/news/multi-head-latent-attention-mla", "canonical_source": "https://dev.to/sirajuddin-shaik/multi-head-latent-attention-mla-ahn", "published_at": "2026-05-23 13:14:23+00:00", "updated_at": "2026-05-23 13:34:19.771894+00:00", "lang": "en", "topics": ["large-language-models", "artificial-intelligence", "machine-learning", "research"], "entities": ["DeepSeek-V2", "DeepSeek-V3", "Kimi K2.x", "Multi-Head Latent Attention", "Multi-Head Attention"], "alternates": {"html": "https://wpnews.pro/news/multi-head-latent-attention-mla", "markdown": "https://wpnews.pro/news/multi-head-latent-attention-mla.md", "text": "https://wpnews.pro/news/multi-head-latent-attention-mla.txt", "jsonld": "https://wpnews.pro/news/multi-head-latent-attention-mla.jsonld"}}