cd /news/large-language-models/exploring-speculative-decoding-from-… · home topics large-language-models article
[ARTICLE · art-41936] src=cefboud.com ↗ pub= topic=large-language-models verified=true sentiment=· neutral

Exploring Speculative Decoding: From Concept to Implementation

Speculative decoding optimizes LLM inference by using a cheap draft model to predict multiple tokens, which are then verified in a single forward pass of the target model, reducing memory-bandwidth bottlenecks and improving GPU utilization. The technique leverages the KV cache and can yield significant efficiency gains when draft tokens are accurate.

read16 min views10 publishedMay 31, 2026
Exploring Speculative Decoding: From Concept to Implementation
Image: Cefboud (auto-discovered)

In this post, we explore speculative decoding through a concrete vLLM-focused implementation, covering draft models, EAGLE, MTP, and the tradeoffs involved.

Intro #

In this post, I’ll discuss speculative decoding, a technique used to optimize LLM inference. It’s one those things that when you first learn about it, it somehow just clicks in an “oh yeah! that makes sense” way.

But first, I want to motivate why LLM inference optimzation matters.

LLMs generate responses to user queries. We train large models once, often at 8- or 9-figure costs, but serving them is what happens millions of times. Running a model requires some serious hardware, and saying GPUs are expensive and are in limited supply is almost an understatement. A small efficiency gain means very large savings over time.

Refresher on LLM Inference Basics #

Modern GPUs are impressive beasts. But they have their quirks. A GPU can run hundreds of trillions of operations per second, yet it can only move a few trillion bytes from GPU memory to the compute units.

LLM inference is autoregressive. If we have some input tokens [t_1 .. t_n]

, the model gives us a logits vector of size vocab_size

for each token in the sequence. We use the logits of the last tokent_n

to sample the the next token t_{n+1}

.

That means unless we run hundreds of operations per byte — which we don’t in LLMs — we’re essentially memory-bandwidth bound.

When we batch, we perform more ops per set of weights (X @ W): the larger the batch, the more we reuse the weights (W) we brought from memory, and the better we utilize the GPU.

From KV Cache to Speculative Decoding #

Each token goes through multiple layers, and each layer has a few standard blocks: normalization, MLP, and most notably the transformer’s attention block. In every block before attention, a token does not care about other tokens. In attention, token t_i

needs to know about tokens t_0 .. t_{i-1}

. Specifically, it needs access to the keys and values for those tokens at that particular layer.

One key optimization that LLM inference engines bring to the table is that, instead of recalculating those K and V tensors for every new token, we store them and only run the calculation for the new tokens. That is the famous KV cache. As new tokens go through the model, their key and value vectors are added to the cache.

The sampled token t_{n+1}

is added to the input, and its logits are used to sample t_{n+2}

and so on. During the first model run, we calculate n

logits vectors in the output even though we only need the last one for generation. And the difference between calculating 1 or a few candidate tokens in the same forward pass is often much smaller than running several decode steps one after the other.

For each run or forward pass that produces a new token, we need to load all the models weights and we need to reload them for the next one and so on. And since the memory bandwidth is the limiting factor, we’re waiting for weights to be loaded most of the time.

That is where speculative decoding comes in. If we can guess a few likely tokens ahead of time and feed them to the model, we can verify them in one pass. If they are correct, we get those tokens almost for free. f Regular:

[t_1 .. t_n] => t_{n+1}

Spec dec:

[t_1 .. t_n] => t_{n+1}, t_{n+2}, t_{n+3}, ...

Of course, this only works if the guessed tokens, which we call draft tokens, are usually correct and the guessing process is much cheaper than a full forward pass on the target model. If not, we might as well just use the large original model, which we refer to as the target.

There are many techniques to generate these draft tokens: n-gram, EAGLE, MTP, and others.

But the idea is the same. One forward pass is normally used to sample one token. If we can run a cheaper draft process and predict a few extra tokens, we can reduce the number of expensive target-model steps. If we do things right, we can also preserve the target model’s distribution exactly, as if there were no draft model at all.

Pseudo code #

1
2
3
4
5
6
7
8
9
10
python
def propose(tokens, num_speculative_tokens):
    draft_tokens = []

    for _ in range(num_speculative_tokens):
        logits = generate_draft_token(tokens)  # this must be very fast compared to the target model
        token = sample(logits)
        draft_tokens.append(token)
        tokens.append(token)

    return tokens, draft_tokens

Once we have the draft tokens, we verify them:

1
2
3
4
5
6
7
8
9
10
11
12
python
def sample_verify(tokens, num_drafts):
    logits = target_model(tokens)  # run target model forward, this is the true distribution

    for i, logit in enumerate(logits[-num_drafts:]):
        target_token = sample(logit)
        draft_token = tokens[-num_drafts + i]

        if target_token == draft_token:
            accept(draft_token)
        else:
            reject(draft_token)
            break

Guaranteeing The Original Distribution #

The important part is that speculative decoding doesn’t sacrifice correctness for speed. The final output follows the same distribution as the target model.

This is mathematically provable if we follow this algorithm:

1
2
3
4
5
6
7
8
1. Sample tokens using the draft model via probability distribution q(x)
2. Target model computes true distribution p(x) 
3. Accept the sampled token with probability `min(1, p(x)/q(x))`. Two cases:
   a. If p(x) >= q(x), we accept with probability 1. the target likes the token at least as much as the draft, so we keep it.
   b. If p(x) < q(x), we accept with probability p(x)/q(x), which is less than 1.
4. If rejected, sample a correction token from max(0, p(x)-q(x))

Step 4 is key: it’s where we account for the discrepancy between the draft and target. By resampling from p-q, we cover the tokens the draft was overlooking.

Step 3b is also important. If the draft thinks token 2 has probability 0.4 but the target thinks it’s 0.2, then p/q = 0.2/0.4 = 0.5, so we accept it half the time. This makes sense: the draft is overconfident, and we correct for that via the ratio p/q.

If p > q, the target likes the token more than the draft does, so we just accept it.

Speculators #

There are many ways to get these draft tokens. The criteria we are optimizing for are being as close as possible to the target model and being faster.

Smaller Draft Model

The most intuitive one is probably using a smaller model. Imagine you have a 470B model and you rely on a 7B model from the same family.

For complex tokens, the small model will not perform well, but for repetitive and easy stuff, it should be decent. For example:

Q: How can we solve special relativity? A: To solve special ...

The small model should easily guess that we will repeat part of the question and provide appropriate draft tokens. The rest of the answer is harder, but if we get the easy parts right, we still come out ahead.

We can come up with different techniques to generate draft tokens. It is just a function that takes the existing sequence and tries to predict K

draft tokens. In vLLM, this is implemented as a pluggable speculator.

The snippets below are trimmed from vLLM and keep only the skeleton and the key data flow.

The shared proposer base looks like this:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
python
class SpecDecodeBaseProposer:
    def __init__(self, vllm_config, device, pass_hidden_states_to_model, runner=None):

    @torch.inference_mode()
    def propose(
        self,
        target_token_ids,
        target_positions,
        target_hidden_states,
        next_token_ids,
        token_indices_to_sample,
        common_attn_metadata,
        sampling_metadata,
        mm_embed_inputs=None,
        num_rejected_tokens_gpu=None,
        slot_mappings=None,
    ):
        ...

N-gram

Things tend to repeat themselves. What goes around comes around. This is a general principle in computing that underlies caching: if we see some data, we are likely to use it again. That is locality. Following this principle, we can look at the last N

tokens in our sequence and search for a previous occurrence. Our draft tokens are then the K

tokens that came after that previous occurrence.

From the example above, the suffix of the sequence is “solve special”, its previous occurrence is a few words earlier, and what comes after it is “relativity?”. We guess that as the likely draft tokens, and we end up being right.

N-gram is a very simple and cheap technique to run. We can even run it on the CPU. Its simplicity also means that it is often wrong in practice, but it can be quite useful for text with repetitive patterns, and code is the perfect example.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
python
class NgramProposer:
    def __init__(self, vllm_config):
        self.min_n = vllm_config.speculative_config.prompt_lookup_min
        self.max_n = vllm_config.speculative_config.prompt_lookup_max
        self.k = vllm_config.speculative_config.num_speculative_tokens
        self.max_model_len = vllm_config.model_config.max_model_len

    def propose(self, sampled_token_ids, num_tokens_no_spec, token_ids_cpu, slot_mappings=None):
        valid_requests = [i for i, sampled_ids in enumerate(sampled_token_ids)
                          if sampled_ids and num_tokens_no_spec[i] < self.max_model_len]
        return self.batch_propose(len(sampled_token_ids), valid_requests, num_tokens_no_spec, token_ids_cpu)

    def batch_propose(...):
        for i in prange(len(valid_ngram_requests)):
            idx = valid_ngram_requests[i]
            num_tokens = num_tokens_no_spec[idx]
            context_token_ids = token_ids_cpu[idx, :num_tokens]
            drafter_output = _find_longest_matched_ngram_and_propose_tokens(
                origin_tokens=context_token_ids,
                min_ngram=min_n,
                max_ngram=max_n,
                max_model_len=max_model_len,
                k=k,
            )

            valid_ngram_num_drafts[idx] = drafter_output.shape[0]
            if len(drafter_output):
                valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output

def _find_longest_matched_ngram_and_propose_tokens(origin_tokens, min_ngram, max_ngram, max_model_len, k):

EAGLE

Although the small draft model looks good in theory, in practice it is still lacking because these are fundamentally two different models that learn different things.

An interesting technique is EAGLE, which has several iterations: EAGLE1, EAGLE2, EAGLE3, and most recently EAGLE3.1. The key idea is that the target model is already doing most of the heavy lifting and has all the information needed to predict the next tokens. Some of that information lives in the intermediate hidden states.

So instead of relying only on the token embedding, EAGLE uses the embedding plus hidden states from the target model as input to a lightweight draft network. That draft network then predicts the next token.

In vLLM’s EAGLE-3 setup, the target model produces hidden states at selected layers. Those states are concatenated and projected through a fully connected layer, then passed through lightweight decoder layers and an LM head to produce draft logits. The draft model is still autoregressive, but it relies on the target model’s hidden states.

Essentially, we add a new decode layer that takes as input all the tokens including the last sampled one and their hidden layers, then uses that information to predict the next draft token.

For the first draft token, we do not have a target-model hidden state for the new token yet, so the draft model has to rely on the hidden states it already has. That is one reason why the hidden state design matters.

The main difference between earlier EAGLE versions and EAGLE3 is that the earlier versions focused on the last hidden layer, while EAGLE3 uses multiple layers, usually spanning early, middle, and late stages, to capture a broader view of the model’s reasoning.

The diagram above from the EAGLE3 paper illustrates the idea. We had “How can” and we just predicted “I” using the target model. For each predicted token i

, the hidden state that led to it comes from i-1

. So “How”’s hidden state led to “can”, and “can”’s hidden state led to “I”. In the draft model, we combine the hidden state of i-1

with the token embedding of i

.

The hidden states from low, middle, and high layers are used. For each token, those hidden state vectors are concatenated and combined using a learned fully connected block that is trained to pick and combine the relevant information across the different stages.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

@torch.inference_mode()
def propose(
    self,
    input_batch: InputBatch,
    attn_metadata: dict[str, Any],
    slot_mappings: dict[str, torch.Tensor],
    last_hidden_states: torch.Tensor,     # [num_tokens, H] — target's final layer
    aux_hidden_states: list[torch.Tensor] | None,  # EAGLE-3: 3 × [num_tokens, H]
    num_sampled: torch.Tensor,            # [num_reqs] — accepted count from prev iter
    num_rejected: torch.Tensor,           # [num_reqs] — rejected count from prev iter
    last_sampled: torch.Tensor,           # [max_num_reqs] — last accepted token/request
    next_prefill_tokens: torch.Tensor,    # [max_num_reqs] — for chunked prefills
    temperature: torch.Tensor,
    seeds: torch.Tensor,
    ...
) -> torch.Tensor:
    num_tokens = input_batch.num_tokens_after_padding
    num_reqs = input_batch.num_reqs
    max_query_len = input_batch.num_scheduled_tokens.max()

    
    
    if aux_hidden_states:
        assert self.method == "eagle3"
        hidden_states = self.model.combine_hidden_states(
            torch.cat(aux_hidden_states, dim=-1)
        )
    else:
        hidden_states = last_hidden_states
    
    prepare_eagle_inputs(
        self.input_buffers, input_batch, self.last_token_indices,
        num_sampled, num_rejected, last_sampled, next_prefill_tokens,
        self.max_num_reqs,
    )
    
    self.prefill(
        num_reqs, prefill_batch_desc.num_tokens,
        attn_metadata, slot_mappings,
        num_tokens_across_dp=num_tokens_across_dp,
        cudagraph_runtime_mode=prefill_batch_desc.cg_mode,
        mm_inputs=mm_inputs,
    )

    
    prepare_eagle_decode(
        self.draft_tokens[:num_reqs, 0], input_batch.seq_lens,
        num_rejected, self.input_buffers, self.max_model_len, self.max_num_reqs,
    )

    
    self.generate_draft(
        num_reqs, decode_batch_desc.num_tokens,
        attn_metadata_updated, slot_mappings_updated,
        num_tokens_across_dp=num_tokens_across_dp,
        cudagraph_runtime_mode=decode_batch_desc.cg_mode,
    )

    return self.draft_tokens[:num_reqs]  # [num_reqs, K]

## generating draft tokens is still auto-regressive hence the for loop
def generate_draft(self, num_reqs, num_tokens_padded, attn_metadata, slot_mappings, ...):
    pos = self.input_buffers.positions[:num_reqs]
    query_start_loc = self.input_buffers.query_start_loc[:num_reqs + 1]
    idx_mapping = self.idx_mapping[:num_reqs]

    for step in range(1, self.num_speculative_steps):
        last_hidden_states, hidden_states = self.run_model(
            num_tokens_padded, attn_metadata, slot_mappings, ...
        )
        last_hidden_states = last_hidden_states[:num_reqs]
        hidden_states = hidden_states[:num_reqs]

        logits = self.model.compute_logits(last_hidden_states)
        draft_tokens = self._sample_draft(logits, idx_mapping, pos, step=step)
        self.draft_tokens[:num_reqs, step] = draft_tokens

        if step < self.num_speculative_steps - 1:
            update_eagle_inputs(
                draft_tokens, hidden_states,
                self.input_buffers, self.hidden_states, self.max_model_len,
            )

There is one subtlety worth calling out because it is quite interesting. When generating draft tokens beyond the first one (tokens 2 through k), we use hidden states from the draft model itself. In the diagram above (steps 2 and 3 on the right), this corresponds to using a_i

and a_do

from the draft model rather than g_i

and g_do

from the target model, which contain the true hidden states.

When a draft token is verified and accepted, the corresponding true hidden states from the target model are then passed back to the draft model. At that point, the draft “prefill” step (step 3 in the propose method) recomputes and repopulates the draft KV cache using this corrected information because the “prefill” uses the same slot/attention metadata as the target’s.

EAGLE models are trained separately and are their own models. We can find EAGLE3 models for a variety of open models, for example here. vLLM also has a project to train draft models, such as EAGLE, called speculators, which integrates seamlessly with vLLM.

MTP

EAGLE is a draft model that adds an extra decode path so we can efficiently predict extra draft tokens.

Could we merge a similar extra layer into the target model itself and make it part of the model? That is what MTP is. Some models, such as DeepSeek-family models, include an extra multi-token prediction layer near the end of the network. When a new token is sampled, its embedding plus the hidden state of the last layer are passed to the MTP layer to predict the token that comes right after it. The same LM head and embedding are reused.

This is very similar to EAGLE, except that the model was trained with MTP from the start, it’s even part of the loss function. We can have more than one MTP layer to generate more draft tokens, or we can reuse the MTP layer to predict extra draft tokens, although accuracy will probably drop.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Given context: "The quick brown fox jumps over the"

Target model forward pass:
  h = model("The quick brown fox jumps over the")
  token_1 = sample(lm_head(h)) = "lazy"
  Store: h (hidden state at position "the")

MTP Layer 0:
  Input: embed("lazy") ⊕ h   
  Output: h_mtp0
  token_2 = argmax(lm_head(h_mtp0)) = "dog"

MTP Layer 1 (or Layer 0 reused):
  Input: embed("dog") ⊕ h_mtp0
  Output: h_mtp1  
  token_3 = argmax(lm_head(h_mtp1)) = "and"

Draft: ["lazy", "dog", "and"]

If we venture to HuggingFace and look at DeepSeek-v4 weights, we can observe the single MTP.0 layer sitting all by itself after the other 61 regular layers:

Let’s look at the code. There’s nothing out of the ordinary. Take the input emebeddings and the hidden states that led to them, normalize so we can bring them to similar magnitudes then we concatenate and project them so they can be fed into a regular decode layer.

The generated draft token is then verified by being run through all the model layers.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
python
class DeepSeekMultiTokenPredictorLayer(nn.Module):
    def __init__(self, vllm_config, prefix):
        config = vllm_config.speculative_config.draft_model_config.hf_config
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
        self.shared_head = SharedHead(config=config, prefix=prefix, quant_config=...)
        self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, config=config, topk_indices_buffer=...)

    def forward(self, input_ids, positions, previous_hidden_states, inputs_embeds=None, spec_step_index=0):
        assert inputs_embeds is not None
        inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
        inputs_embeds = self.enorm(inputs_embeds)
        previous_hidden_states = self.hnorm(previous_hidden_states)
        hidden_states = self.eh_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
        hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None)
        return residual + hidden_states

class DeepSeekMultiTokenPredictor(nn.Module):
    def __init__(self, vllm_config, prefix=""):
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        self.layers = nn.ModuleDict({...})
        self.embed_tokens = VocabParallelEmbedding(...)
        self.logits_processor = LogitsProcessor(config.vocab_size)

    def forward(self, input_ids, positions, previous_hidden_states, inputs_embeds=None, spec_step_idx=0):
        current_step_idx = spec_step_idx % self.num_mtp_layers # cycle throught the layer if num of draft tokens is larger than num of mtp layers
        return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
            input_ids, positions, previous_hidden_states, inputs_embeds, current_step_idx
        )

    def compute_logits(self, hidden_states, spec_step_idx=0):
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + (spec_step_idx % self.num_mtp_layers))]
        return self.logits_processor(mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states))

CC BY 4.0by the author.

── more in #large-language-models 4 stories · sorted by recency
── more on @vllm 3 stories trending now
sponsored brought to you by zahid.host 4,200+ EU-deployed projects
reading about agents? ship yours in a single git push.

Run your AI side-project on zahid.host

EU-based hosting, git-push deploys, automatic HTTPS, no cold starts. Free tier with a custom domain — perfect for shipping the agent you just read about.

$git push zahid main
Live at https://your-agent.zahid.host
Get free account → Pricing
from €0/mo · no card required
LIVE [news/exploring-speculativ…] indexed:0 read:16min 2026-05-31 ·