machine learning +
KV Cache Explained: Build a Cache Manager in Python
Speculative Decoding: Faster LLM Inference (Python)
Build a speculative decoding simulator in Python. Learn the draft-verify algorithm, measure acceptance rates, and understand when it speeds up LLM inference.
This post has interactive code — click ‘Run’ or press Ctrl+Enter on any code block to execute it directly in your browser.
Simulate the draft-then-verify trick that makes LLM inference 2-3x faster — with pure Python and NumPy.
Your large language model generates one token at a time. Each token costs a full forward pass through billions of parameters. That’s the bottleneck.
What if a tiny model could guess the next five tokens faster? And what if the big model could check all five guesses in a single pass?
That’s speculative decoding. It’s the same idea as having an intern draft an email and your manager approve it in one glance — instead of the manager typing every word. The output is identical to what the big model would produce alone. But you get there faster.
In this article, you’ll build a speculative decoding simulator from scratch. No PyTorch, no Hugging Face. Just NumPy probability distributions that show how the draft-verify loop works. You’ll measure acceptance rates and learn when speculative decoding helps versus when it backfires.
Prerequisites
- Python version: 3.9+
- Required libraries: NumPy (1.24+)
- Install:
pip install numpy - Time to complete: 25-30 minutes
What Is Speculative Decoding?
Speculative decoding pairs two models: a small draft model and a large target model. The draft model is fast but less accurate. The target model is slow but authoritative.
Here’s the core loop in three steps:
- The draft model proposes
\(\gamma\)tokens (typically 3-7). - The target model verifies all
\(\gamma\)tokens in one forward pass. - Accepted tokens go to the output. The first rejected token gets resampled from the target model’s distribution.
The key guarantee: the final output distribution is mathematically identical to what the target model would produce alone. You don’t sacrifice quality. You trade a small model’s cheap forward passes for fewer expensive ones.
Why does this work? Because the target model can verify \(\gamma\) tokens in roughly the same time it takes to generate one. Matrix multiplication on GPUs is parallel — checking five tokens isn’t five times slower than checking one.
Key Insight: Speculative decoding doesn’t change what the model outputs — it changes how fast you get there. The verification step guarantees the same distribution as standard autoregressive decoding.
Let’s make this concrete. We’ll simulate both models as probability distributions over a small vocabulary and watch the algorithm run step by step.
import numpy as np
np.random.seed(42)
# Small vocabulary for demonstration
VOCAB = ['the', 'cat', 'sat', 'on', 'mat', 'dog', 'ran', 'big', 'a', 'red']
VOCAB_SIZE = len(VOCAB)
def token_name(token_id):
"""Convert token ID to readable word."""
return VOCAB[token_id]
def tokens_to_text(token_ids):
"""Convert list of token IDs to readable text."""
return ' '.join(VOCAB[t] for t in token_ids)
print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"Vocabulary: {VOCAB}")
Output:
python
Vocabulary size: 10
Vocabulary: ['the', 'cat', 'sat', 'on', 'mat', 'dog', 'ran', 'big', 'a', 'red']
We’re using 10 words to keep things visual. In production, vocabulary sizes are 32K-128K tokens. The algorithm works the same way regardless.
How the Draft Model Proposes Tokens
The draft model is small and fast. Think of it as a lightweight n-gram model or a tiny neural network. It produces a probability distribution over the vocabulary for each position.
In our simulator, we’ll represent each model as a function that returns a probability distribution given a context. The draft model’s distribution is “close but not perfect” compared to the target model.
We’ll build two mock models. The draft_model produces somewhat noisy probabilities. The target_model produces the “true” distribution. The closer these two distributions are, the more tokens get accepted — and the bigger the speedup.
def target_model(context):
"""Simulate the large target model's next-token distribution.
Returns a probability distribution over VOCAB_SIZE tokens.
Context determines the distribution (simplified).
"""
# Use context to seed a deterministic distribution
ctx_hash = sum(context[-3:]) if context else 0
rng = np.random.RandomState(ctx_hash + 100)
# Create a peaked distribution (confident model)
logits = rng.randn(VOCAB_SIZE)
logits[ctx_hash % VOCAB_SIZE] += 2.0 # Boost one token
# Softmax
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / exp_logits.sum()
def draft_model(context):
"""Simulate a smaller draft model.
Similar to target but noisier — less confident,
slightly different preferences.
"""
ctx_hash = sum(context[-3:]) if context else 0
rng = np.random.RandomState(ctx_hash + 200)
# Flatter distribution (less confident)
logits = rng.randn(VOCAB_SIZE) * 0.8
logits[ctx_hash % VOCAB_SIZE] += 1.5 # Weaker boost
# Softmax
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / exp_logits.sum()
# Compare the two distributions for the same context
context = [0, 1] # "the cat"
p_target = target_model(context)
q_draft = draft_model(context)
print("Target model distribution:")
for i, (name, prob) in enumerate(zip(VOCAB, p_target)):
bar = '#' * int(prob * 50)
print(f" {name:>4s}: {prob:.3f} {bar}")
print("\nDraft model distribution:")
for i, (name, prob) in enumerate(zip(VOCAB, q_draft)):
bar = '#' * int(prob * 50)
print(f" {name:>4s}: {prob:.3f} {bar}")
Output:
python
Target model distribution:
the: 0.374 ##################
cat: 0.345 #################
sat: 0.062 ###
on: 0.041 ##
mat: 0.048 ##
dog: 0.018
ran: 0.011
big: 0.046 ##
a: 0.003
red: 0.052 ##
Draft model distribution:
the: 0.101 #####
cat: 0.237 ###########
sat: 0.208 ##########
on: 0.111 #####
mat: 0.044 ##
dog: 0.111 #####
ran: 0.045 ##
big: 0.055 ##
a: 0.044 ##
red: 0.043 ##
Notice the difference? The target model is more confident — it puts 37% probability on “the” and 35% on “cat.” The draft model spreads probability more evenly. That mismatch drives the acceptance rate, which we’ll measure next.
The Verification Algorithm — Accept or Reject
This is the heart of speculative decoding. The draft model proposes a token. The target model checks whether it “agrees.” But it’s not a simple yes/no — it’s probabilistic.
For each draft token \(x$ sampled from the draft distribution \)q(x)$, the target model accepts it with probability:
\[\alpha(x) = \min\left(1, \frac{p(x)}{q(x)}\right)\]Where:
– \(p(x)\) = target model’s probability for token $x$
– \(q(x)\) = draft model’s probability for token $x$
If \(p(x) \geq q(x)\), the draft model underestimated this token — accept it always. If \(p(x) < q(x)\), the draft model overestimated it — accept with probability \(p(x)/q(x)\).
When a token is rejected, we resample from a corrected distribution:
\[p'(x) = \text{norm}\left(\max(0, p(x) - q(x))\right)\]This correction ensures the final output matches the target distribution exactly. It's the mathematical trick that makes speculative decoding lossless.
Let's code the single-token verification step. The function verify_token takes a draft token, both distributions, and returns whether to accept it. If rejected, it returns a resampled token from the corrected distribution.
def verify_token(draft_token, p_target, q_draft):
"""Verify a single draft token against the target distribution.
Returns (accepted: bool, final_token: int, acceptance_prob: float).
"""
p_x = p_target[draft_token]
q_x = q_draft[draft_token]
# Acceptance probability
accept_prob = min(1.0, p_x / q_x)
# Accept or reject
u = np.random.random()
if u < accept_prob:
return True, draft_token, accept_prob
# Rejected — resample from corrected distribution
residual = np.maximum(0, p_target - q_draft)
residual_sum = residual.sum()
if residual_sum > 0:
corrected = residual / residual_sum
else:
corrected = p_target # Fallback to target
resampled = np.random.choice(VOCAB_SIZE, p=corrected)
return False, resampled, accept_prob
That's 20 lines for the core algorithm. The min(1, p/q) ratio is the entire verification mechanism. Let's test it on a single token.
np.random.seed(42)
context = [0, 1] # "the cat"
p_target = target_model(context)
q_draft = draft_model(context)
# Draft model samples a token
draft_token = np.random.choice(VOCAB_SIZE, p=q_draft)
print(f"Draft proposes: '{token_name(draft_token)}'")
print(f" p_target('{token_name(draft_token)}') = {p_target[draft_token]:.3f}")
print(f" q_draft('{token_name(draft_token)}') = {q_draft[draft_token]:.3f}")
accepted, final_token, accept_prob = verify_token(
draft_token, p_target, q_draft
)
print(f" Acceptance probability: {accept_prob:.3f}")
print(f" {'ACCEPTED' if accepted else 'REJECTED'}")
print(f" Final token: '{token_name(final_token)}'")
Output:
python
Draft proposes: 'sat'
p_target('sat') = 0.062
q_draft('sat') = 0.208
Acceptance probability: 0.297
REJECTED
Final token: 'cat'
The draft guessed "sat." The target model gives it only 6.2% probability while the draft gives it 20.8%. The ratio \(p/q\) is just 0.297 — a low acceptance probability. This time it was rejected, and the algorithm resampled "cat" from the corrected distribution.
Tip: When the draft model underestimates a token's probability (
\(q(x) < p(x)\)), acceptance is guaranteed. The ratio \(p(x)/q(x)\) exceeds 1, and \(\min(1, p/q) = 1\). The draft model only gets penalized for tokens it's overconfident about.Building the Full Draft-Verify Loop
A single token verification is nice, but speculative decoding proposes \(\gamma\) tokens at once. The algorithm processes them left to right. The first rejection stops the chain — all subsequent draft tokens are discarded.
Here's why: each token depends on the previous context. If token 3 is wrong, tokens 4 and 5 were generated from the wrong context. They're invalid regardless of their individual quality.
The function speculative_decode_step runs one full round: draft \(\gamma\) tokens, verify each in sequence, and return the accepted tokens plus one bonus token from the target model.
def speculative_decode_step(context, gamma=5):
"""Run one speculative decoding round.
1. Draft model proposes gamma tokens.
2. Target model verifies each in sequence.
3. Return accepted tokens + 1 bonus token.
Returns (new_tokens, stats_dict).
"""
draft_tokens = []
draft_distributions = []
current_context = list(context)
# Step 1: Draft model proposes gamma tokens
for _ in range(gamma):
q = draft_model(current_context)
draft_distributions.append(q)
token = np.random.choice(VOCAB_SIZE, p=q)
draft_tokens.append(token)
current_context.append(token)
# Step 2: Target model verifies each token
accepted_tokens = []
verify_context = list(context)
for i in range(gamma):
p = target_model(verify_context)
q = draft_distributions[i]
accepted, final_token, accept_prob = verify_token(
draft_tokens[i], p, q
)
if accepted:
accepted_tokens.append(final_token)
verify_context.append(final_token)
else:
# Rejected: add the resampled token
accepted_tokens.append(final_token)
break
else:
# All gamma tokens accepted — bonus token from target
p = target_model(verify_context)
bonus = np.random.choice(VOCAB_SIZE, p=p)
accepted_tokens.append(bonus)
n_accepted = len(accepted_tokens)
n_drafted = gamma
stats = {
'drafted': n_drafted,
'accepted': n_accepted,
'draft_tokens': [token_name(t) for t in draft_tokens],
'final_tokens': [token_name(t) for t in accepted_tokens],
'acceptance_rate': (n_accepted - 1) / n_drafted
}
return accepted_tokens, stats
Let's run a single round and inspect the results.
np.random.seed(42)
context = [0] # Start with "the"
new_tokens, stats = speculative_decode_step(context, gamma=5)
print(f"Starting context: '{tokens_to_text(context)}'")
print(f"Draft proposed {stats['drafted']} tokens: {stats['draft_tokens']}")
print(f"Final tokens ({stats['accepted']} total): {stats['final_tokens']}")
print(f"Acceptance rate: {stats['acceptance_rate']:.1%}")
print(f"Generated text: '{tokens_to_text(context + new_tokens)}'")
Output:
python
Starting context: 'the'
Draft proposed 5 tokens: ['cat', 'a', 'red', 'big', 'on']
Final tokens (3 total): ['cat', 'a', 'cat']
Acceptance rate: 40.0%
Generated text: 'the cat a cat'
The draft model proposed 5 tokens. Two were accepted, the third was rejected and resampled to "cat." Tokens 4 and 5 were discarded because the chain broke at position 3. We still produced 3 tokens in what would be one target-model forward pass.
Measuring Speedup — When Does It Help?
Here's the question that matters: how much faster is speculative decoding compared to standard autoregressive generation?
The theoretical speedup depends on two factors:
- Acceptance rate (
\(\alpha\)): How often the target model agrees with the draft. - **Cost ratio (
\(c$)**: How cheap the draft model is compared to the target. If the draft model takes 10ms and the target takes 100ms, \)c = 0.1$.
The expected speedup formula is:
\[\text{Speedup} = \frac{1 - \alpha^{\gamma+1}}{(1 - \alpha)} \cdot \frac{1}{c \cdot \gamma + 1}\]Where:
- \(\alpha\) = average acceptance rate per token
- \(\gamma\) = number of draft tokens (speculation length)
- $c$ = cost ratio (draft time / target time)
When \(\alpha\) is high and \(c$ is low, you win big. When \)\alpha$ is low or $c$ is high, you actually slow down.
Let's build a function that simulates many rounds of both standard and speculative decoding, then compares their total cost.
def standard_decode(context, num_tokens):
"""Standard autoregressive decoding — one token at a time."""
generated = list(context)
target_calls = 0
for _ in range(num_tokens):
p = target_model(generated)
token = np.random.choice(VOCAB_SIZE, p=p)
generated.append(token)
target_calls += 1
return generated[len(context):], target_calls
def speculative_decode(context, num_tokens, gamma=5):
"""Full speculative decoding for num_tokens tokens."""
generated = list(context)
target_calls = 0
draft_calls = 0
total_accepted = 0
total_drafted = 0
rounds = 0
while len(generated) - len(context) < num_tokens:
new_tokens, stats = speculative_decode_step(generated, gamma)
generated.extend(new_tokens)
target_calls += 1 # One target forward pass per round
draft_calls += gamma
total_accepted += stats['accepted']
total_drafted += stats['drafted']
rounds += 1
# Trim to exact length
result = generated[len(context):len(context) + num_tokens]
return result, {
'target_calls': target_calls,
'draft_calls': draft_calls,
'total_accepted': total_accepted,
'total_drafted': total_drafted,
'rounds': rounds,
'avg_acceptance': total_accepted / total_drafted if total_drafted > 0 else 0
}
Now let's compare the two approaches on generating 50 tokens. The cost metric is the number of target model forward passes — that's the expensive operation.
np.random.seed(42)
context = [0, 1, 2] # "the cat sat"
# Standard decoding
std_tokens, std_calls = standard_decode(context, 50)
# Speculative decoding with different gamma values
print("=" * 60)
print(f"Generating 50 tokens from context: '{tokens_to_text(context)}'")
print("=" * 60)
print(f"\nStandard decoding:")
print(f" Target model calls: {std_calls}")
print(f" Cost ratio baseline: 1.00x")
for gamma in [3, 5, 7]:
np.random.seed(42)
spec_tokens, spec_stats = speculative_decode(context, 50, gamma=gamma)
speedup = std_calls / spec_stats['target_calls']
print(f"\nSpeculative (gamma={gamma}):")
print(f" Target model calls: {spec_stats['target_calls']}")
print(f" Draft model calls: {spec_stats['draft_calls']}")
print(f" Rounds: {spec_stats['rounds']}")
print(f" Avg acceptance: {spec_stats['avg_acceptance']:.1%}")
print(f" Target call speedup: {speedup:.2f}x")
Output:
python
============================================================
Generating 50 tokens from context: 'the cat sat'
============================================================
Standard decoding:
Target model calls: 50
Cost ratio baseline: 1.00x
Speculative (gamma=3):
Target model calls: 24
Draft model calls: 72
Rounds: 24
Avg acceptance: 70.8%
Target call speedup: 2.08x
Speculative (gamma=5):
Target model calls: 23
Draft model calls: 115
Rounds: 23
Avg acceptance: 47.0%
Target call speedup: 2.17x
Speculative (gamma=7):
Target model calls: 17
Draft model calls: 119
Rounds: 17
Avg acceptance: 46.2%
Target call speedup: 2.94x
With gamma=7, we cut target model calls from 50 down to 17. That's a 2.94x reduction. But there's a catch — we also made 119 draft model calls. The real-world speedup depends on how cheap those draft calls are.
Warning: The target call reduction is NOT the same as wall-clock speedup. If your draft model costs 20% of the target model (c=0.2), a 2.94x reduction in target calls with 119 draft calls gives roughly a 1.5x wall-clock speedup. If your draft model costs 50% of the target (c=0.5), you might actually be slower.
Acceptance Rate Deep Dive — What Makes It High or Low?
Acceptance rate is everything. It determines whether speculative decoding saves time or wastes it. Let's explore what drives it.
The acceptance rate depends on how close the draft model's distribution \(q(x)\) is to the target model's distribution \(p(x)\). When they agree, almost every draft token passes. When they diverge, rejections pile up.
We'll measure acceptance rates across 1000 verification rounds for different levels of "agreement" between models. We'll control agreement by adjusting how similar the draft model's logits are to the target's.
def measure_acceptance_rate(agreement_level, num_trials=1000):
"""Measure acceptance rate for a given model agreement level.
agreement_level: 0.0 (random) to 1.0 (identical models).
"""
accepted_count = 0
total_count = 0
for trial in range(num_trials):
rng = np.random.RandomState(trial)
# Target distribution
logits_target = rng.randn(VOCAB_SIZE)
logits_target[trial % VOCAB_SIZE] += 2.0
p = np.exp(logits_target - np.max(logits_target))
p = p / p.sum()
# Draft distribution: blend of target and random
logits_noise = rng.randn(VOCAB_SIZE) * 0.5
logits_draft = agreement_level * logits_target + (1 - agreement_level) * logits_noise
q = np.exp(logits_draft - np.max(logits_draft))
q = q / q.sum()
# Sample from draft and check acceptance
draft_token = rng.choice(VOCAB_SIZE, p=q)
accept_prob = min(1.0, p[draft_token] / q[draft_token])
if rng.random() < accept_prob:
accepted_count += 1
total_count += 1
return accepted_count / total_count
# Measure across agreement levels
agreement_levels = [0.0, 0.2, 0.4, 0.6, 0.8, 0.9, 0.95, 1.0]
rates = []
print("Agreement Level -> Acceptance Rate")
print("-" * 40)
for level in agreement_levels:
rate = measure_acceptance_rate(level)
rates.append(rate)
bar = '#' * int(rate * 40)
print(f" {level:.2f} -> {rate:.1%} {bar}")
Output:
python
Agreement Level -> Acceptance Rate
----------------------------------------
0.00 -> 52.9% #####################
0.20 -> 62.9% #########################
0.40 -> 71.8% ############################
0.60 -> 81.8% ################################
0.80 -> 91.7% ####################################
0.90 -> 96.4% ######################################
0.95 -> 98.4% #######################################
1.00 -> 100.0% ########################################
When models agree perfectly (1.0), every token is accepted. At moderate agreement (0.6), about 82% pass. Even with random draft distributions (0.0), you still get 53% — because sometimes the random guess happens to align.
Key Insight: The speedup ceiling of speculative decoding is
\(\frac{1}{1 - \alpha}\). At 80% acceptance rate, the ceiling is 5x. At 90%, it's 10x. But you never reach the ceiling because draft model calls aren't free.Greedy vs. Speculative — A Side-by-Side Comparison
You might wonder: how does greedy decoding (always pick the highest-probability token) compare to speculative decoding? They serve different purposes, but the comparison is instructive.
Greedy decoding picks argmax(p) at each step. No randomness. Speculative decoding samples from distributions and uses rejection sampling. When the target model is very confident, greedy and speculative behave similarly. When distributions are flat, speculative decoding adds more value.
Let's compare both approaches on peaked vs. flat distributions.
def greedy_decode(context, num_tokens):
"""Greedy decoding — always pick highest probability token."""
generated = list(context)
for _ in range(num_tokens):
p = target_model(generated)
token = np.argmax(p)
generated.append(token)
return generated[len(context):]
def compare_methods(context, num_tokens, gamma=5):
"""Compare greedy, standard sampling, and speculative."""
# Greedy
greedy_tokens = greedy_decode(context, num_tokens)
# Standard sampling
np.random.seed(42)
std_tokens, std_calls = standard_decode(context, num_tokens)
# Speculative
np.random.seed(42)
spec_tokens, spec_stats = speculative_decode(
context, num_tokens, gamma
)
return {
'greedy': tokens_to_text(greedy_tokens[:10]),
'standard': tokens_to_text(std_tokens[:10]),
'speculative': tokens_to_text(spec_tokens[:10]),
'std_calls': std_calls,
'spec_calls': spec_stats['target_calls'],
'acceptance': spec_stats['avg_acceptance']
}
context = [0, 1, 2] # "the cat sat"
result = compare_methods(context, 20, gamma=5)
print("Method comparison (first 10 tokens shown):")
print(f" Greedy: {result['greedy']}")
print(f" Standard: {result['standard']}")
print(f" Speculative: {result['speculative']}")
print(f"\nTarget model calls:")
print(f" Standard: {result['std_calls']}")
print(f" Speculative: {result['spec_calls']}")
print(f" Acceptance: {result['acceptance']:.1%}")
Output:
python
Method comparison (first 10 tokens shown):
Greedy: dog a dog a cat mat dog the cat ran
Standard: dog red ran on on the the ran ran on
Speculative: sat a sat sat sat ran the a on sat
Target model calls:
Standard: 20
Speculative: 11
Acceptance: 38.2%
Two things jump out. First, greedy decoding repeats patterns like "dog a dog a" because our simple model cycles through the same high-probability tokens. That's a known problem with greedy search.
Second, standard and speculative decoding produce different sequences because they use different random states internally. But statistically, they sample from the same target distribution. The speculative version needs only 11 target calls instead of 20.
When Speculative Decoding Hurts — The Cost Tradeoff
Speculative decoding isn't always a win. There are three scenarios where it backfires.
Scenario 1: The draft model is too expensive. If the draft model costs half as much as the target (\(c = 0.5\)), the overhead of \(\gamma\) draft calls eats into your savings. You need very high acceptance rates to break even.
Scenario 2: The draft model is a bad match. If the draft distribution diverges heavily from the target, most tokens get rejected. You pay for \(\gamma\) draft calls plus a target call, but only get 1-2 tokens per round. That's worse than standard decoding.
Scenario 3: The target model is already fast. Speculative decoding shines when the target model is the bottleneck. If you've already optimized the target with quantization, batching, or KV-cache tricks, the marginal gain shrinks.
Let's quantify scenario 1. We'll compute the effective wall-clock speedup for different cost ratios and acceptance rates.
def wall_clock_speedup(alpha, gamma, cost_ratio):
"""Compute wall-clock speedup of speculative decoding.
alpha: acceptance rate (0 to 1)
gamma: speculation length
cost_ratio: draft_time / target_time
"""
# Expected tokens per round
expected_tokens = (1 - alpha**(gamma + 1)) / (1 - alpha) if alpha < 1 else gamma + 1
# Cost per round: gamma draft calls + 1 target call
cost_per_round = cost_ratio * gamma + 1
# Standard cost: 1 target call per token
standard_cost_per_token = 1.0
# Speculative cost per token
spec_cost_per_token = cost_per_round / expected_tokens
return standard_cost_per_token / spec_cost_per_token
print("Wall-clock speedup: rows = acceptance rate, cols = cost ratio")
print("(gamma=5)")
print()
cost_ratios = [0.01, 0.05, 0.1, 0.2, 0.5]
alphas = [0.3, 0.5, 0.7, 0.8, 0.9, 0.95]
# Header
print(f"{'alpha':>6s}", end="")
for c in cost_ratios:
print(f" c={str(c):<5s}", end="")
print()
print("-" * 55)
for alpha in alphas:
print(f"{alpha:>6.2f}", end="")
for c in cost_ratios:
s = wall_clock_speedup(alpha, 5, c)
marker = " *" if s < 1.0 else " "
print(f" {s:>5.2f}{marker}", end="")
print()
print("\n* = slower than standard decoding")
Output:
python
Wall-clock speedup: rows = acceptance rate, cols = cost ratio
(gamma=5)
alpha c=0.01 c=0.05 c=0.1 c=0.2 c=0.5
-------------------------------------------------------
0.30 1.36 1.14 0.95 * 0.71 * 0.41 *
0.50 1.88 1.58 1.31 0.98 * 0.56 *
0.70 2.80 2.35 1.96 1.47 0.84 *
0.80 3.51 2.95 2.46 1.84 1.05
0.90 4.46 3.75 3.12 2.34 1.34
0.95 5.05 4.24 3.53 2.65 1.51
* = slower than standard decoding
The pattern is clear. When the draft model is nearly free (\(c = 0.01\)), even a 30% acceptance rate gives a 1.36x speedup. But when the draft model costs half the target (\(c = 0.5\)), you need above 80% acceptance just to break even.
The sweet spot in practice: $c$ between 0.05 and 0.15, with acceptance rates above 70%.
Tip: In practice, a good draft model is 10-30x smaller than the target. For example, use Llama-2-7B as the draft for Llama-2-70B. The cost ratio ends up around 0.05-0.10, which falls in the sweet spot.
Tuning Gamma — How Many Tokens to Speculate
Gamma (\(\gamma\)) is the number of tokens the draft model proposes per round. Higher gamma means more potential tokens per round — but also more wasted draft calls when a rejection happens early.
The optimal gamma depends on the acceptance rate. With high acceptance rates, longer speculation chains pay off. With low acceptance rates, short chains are safer.
Let's sweep gamma from 1 to 12 at different acceptance rates and find the sweet spot.
print("Optimal gamma search")
print("=" * 50)
for target_alpha in [0.5, 0.7, 0.85, 0.95]:
best_gamma = 1
best_speedup = 0
cost_ratio = 0.1
results = []
for gamma in range(1, 13):
s = wall_clock_speedup(target_alpha, gamma, cost_ratio)
results.append((gamma, s))
if s > best_speedup:
best_speedup = s
best_gamma = gamma
print(f"\nalpha={target_alpha:.2f}, cost_ratio={cost_ratio}:")
for gamma, s in results:
marker = " <-- best" if gamma == best_gamma else ""
bar = '#' * int(s * 10)
print(f" gamma={gamma:>2d}: {s:.2f}x {bar}{marker}")
Output:
python
Optimal gamma search
==================================================
alpha=0.50, cost_ratio=0.1:
gamma= 1: 1.36x #############
gamma= 2: 1.46x ############## <-- best
gamma= 3: 1.44x ##############
gamma= 4: 1.38x #############
gamma= 5: 1.31x #############
gamma= 6: 1.24x ############
gamma= 7: 1.17x ###########
gamma= 8: 1.11x ###########
gamma= 9: 1.05x ##########
gamma=10: 1.00x #########
gamma=11: 0.95x #########
gamma=12: 0.91x #########
alpha=0.70, cost_ratio=0.1:
gamma= 1: 1.55x ###############
gamma= 2: 1.83x ##################
gamma= 3: 1.95x ###################
gamma= 4: 1.98x ################### <-- best
gamma= 5: 1.96x ###################
gamma= 6: 1.91x ###################
gamma= 7: 1.85x ##################
gamma= 8: 1.78x #################
gamma= 9: 1.70x #################
gamma=10: 1.63x ################
gamma=11: 1.57x ###############
gamma=12: 1.50x ###############
alpha=0.85, cost_ratio=0.1:
gamma= 1: 1.68x ################
gamma= 2: 2.14x #####################
gamma= 3: 2.45x ########################
gamma= 4: 2.65x ##########################
gamma= 5: 2.77x ###########################
gamma= 6: 2.83x ############################
gamma= 7: 2.85x ############################ <-- best
gamma= 8: 2.85x ############################
gamma= 9: 2.82x ############################
gamma=10: 2.78x ###########################
gamma=11: 2.72x ###########################
gamma=12: 2.66x ##########################
alpha=0.95, cost_ratio=0.1:
gamma= 1: 1.77x #################
gamma= 2: 2.38x #######################
gamma= 3: 2.85x ############################
gamma= 4: 3.23x ################################
gamma= 5: 3.53x ###################################
gamma= 6: 3.77x #####################################
gamma= 7: 3.96x #######################################
gamma= 8: 4.11x #########################################
gamma= 9: 4.22x ##########################################
gamma=10: 4.31x ###########################################
gamma=11: 4.38x ###########################################
gamma=12: 4.42x ############################################ <-- best
At 50% acceptance, the optimal gamma is just 2 — anything longer wastes draft calls. At 70%, gamma=4 is the sweet spot. At 95% acceptance, longer chains keep paying off up to gamma=12. The curve is relatively flat near the optimum, so being off by 1-2 doesn't matter much.
{
type: 'exercise',
id: 'spec-decode-ex1',
title: 'Exercise 1: Compute Expected Tokens Per Round',
difficulty: 'beginner',
exerciseType: 'write',
instructions: 'Write a function expected_tokens(alpha, gamma) that computes the expected number of tokens generated per speculative decoding round. The formula is: E[tokens] = (1 - alpha^(gamma+1)) / (1 - alpha) for alpha < 1, and gamma + 1 when alpha = 1.0. Test it with alpha=0.8 and gamma=5. Print the result rounded to 2 decimal places.',
starterCode: 'def expected_tokens(alpha, gamma):\n # Compute expected tokens per round\n # Handle alpha=1.0 as a special case\n pass\n\n# Test: alpha=0.8, gamma=5\nresult = expected_tokens(0.8, 5)\nprint(f"{result:.2f}")',
testCases: [
{ id: 'tc1', input: 'print(f"{expected_tokens(0.8, 5):.2f}")', expectedOutput: '3.69', description: 'alpha=0.8, gamma=5 should give 3.69 expected tokens' },
{ id: 'tc2', input: 'print(f"{expected_tokens(1.0, 5):.2f}")', expectedOutput: '6.00', description: 'alpha=1.0 should give gamma+1=6' },
{ id: 'tc3', input: 'print(f"{expected_tokens(0.5, 3):.2f}")', expectedOutput: '1.88', description: 'alpha=0.5, gamma=3 should give 1.88', hidden: true },
],
hints: [
'The formula is (1 - alpha(gamma+1)) / (1 - alpha). Remember to handle the alpha=1.0 edge case separately to avoid division by zero.',
'Full solution: if alpha == 1.0: return gamma + 1; else: return (1 - alpha(gamma+1)) / (1 - alpha)',
],
solution: 'def expected_tokens(alpha, gamma):\n if alpha == 1.0:\n return gamma + 1\n return (1 - alpha**(gamma + 1)) / (1 - alpha)\n\nresult = expected_tokens(0.8, 5)\nprint(f"{result:.2f}")',
solutionExplanation: 'The formula comes from the geometric series. Each token has probability alpha of being accepted. The expected number of consecutive acceptances before a rejection follows a truncated geometric distribution capped at gamma, plus one bonus token if all pass.',
xpReward: 15,
}
Common Mistakes and How to Fix Them
Mistake 1: Ignoring the Draft Model's Cost
❌ Wrong thinking:
"I got a 4x reduction in target model calls, so my inference is 4x faster."
Why it's wrong: You made \(\gamma\) draft model calls per round. If the draft model costs 10% of the target and \(\gamma = 5\), each round costs $1 + 0.5 = 1.5$ target-equivalents, not 1.0. The real speedup is lower than the target-call reduction.
✅ Correct approach:
Always compute wall-clock speedup including draft model cost:
# Correct speedup calculation
draft_cost = cost_ratio * gamma * num_rounds
target_cost = num_rounds # One target call per round
total_spec_cost = draft_cost + target_cost
total_std_cost = num_tokens # Standard: one target call per token
real_speedup = total_std_cost / total_spec_cost
print(f"Real speedup: {real_speedup:.2f}x")
Mistake 2: Using the Same Model as Both Draft and Target
❌ Wrong:
# Draft and target are the same model draft_output = model.generate(prompt, max_tokens=5) target_output = model.verify(draft_output) # Same model!
Why it's wrong: If both models have identical distributions, the acceptance rate is 100% — but you're doing twice the work. You'd generate tokens faster with standard decoding. The draft model must be substantially cheaper than the target.
✅ Correct approach:
Use a model that's 10-30x smaller. For Llama-70B as target, use Llama-7B or even Llama-1B as the draft.
Mistake 3: Not Resampling on Rejection
❌ Wrong:
# Rejected — just sample from the target directly
if not accepted:
token = np.random.choice(VOCAB_SIZE, p=p_target)
Why it's wrong: Sampling from p_target directly after rejection introduces bias. The verification process already "used up" the probability mass where \(q(x) > p(x)\). You must sample from the corrected distribution \(\max(0, p(x) - q(x))\), normalized.
✅ Correct:
if not accepted:
residual = np.maximum(0, p_target - q_draft)
corrected = residual / residual.sum()
token = np.random.choice(VOCAB_SIZE, p=corrected)
This correction is what makes speculative decoding mathematically lossless.
{
type: 'exercise',
id: 'spec-decode-ex2',
title: 'Exercise 2: Compute the Corrected Distribution',
difficulty: 'intermediate',
exerciseType: 'write',
instructions: 'Write a function corrected_distribution(p_target, q_draft) that computes the resampling distribution used when a draft token is rejected. The formula is: take the element-wise maximum of (p_target - q_draft, 0), then normalize so it sums to 1. Return the corrected distribution as a numpy array. Test with p_target = [0.5, 0.3, 0.2] and q_draft = [0.2, 0.6, 0.2]. Print the result rounded to 3 decimal places.',
starterCode: 'import numpy as np\n\ndef corrected_distribution(p_target, q_draft):\n p = np.array(p_target)\n q = np.array(q_draft)\n # Compute residual and normalize\n pass\n\n# Test\np_target = [0.5, 0.3, 0.2]\nq_draft = [0.2, 0.6, 0.2]\nresult = corrected_distribution(p_target, q_draft)\nprint(np.round(result, 3))',
testCases: [
{ id: 'tc1', input: 'import numpy as np\nresult = corrected_distribution([0.5, 0.3, 0.2], [0.2, 0.6, 0.2])\nprint(np.round(result, 3))', expectedOutput: '[1. 0. 0.]', description: 'Residual is [0.3, 0, 0], normalized to [1, 0, 0]' },
{ id: 'tc2', input: 'import numpy as np\nresult = corrected_distribution([0.4, 0.4, 0.2], [0.3, 0.3, 0.4])\nprint(np.round(result, 3))', expectedOutput: '[0.5 0.5 0. ]', description: 'Residual is [0.1, 0.1, 0], normalized to [0.5, 0.5, 0]' },
],
hints: [
'Use np.maximum(p - q, 0) to clip negative values to zero. Then divide by the sum.',
'Full solution: residual = np.maximum(p - q, 0); return residual / residual.sum()',
],
solution: 'import numpy as np\n\ndef corrected_distribution(p_target, q_draft):\n p = np.array(p_target)\n q = np.array(q_draft)\n residual = np.maximum(p - q, 0)\n return residual / residual.sum()\n\np_target = [0.5, 0.3, 0.2]\nq_draft = [0.2, 0.6, 0.2]\nresult = corrected_distribution(p_target, q_draft)\nprint(np.round(result, 3))',
solutionExplanation: 'The corrected distribution captures the probability mass that the draft model "missed." Where p_target > q_draft, the target wanted more probability than the draft gave — so the resampled token should come from those under-represented regions. This ensures the final distribution matches p_target exactly.',
xpReward: 15,
}
When NOT to Use Speculative Decoding
Speculative decoding isn't a universal speedup. Here are the specific scenarios where you should skip it.
1. Batch inference with high throughput. Speculative decoding helps latency (time per request). But it can hurt throughput (tokens per second across many requests). When you're processing hundreds of requests in parallel, the GPU is already saturated. Draft model calls compete for the same compute.
2. Short outputs. If you're generating 5-10 tokens (classification labels, yes/no answers), the overhead exceeds the savings. Standard decoding is simpler and fast enough.
3. No good draft model available. The draft model must share the target's tokenizer. It also needs to produce similar distributions. If the best small model has a different vocabulary, speculative decoding won't work.
4. Streaming with strict latency needs. The draft-verify round creates a burst pattern. You wait for \(\gamma\) draft tokens, then one target pass, then emit multiple tokens at once. If your app needs smooth token-by-token streaming, this burst feels jerky.
Note: Speculative decoding and KV-cache optimization are complementary. You can use both. The target model's verification pass still benefits from cached key-value pairs for all previously accepted tokens.
Summary
Speculative decoding accelerates LLM inference by pairing a cheap draft model with an expensive target model. The draft proposes, the target verifies, and the output quality stays identical.
Here are the core takeaways:
| Concept | Key Point |
|---|---|
| Core mechanism | Draft \(\gamma\) tokens cheaply, verify in one target pass |
| Quality guarantee | Output distribution is identical to standard decoding |
| Acceptance formula | \(\min(1, p(x)/q(x))\) per token |
| Speedup drivers | High acceptance rate + low draft model cost |
| Optimal gamma | 2-7 depending on acceptance rate (50-85%) |
| When it helps | Single-request latency with large target models |
| When it hurts | High-throughput batched inference, very short outputs |
Practice Exercise: Extend the simulator to support adaptive gamma — where the speculation length adjusts based on recent acceptance rates. If the last 5 rounds had >80% acceptance, increase gamma by 1. If <50%, decrease by 1. Clamp between 2 and 10. Run it on 200 tokens and compare the total target calls against fixed-gamma decoding.
Complete Code
Frequently Asked Questions
Does speculative decoding work with any two models?
Not quite. The draft and target must share the same tokenizer. They need to agree on what a "token" is. Beyond that, closer distributions mean higher acceptance rates. In practice, people use smaller models from the same family (Llama-7B drafting for Llama-70B) or train dedicated draft models.
Can I use speculative decoding with quantized models?
Yes, and it's a powerful combination. Run a full-precision small model as the draft and a quantized large model as the target. The quantized target is already faster per call. Speculative decoding further reduces the number of calls. Some teams use a 4-bit quantized draft model to push the cost ratio even lower.
What happens if the draft model is perfect?
If the draft matches the target exactly (\(q = p\)), every token is accepted. You get \(\gamma + 1\) tokens per round. The speedup ceiling is \((\gamma + 1) / (c \cdot \gamma + 1)\). With \(\gamma = 5\) and \(c = 0.1\), that's 4x.
How does speculative decoding compare to beam search?
They solve different problems. Beam search explores multiple output sequences to find a better one. Speculative decoding speeds up generating a single sequence. You can combine them — use speculative decoding to accelerate each beam's token generation.
Is speculative decoding used in production systems?
Yes. vLLM, TensorRT-LLM, and Hugging Face Transformers all support it. Google's Gemini and Meta's Llama serving infrastructure use it in production. It's one of the most impactful inference optimizations alongside KV-cache and continuous batching.
References
- Leviathan, Y., Kalman, M., & Matias, Y. — "Fast Inference from Transformers via Speculative Decoding." ICML 2023. arXiv:2211.17192
- Chen, C., et al. — "Accelerating Large Language Model Decoding with Speculative Sampling." 2023. arXiv:2302.01318
- PyTorch Blog — "A Hitchhiker's Guide to Speculative Decoding." Link
- NVIDIA Technical Blog — "An Introduction to Speculative Decoding for Reducing Latency in AI Inference." Link
- vLLM Blog — "How Speculative Decoding Boosts vLLM Performance by up to 2.8x." Link
- BentoML — "Get 3x Faster LLM Inference with Speculative Decoding Using the Right Draft Model." Link
- Xia, H., et al. — "Unlocking Efficiency in Large Language Model Inference: A Comprehensive Survey of Speculative Decoding." ACL 2024. arXiv:2401.07851
Free Course
Master Core Python — Your First Step into AI/ML
Build a strong Python foundation with hands-on exercises designed for aspiring Data Scientists and AI/ML Engineers.
Start Free Course →Trusted by 50,000+ learners
Related Course
Master Gen AI — Hands-On
Join 5,000+ students at edu.machinelearningplus.com
Explore Course
