Exploring Speculative Decoding: From Concept to Implementation Speculative decoding optimizes LLM inference by using a cheap draft model to predict multiple tokens, which are then verified in a single forward pass of the target model, reducing memory-bandwidth bottlenecks and improving GPU utilization. The technique leverages the KV cache and can yield significant efficiency gains when draft tokens are accurate. Exploring Speculative Decoding: From Concept to Implementation In this post, we explore speculative decoding through a concrete vLLM-focused implementation, covering draft models, EAGLE, MTP, and the tradeoffs involved. Intro In this post, I’ll discuss speculative decoding, a technique used to optimize LLM inference. It’s one those things that when you first learn about it, it somehow just clicks in an “oh yeah that makes sense” way. But first, I want to motivate why LLM inference optimzation matters. LLMs generate responses to user queries. We train large models once, often at 8- or 9-figure costs, but serving them is what happens millions of times. Running a model requires some serious hardware, and saying GPUs are expensive and are in limited supply is almost an understatement. A small efficiency gain means very large savings over time. Refresher on LLM Inference Basics Modern GPUs are impressive beasts. But they have their quirks. A GPU can run hundreds of trillions of operations per second, yet it can only move a few trillion bytes from GPU memory to the compute units. LLM inference is autoregressive. If we have some input tokens t 1 .. t n , the model gives us a logits vector of size vocab size for each token in the sequence. We use the logits of the last token t n to sample the the next token t {n+1} . That means unless we run hundreds of operations per byte — which we don’t in LLMs — we’re essentially memory-bandwidth bound. When we batch, we perform more ops per set of weights X @ W : the larger the batch, the more we reuse the weights W we brought from memory, and the better we utilize the GPU. From KV Cache to Speculative Decoding Each token goes through multiple layers, and each layer has a few standard blocks: normalization, MLP, and most notably the transformer’s attention block. In every block before attention, a token does not care about other tokens. In attention, token t i needs to know about tokens t 0 .. t {i-1} . Specifically, it needs access to the keys and values for those tokens at that particular layer. One key optimization that LLM inference engines bring to the table is that, instead of recalculating those K and V tensors for every new token, we store them and only run the calculation for the new tokens. That is the famous KV cache. As new tokens go through the model, their key and value vectors are added to the cache. The sampled token t {n+1} is added to the input, and its logits are used to sample t {n+2} and so on. During the first model run, we calculate n logits vectors in the output even though we only need the last one for generation. And the difference between calculating 1 or a few candidate tokens in the same forward pass is often much smaller than running several decode steps one after the other. For each run or forward pass that produces a new token, we need to load all the models weights and we need to reload them for the next one and so on. And since the memory bandwidth is the limiting factor, we’re waiting for weights to be loaded most of the time. That is where speculative decoding comes in. If we can guess a few likely tokens ahead of time and feed them to the model, we can verify them in one pass. If they are correct, we get those tokens almost for free. f Regular: t 1 .. t n = t {n+1} Spec dec: t 1 .. t n = t {n+1}, t {n+2}, t {n+3}, ... Of course, this only works if the guessed tokens, which we call draft tokens, are usually correct and the guessing process is much cheaper than a full forward pass on the target model. If not, we might as well just use the large original model, which we refer to as the target. There are many techniques to generate these draft tokens: n-gram, EAGLE, MTP, and others. But the idea is the same. One forward pass is normally used to sample one token. If we can run a cheaper draft process and predict a few extra tokens, we can reduce the number of expensive target-model steps. If we do things right, we can also preserve the target model’s distribution exactly, as if there were no draft model at all. Pseudo code 1 2 3 4 5 6 7 8 9 10 python def propose tokens, num speculative tokens : draft tokens = for in range num speculative tokens : logits = generate draft token tokens this must be very fast compared to the target model token = sample logits draft tokens.append token tokens.append token return tokens, draft tokens Once we have the draft tokens, we verify them: 1 2 3 4 5 6 7 8 9 10 11 12 python def sample verify tokens, num drafts : logits = target model tokens run target model forward, this is the true distribution for i, logit in enumerate logits -num drafts: : target token = sample logit draft token = tokens -num drafts + i if target token == draft token: accept draft token else: reject draft token break Guaranteeing The Original Distribution The important part is that speculative decoding doesn’t sacrifice correctness for speed. The final output follows the same distribution as the target model. This is mathematically provable if we follow this algorithm: 1 2 3 4 5 6 7 8 this step should be way faster than regular decode 1. Sample tokens using the draft model via probability distribution q x verify multiple tokens cheaply in one pass 2. Target model computes true distribution p x 3. Accept the sampled token with probability min 1, p x /q x . Two cases: a. If p x = q x , we accept with probability 1. the target likes the token at least as much as the draft, so we keep it. b. If p x < q x , we accept with probability p x /q x , which is less than 1. 4. If rejected, sample a correction token from max 0, p x -q x Step 4 is key: it’s where we account for the discrepancy between the draft and target. By resampling from p-q, we cover the tokens the draft was overlooking. Step 3b is also important. If the draft thinks token 2 has probability 0.4 but the target thinks it’s 0.2, then p/q = 0.2/0.4 = 0.5, so we accept it half the time. This makes sense: the draft is overconfident, and we correct for that via the ratio p/q. If p q, the target likes the token more than the draft does, so we just accept it. Speculators There are many ways to get these draft tokens. The criteria we are optimizing for are being as close as possible to the target model and being faster. Smaller Draft Model The most intuitive one is probably using a smaller model. Imagine you have a 470B model and you rely on a 7B model from the same family. For complex tokens, the small model will not perform well, but for repetitive and easy stuff, it should be decent. For example: Q: How can we solve special relativity? A: To solve special ... The small model should easily guess that we will repeat part of the question and provide appropriate draft tokens. The rest of the answer is harder, but if we get the easy parts right, we still come out ahead. We can come up with different techniques to generate draft tokens. It is just a function that takes the existing sequence and tries to predict K draft tokens. In vLLM, this is implemented as a pluggable speculator. The snippets below are trimmed from vLLM and keep only the skeleton and the key data flow. The shared proposer base https://github.com/vllm-project/vllm/blob/6bdabbad5bce747865fd3a249658518a4269cc22/vllm/v1/spec decode/llm base proposer.py L55 looks like this: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 python class SpecDecodeBaseProposer: def init self, vllm config, device, pass hidden states to model, runner=None : @torch.inference mode def propose self, target token ids, target positions, target hidden states, next token ids, token indices to sample, common attn metadata, sampling metadata, mm embed inputs=None, num rejected tokens gpu=None, slot mappings=None, : Take in current token and extra data that might be used depending on the proposer then generate draft tokens and their draft logits/probs ... N-gram Things tend to repeat themselves. What goes around comes around. This is a general principle in computing that underlies caching: if we see some data, we are likely to use it again. That is locality. Following this principle, we can look at the last N tokens in our sequence and search for a previous occurrence. Our draft tokens are then the K tokens that came after that previous occurrence. From the example above, the suffix of the sequence is “solve special”, its previous occurrence is a few words earlier, and what comes after it is “relativity?”. We guess that as the likely draft tokens, and we end up being right. N-gram is a very simple and cheap technique to run. We can even run it on the CPU. Its simplicity also means that it is often wrong in practice, but it can be quite useful for text with repetitive patterns, and code is the perfect example. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 python class NgramProposer: def init self, vllm config : Draft length and match window come from speculative config. self.min n = vllm config.speculative config.prompt lookup min self.max n = vllm config.speculative config.prompt lookup max self.k = vllm config.speculative config.num speculative tokens self.max model len = vllm config.model config.max model len def propose self, sampled token ids, num tokens no spec, token ids cpu, slot mappings=None : Only speculate for requests that actually sampled a token. valid requests = i for i, sampled ids in enumerate sampled token ids if sampled ids and num tokens no spec i < self.max model len return self.batch propose len sampled token ids , valid requests, num tokens no spec, token ids cpu def batch propose ... : for i in prange len valid ngram requests : idx = valid ngram requests i num tokens = num tokens no spec idx context token ids = token ids cpu idx, :num tokens drafter output = find longest matched ngram and propose tokens origin tokens=context token ids, min ngram=min n, max ngram=max n, max model len=max model len, k=k, valid ngram num drafts idx = drafter output.shape 0 if len drafter output : valid ngram draft idx, : drafter output.shape 0 = drafter output def find longest matched ngram and propose tokens origin tokens, min ngram, max ngram, max model len, k : use Knuth–Morris–Pratt KMP algorithm to match longest pattern this video explains it neatly : https://www.youtube.com/watch?v=JoF0Z7nVSrA if you're not familiar with it, it's worth a watch EAGLE Although the small draft model looks good in theory, in practice it is still lacking because these are fundamentally two different models that learn different things. An interesting technique is EAGLE, which has several iterations: EAGLE1, EAGLE2, EAGLE3, and most recently EAGLE3.1. The key idea is that the target model is already doing most of the heavy lifting and has all the information needed to predict the next tokens. Some of that information lives in the intermediate hidden states. So instead of relying only on the token embedding, EAGLE uses the embedding plus hidden states from the target model as input to a lightweight draft network. That draft network then predicts the next token. In vLLM’s EAGLE-3 setup, the target model produces hidden states at selected layers. Those states are concatenated and projected through a fully connected layer, then passed through lightweight decoder layers and an LM head to produce draft logits. The draft model is still autoregressive, but it relies on the target model’s hidden states. Essentially, we add a new decode layer that takes as input all the tokens including the last sampled one and their hidden layers, then uses that information to predict the next draft token. For the first draft token, we do not have a target-model hidden state for the new token yet, so the draft model has to rely on the hidden states it already has. That is one reason why the hidden state design matters. The main difference between earlier EAGLE versions and EAGLE3 is that the earlier versions focused on the last hidden layer, while EAGLE3 uses multiple layers, usually spanning early, middle, and late stages, to capture a broader view of the model’s reasoning. The diagram above from the EAGLE3 paper illustrates the idea. We had “How can” and we just predicted “I” using the target model. For each predicted token i , the hidden state that led to it comes from i-1 . So “How”’s hidden state led to “can”, and “can”’s hidden state led to “I”. In the draft model, we combine the hidden state of i-1 with the token embedding of i . The hidden states from low, middle, and high layers are used. For each token, those hidden state vectors are concatenated and combined using a learned fully connected block that is trained to pick and combine the relevant information across the different stages. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 vllm/v1/worker/gpu/spec decode/eagle/speculator.py:405-567 @torch.inference mode def propose self, input batch: InputBatch, attn metadata: dict str, Any , slot mappings: dict str, torch.Tensor , last hidden states: torch.Tensor, num tokens, H — target's final layer aux hidden states: list torch.Tensor | None, EAGLE-3: 3 × num tokens, H num sampled: torch.Tensor, num reqs — accepted count from prev iter num rejected: torch.Tensor, num reqs — rejected count from prev iter last sampled: torch.Tensor, max num reqs — last accepted token/request next prefill tokens: torch.Tensor, max num reqs — for chunked prefills temperature: torch.Tensor, seeds: torch.Tensor, ... - torch.Tensor: num tokens = input batch.num tokens after padding num reqs = input batch.num reqs max query len = input batch.num scheduled tokens.max STEP 1: FC FUSION EAGLE-3 only if aux hidden states: assert self.method == "eagle3" hidden states = self.model.combine hidden states torch.cat aux hidden states, dim=-1 else: EAGLE-1/2: use final hidden states directly no fusion hidden states = last hidden states STEP 2: PREPARE EAGLE INPUTS Triton kernel prepare eagle inputs self.input buffers, input batch, self.last token indices, num sampled, num rejected, last sampled, next prefill tokens, self.max num reqs, STEP 3: PREFILL — GENERATE DRAFT TOKEN 0 self.prefill num reqs, prefill batch desc.num tokens, attn metadata, slot mappings, num tokens across dp=num tokens across dp, cudagraph runtime mode=prefill batch desc.cg mode, mm inputs=mm inputs, STEP 4: PREPARE DECODE — TRANSITION TO AUTOREGRESSIVE MODE prepare eagle decode self.draft tokens :num reqs, 0 , input batch.seq lens, num rejected, self.input buffers, self.max model len, self.max num reqs, STEP 5: DECODE LOOP — GENERATE DRAFT TOKENS 1..K-1 self.generate draft num reqs, decode batch desc.num tokens, attn metadata updated, slot mappings updated, num tokens across dp=num tokens across dp, cudagraph runtime mode=decode batch desc.cg mode, return self.draft tokens :num reqs num reqs, K generating draft tokens is still auto-regressive hence the for loop def generate draft self, num reqs, num tokens padded, attn metadata, slot mappings, ... : pos = self.input buffers.positions :num reqs query start loc = self.input buffers.query start loc :num reqs + 1 idx mapping = self.idx mapping :num reqs ── ITERATE THROUGH DRAFT POSITIONS 1, 2, ..., K-1 ── for step in range 1, self.num speculative steps : EAGLE forward: 1 token per request decode mode Uses hidden states from previous step + embed prev draft token as input last hidden states, hidden states = self.run model num tokens padded, attn metadata, slot mappings, ... last hidden states = last hidden states :num reqs hidden states = hidden states :num reqs We have the final output of the EAGLE model We compute logits then sample the draft tokens logits = self.model.compute logits last hidden states draft tokens = self. sample draft logits, idx mapping, pos, step=step self.draft tokens :num reqs, step = draft tokens ── UPDATE STATE FOR NEXT STEP unless this is the final step ── if step < self.num speculative steps - 1: ... update eagle inputs draft tokens, hidden states, self.input buffers, self.hidden states, self.max model len, ... There is one subtlety worth calling out because it is quite interesting. When generating draft tokens beyond the first one tokens 2 through k , we use hidden states from the draft model itself. In the diagram above steps 2 and 3 on the right , this corresponds to using a i and a do from the draft model rather than g i and g do from the target model, which contain the true hidden states. When a draft token is verified and accepted, the corresponding true hidden states from the target model are then passed back to the draft model. At that point, the draft “prefill” step step 3 in the propose method recomputes and repopulates the draft KV cache using this corrected information because the “prefill” uses the same slot/attention metadata as the target’s. EAGLE models are trained separately and are their own models. We can find EAGLE3 models for a variety of open models, for example here https://huggingface.co/collections/RedHatAI/speculator-models . vLLM also has a project to train draft models, such as EAGLE, called speculators https://github.com/vllm-project/speculators , which integrates seamlessly with vLLM. MTP EAGLE is a draft model that adds an extra decode path so we can efficiently predict extra draft tokens. Could we merge a similar extra layer into the target model itself and make it part of the model? That is what MTP is. Some models, such as DeepSeek-family models, include an extra multi-token prediction layer near the end of the network. When a new token is sampled, its embedding plus the hidden state of the last layer are passed to the MTP layer to predict the token that comes right after it. The same LM head and embedding are reused. This is very similar to EAGLE, except that the model was trained with MTP from the start, it’s even part of the loss function. We can have more than one MTP layer to generate more draft tokens, or we can reuse the MTP layer to predict extra draft tokens, although accuracy will probably drop. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 Given context: "The quick brown fox jumps over the" Target model forward pass: h = model "The quick brown fox jumps over the" token 1 = sample lm head h = "lazy" Store: h hidden state at position "the" we predicted "lazy" using the hidden state at position "the" we use both in MTP layer 0 MTP Layer 0: Input: embed "lazy" ⊕ h Output: h mtp0 token 2 = argmax lm head h mtp0 = "dog" Same logic as above MTP Layer 1 or Layer 0 reused : Input: embed "dog" ⊕ h mtp0 Output: h mtp1 token 3 = argmax lm head h mtp1 = "and" Draft: "lazy", "dog", "and" If we venture to HuggingFace and look at DeepSeek-v4 weights https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro?show file info=model.safetensors.index.json , we can observe the single MTP.0 layer sitting all by itself after the other 61 regular layers: Let’s look at the code. There’s nothing out of the ordinary. Take the input emebeddings and the hidden states that led to them, normalize so we can bring them to similar magnitudes then we concatenate and project them so they can be fed into a regular decode layer. The generated draft token is then verified by being run through all the model layers. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 python class DeepSeekMultiTokenPredictorLayer nn.Module : def init self, vllm config, prefix : MTP reuses the model's own embedding + hidden-state path. config = vllm config.speculative config.draft model config.hf config self.enorm = RMSNorm config.hidden size, eps=config.rms norm eps self.hnorm = RMSNorm config.hidden size, eps=config.rms norm eps self.eh proj = nn.Linear config.hidden size 2, config.hidden size, bias=False self.shared head = SharedHead config=config, prefix=prefix, quant config=... self.mtp block = DeepseekV2DecoderLayer vllm config, prefix, config=config, topk indices buffer=... def forward self, input ids, positions, previous hidden states, inputs embeds=None, spec step index=0 : assert inputs embeds is not None Position 0 is masked out because MTP only needs the shifted context. inputs embeds = torch.where positions.unsqueeze -1 == 0, 0, inputs embeds inputs embeds = self.enorm inputs embeds previous hidden states = self.hnorm previous hidden states Fuse the current embedding with the previous hidden state. hidden states = self.eh proj torch.cat inputs embeds, previous hidden states , dim=-1 One extra decoder block turns that fused state into draft logits. hidden states, residual = self.mtp block positions=positions, hidden states=hidden states, residual=None return residual + hidden states class DeepSeekMultiTokenPredictor nn.Module : def init self, vllm config, prefix="" : config = vllm config.model config.hf config self.mtp start layer idx = config.num hidden layers self.num mtp layers = config.num nextn predict layers self.layers = nn.ModuleDict {...} self.embed tokens = VocabParallelEmbedding ... self.logits processor = LogitsProcessor config.vocab size def forward self, input ids, positions, previous hidden states, inputs embeds=None, spec step idx=0 : current step idx = spec step idx % self.num mtp layers cycle throught the layer if num of draft tokens is larger than num of mtp layers return self.layers str self.mtp start layer idx + current step idx input ids, positions, previous hidden states, inputs embeds, current step idx def compute logits self, hidden states, spec step idx=0 : mtp layer = self.layers str self.mtp start layer idx + spec step idx % self.num mtp layers notice the "shared head" return self.logits processor mtp layer.shared head.head, mtp layer.shared head hidden states CC BY 4.0 https://creativecommons.org/licenses/by/4.0/ by the author.