# Multi-Head Latent Attention (MLA)

> Source: <https://dev.to/sirajuddin-shaik/multi-head-latent-attention-mla-ahn>
> Published: 2026-05-23 13:14:23+00:00

Compressing KV cache via low-rank projections — the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x
Multi-Head Latent Attention (MLA) is the attention variant that replaces standard Multi-Head Attention (MHA) in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 5-10x KV cache compression with minimal quality loss.
For input X∈Rn×d\mathbf{X} \in \mathbb{R}^{n \times d}X∈Rn×d , MHA computes per-head projections:
where WQ(h)∈Rd×dk\mathbf{W}_Q^{(h)} \in \mathbb{R}^{d \times d_k}WQ(h)∈Rd×dk , WK(h)∈Rd×dk\mathbf{W}_K^{(h)} \in \mathbb{R}^{d \times d_k}WK(h)∈Rd×dk , WV(h)∈Rd×dv\mathbf{W}_V^{(h)} \in \mathbb{R}^{d \times d_v}WV(h)∈Rd×dv .
KV cache size per token: 2×nh×dk2 \times n_h \times d_k2×nh×dk elements.
MLA replaces the per-head KV projections with a shared low-rank latent compression:
Compression (KV → Latent):
where WDKV∈Rd×dc\mathbf{W}_{DKV} \in \mathbb{R}^{d \times d_c}WDKV∈Rd×dc is the down-projection matrix and dc≪nh×dkd_c \ll n_h \times d_kdc≪nh×dk .
Decompression (Latent → KV):
where WUK(h)∈Rdc×dk\mathbf{W}{UK}^{(h)} \in \mathbb{R}^{d_c \times d_k}WUK(h)∈Rdc×dk and WUV(h)∈Rdc×dv\mathbf{W}{UV}^{(h)} \in \mathbb{R}^{d_c \times d_v}WUV(h)∈Rdc×dv are up-projection matrices.
KV cache per token: Only cKV∈Rdc\mathbf{c}^{KV} \in \mathbb{R}^{d_c}cKV∈Rdc is stored — a single vector of dimension dcd_cdc .
For a model with nhn_hnh heads and head dimension dkd_kdk :
In DeepSeek-V3:
nh=128n_h = 128nh=128
,
dk=128d_k = 128dk=128
,
dc=512d_c = 512dc=512
:
MLA also compresses queries for training efficiency:
This doesn't affect the KV cache but reduces the activation memory during training.
RoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate "absorbed" key projection:
where WKR(h)∈Rdc×dr\mathbf{W}_{KR}^{(h)} \in \mathbb{R}^{d_c \times d_r}WKR(h)∈Rdc×dr with dr≪dkd_r \ll d_kdr≪dk is a narrow projection that carries positional information. The cached representation remains cKV\mathbf{c}^{KV}cKV (position-agnostic), and the RoPE key K^h\hat{\mathbf{K}}_hK^h is recomputed at attention time from the cached latent.
The critical insight in MLA is that the up-projection matrices WUK(h)\mathbf{W}_{UK}^{(h)}WUK(h) can be absorbed into the query projection during attention computation:
Substituting the decompressed forms:
If we define Wabsorbed(h)=WUQ(h)WUK(h)T∈Rdc′×dc\mathbf{W}{absorbed}^{(h)} = \mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T \in \mathbb{R}^{d_c' \times d_c}Wabsorbed(h)=WUQ(h)WUK(h)T∈Rdc′×dc , then:
This means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output.
Practical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax.
RoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key:
The attention score becomes:
Practical implication: The KV cache stores both cKV\mathbf{c}^{KV}cKV (latent) and Khrope\mathbf{K}_h^{rope}Khrope (decoupled rope key). Total cache per token: dc+nh×drd_c + n_h \times d_rdc+nh×dr .
GQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information.
MLA dramatically changes the memory-vs-compute tradeoff in serving:
Memory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows:
Compute-bound prefill phase: MLA adds decompression overhead, but this is amortized:
This is where it gets interesting for Siraj's EAGLE-3 work:
Draft model constraints:
Verification with MLA:
vLLM implementation challenge: vLLM's PagedAttention was designed for MHA/GQA. MLA requires:
import torch
import torch.nn as nn
import math
class MultiHeadLatentAttention(nn.Module):
"""
MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.
Key features:
- Low-rank KV compression (cache only c_KV latent vector)
- Decoupled RoPE for position-aware attention
- Weight absorption for efficient score computation
"""
def __init__(
self,
d_model: int = 4096,
n_heads: int = 128,
d_k: int = 128,
d_v: int = 128,
d_c: int = 512, # KV latent dimension (compression target)
d_c_prime: int = 1536, # Query latent dimension
d_r: int = 64, # Decoupled RoPE key dimension per head
max_seq_len: int = 8192,
rope_base: float = 10000.0,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
self.d_c = d_c
self.d_c_prime = d_c_prime
self.d_r = d_r
# === Down-projections (compression) ===
self.w_dkv = nn.Linear(d_model, d_c, bias=False) # KV latent
self.w_dq = nn.Linear(d_model, d_c_prime, bias=False) # Q latent
# === Up-projections (decompression) ===
# KV up-projections: latent -> per-head K and V
self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)
# Q up-projection: latent -> per-head Q
self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)
# === Decoupled RoPE projections ===
self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False) # Rope key from latent
self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False) # Rope query from latent
# === Output projection ===
self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)
# RoPE frequencies
inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
self.register_buffer('inv_freq', inv_freq)
def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
"""Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq) # [seq, d_r//2]
cos = freqs.cos().unsqueeze(0).unsqueeze(2) # [1, seq, 1, d_r//2]
sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos,
], dim=-1).flatten(-2)
return rotated
def forward(
self,
x: torch.Tensor,
kv_cache: torch.Tensor = None,
start_pos: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, d_model]
kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
start_pos: Position offset for RoPE
Returns:
output: [batch, seq_len, d_model]
new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
"""
B, S, _ = x.shape
# === Step 1: Compress to latent space ===
c_kv = self.w_dkv(x) # [B, S, d_c] — THIS is what gets cached
c_q = self.w_dq(x) # [B, S, d_c']
# === Step 2: Decompress for attention computation ===
# K, V up-projection from latent
k_content = self.w_uk(c_kv) # [B, S, n_heads * d_k]
v = self.w_uv(c_kv) # [B, S, n_heads * d_v]
q_content = self.w_uq(c_q) # [B, S, n_heads * d_k]
# Reshape to multi-head format
q_content = q_content.view(B, S, self.n_heads, self.d_k)
k_content = k_content.view(B, S, self.n_heads, self.d_k)
v = v.view(B, S, self.n_heads, self.d_v)
# === Step 3: Decoupled RoPE ===
# Project to rope-specific dimensions and apply RoPE
k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)
k_rope = self._apply_rope(k_rope, start_pos + S)
q_rope = self._apply_rope(q_rope, start_pos + S)
# Concatenate content + rope for full key and query
q = torch.cat([q_content, q_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
k = torch.cat([k_content, k_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
# === Step 4: KV cache management ===
if kv_cache is not None:
# Append new latent to cache
new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
# Decompress full cache for attention
k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
k_cache_rope = self._apply_rope(
self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
start_pos # cache already has positions 0..start_pos-1
)
k = torch.cat([
torch.cat([k_cache, k_cache_rope], dim=-1),
k
], dim=1)
v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
v = torch.cat([v_cache, v], dim=1)
else:
new_kv_cache = c_kv
# === Step 5: Compute attention ===
# Transpose for attention: [B, n_heads, seq, dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
d_attn = self.d_k + self.d_r
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v) # [B, n_heads, S, d_v]
# === Step 6: Output projection ===
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
output = self.w_o(attn_output)
return output, new_kv_cache
# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
"""Demonstrate the KV cache savings of MLA over MHA."""
n_heads = 128
d_k = 128
d_c = 512 # DeepSeek-V3 latent dim
d_r = 64 # Decoupled rope dim
seq_len = 65536 # 64K context
bytes_per_element = 2 # FP16
# MHA: cache K and V for all heads
mha_cache_per_token = 2 * n_heads * d_k # K + V
mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)
# MLA: cache only c_KV + decoupled rope keys
mla_cache_per_token = d_c + n_heads * d_r # latent + rope keys
mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)
print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
print(f"\nFor 60 layers:")
print(f" MHA: {mha_total * 60:.1f} GB")
print(f" MLA: {mla_total * 60:.1f} GB")
print(f" Savings: {(mha_total - mla_total) * 60:.1f} GB")
if __name__ == "__main__":
# Test MLA forward pass
mla = MultiHeadLatentAttention(
d_model=4096, n_heads=8, d_k=64, d_v=64,
d_c=128, d_c_prime=256, d_r=32,
)
x = torch.randn(2, 10, 4096) # batch=2, seq=10
output, cache = mla(x)
print(f"Output shape: {output.shape}") # [2, 10, 4096]
print(f"Cache shape: {cache.shape}") # [2, 10, 128] — only d_c!
# Autoregressive step
x2 = torch.randn(2, 1, 4096)
output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
print(f"Output2 shape: {output2.shape}") # [2, 1, 4096]
print(f"Cache2 shape: {cache2.shape}") # [2, 11, 128] — grew by 1
print("\n--- Cache Comparison ---")
compare_cache_sizes()
DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — Liu et al., 2024. arxiv:2405.04434 — Original MLA paper introducing the latent compression and decoupled RoPE strategy.
DeepSeek-V3 Technical Report — DeepSeek-AI, 2024. arxiv:2412.19437 — Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.
vLLM MLA Implementation — github.com/vllm-project/vllm — Production MLA kernel with weight absorption and FlashAttention integration.
FlashInfer MLA Attention — github.com/flashinfer-ai/flashinfer — Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.
