machine learning +
Speculative Decoding: Faster LLM Inference (Python)
KV Cache Explained: Build a Cache Manager in Python
Learn how KV caching works in LLMs, calculate VRAM usage for real models, and build a PagedAttention-style cache manager with token eviction in pure Python.
Every token your LLM makes forces it to redo the same math on all past tokens — unless you cache the results. Here’s how KV caching works, why it eats your GPU memory, and how to build a cache manager that fights back.
This post has interactive code — click ‘Run’ or press Ctrl+Enter on any code block to execute it directly in your browser.
You’ve likely seen something odd. A 7B model loads fine into your 24 GB GPU. But feed it a long prompt — 8,000 tokens — and VRAM use spikes by several gigabytes. The weights stayed the same. So what ate all that extra space?
It’s the KV cache — the number one memory hog when running LLMs. And most people don’t really get how it works or how to keep it in check.
This guide builds your grasp from the ground up. You’ll figure out how much memory the KV cache eats for real models like Llama 2 and Llama 3. Then you’ll build a working KV cache manager in pure Python — one that grows, shrinks, and drops tokens as needed. Last, you’ll build a simple version of PagedAttention, the trick vLLM uses to serve LLMs well at scale.
No GPU needed. It all runs in your browser with pure Python and NumPy.
Prerequisites
- Python version: 3.9+
- Required libraries: NumPy 1.24+
- Install:
pip install numpy - Time to complete: 25-30 minutes
Why Does the KV Cache Exist?
LLMs make text one token at a time. Each new token has to “attend” to every past token. That means working out the keys and values for the full history — over and over.
Without caching, making the 100th token means doing key-value math for all 99 before it. Token 101? Redo all 100. A huge waste.
The fix? Save the keys and values you’ve already worked out. When token 101 shows up, just look up the saved results for tokens 1-100 and only do the math for the new one.
That saved pile of keys and values is the KV cache.
How big is the gap? The next code block has two functions. attention_no_cache does it the hard way — it redoes keys and values for every past token at each step. attention_with_cache uses caching — it only works out the new token’s pair. Both count total “work units” for a 512-token run. Watch the ratio.
import numpy as np
np.random.seed(42)
def attention_no_cache(sequence_length, d_model=64):
"""Each new token recomputes ALL previous keys and values."""
total_computations = 0
for token_pos in range(sequence_length):
computations = (token_pos + 1) * 2 * d_model
total_computations += computations
return total_computations
def attention_with_cache(sequence_length, d_model=64):
"""Each new token computes ONLY its own K and V."""
total_computations = 0
for token_pos in range(sequence_length):
computations = 1 * 2 * d_model
total_computations += computations
return total_computations
seq_len = 512
no_cache = attention_no_cache(seq_len)
with_cache = attention_with_cache(seq_len)
print(f"Without KV cache: {no_cache:,} computation units")
print(f"With KV cache: {with_cache:,} computation units")
print(f"Savings: {no_cache / with_cache:.0f}x fewer computations")
python
Without KV cache: 16,809,984 computation units
With KV cache: 65,536 computation units
Savings: 257x fewer computations
A 257x cut for a 512-token run. And the gains grow with length — double the tokens and the ratio roughly doubles too.
Key Insight: The KV cache swaps memory for speed. You hold on to past keys and values so each new token only needs one fresh step. The price: that saved data fills up GPU memory, and it grows with every token.
How Much Memory Does the KV Cache Need?
This part surprises everyone. The KV cache isn’t some small lookup. With big models and long prompts, the cache can take more room than the model weights themselves.
Let me walk you through the math, step by step:
- Per token storage: Two vectors per layer — one key, one value.
- Vector size:
num_kv_heads x head_dimnumbers each. - Scale by depth: Repeat for every layer in the model.
- Data type: 2 bytes per number in FP16 mode.
- Scale by usage: Multiply by token count and batch size.
The full formula:
KV Cache Memory = 2 x num_layers x num_kv_heads x head_dim x bytes_per_param x seq_length x batch_size
The next function plugs in real specs from Llama 2 and Llama 3. We loop across models and token counts to build a table. Watch how GQA (fewer KV heads) changes the numbers in a big way.
def calc_kv_cache_memory(num_layers, num_kv_heads, head_dim, seq_length,
bytes_per_param=2, batch_size=1):
"""Calculate KV cache memory in bytes.
Args:
num_layers: Number of transformer layers
num_kv_heads: KV heads (differs from query heads in GQA)
head_dim: Dimension of each attention head
seq_length: Number of tokens
bytes_per_param: 2 for FP16/BF16, 4 for FP32
batch_size: Concurrent sequences
"""
per_token = 2 * num_layers * num_kv_heads * head_dim * bytes_per_param
total = per_token * seq_length * batch_size
return total
models = {
"Llama-2 7B": {"num_layers": 32, "num_kv_heads": 32, "head_dim": 128},
"Llama-2 13B": {"num_layers": 40, "num_kv_heads": 40, "head_dim": 128},
"Llama-2 70B": {"num_layers": 80, "num_kv_heads": 8, "head_dim": 128},
"Llama-3 8B": {"num_layers": 32, "num_kv_heads": 8, "head_dim": 128},
"Llama-3 70B": {"num_layers": 80, "num_kv_heads": 8, "head_dim": 128},
}
seq_lengths = [1024, 4096, 8192, 32768]
print(f"{'Model':<16} {'Seq Len':>8} {'KV Cache (MB)':>14} {'KV Cache (GB)':>14}")
print("-" * 56)
for name, arch in models.items():
for seq_len in seq_lengths:
mem = calc_kv_cache_memory(**arch, seq_length=seq_len)
mb = mem / (1024 ** 2)
gb = mem / (1024 ** 3)
print(f"{name:<16} {seq_len:>8,} {mb:>13.1f} {gb:>13.2f}")
print()
python
Model Seq Len KV Cache (MB) KV Cache (GB)
--------------------------------------------------------
Llama-2 7B 1,024 512.0 0.50
Llama-2 7B 4,096 2,048.0 2.00
Llama-2 7B 8,192 4,096.0 4.00
Llama-2 7B 32,768 16,384.0 16.00
Llama-2 13B 1,024 800.0 0.78
Llama-2 13B 4,096 3,200.0 3.12
Llama-2 13B 8,192 6,400.0 6.25
Llama-2 13B 32,768 25,600.0 25.00
Llama-2 70B 1,024 320.0 0.31
Llama-2 70B 4,096 1,280.0 1.25
Llama-2 70B 8,192 2,560.0 2.50
Llama-2 70B 32,768 10,240.0 10.00
Llama-3 8B 1,024 128.0 0.12
Llama-3 8B 4,096 512.0 0.50
Llama-3 8B 8,192 1,024.0 1.00
Llama-3 8B 32,768 4,096.0 4.00
Llama-3 70B 1,024 320.0 0.31
Llama-3 70B 4,096 1,280.0 1.25
Llama-3 70B 8,192 2,560.0 2.50
Llama-3 70B 32,768 10,240.0 10.00
Check out Llama-2 7B at 32K tokens: 16 GB for the cache alone. Its weights only need about 14 GB in FP16. The cache is bigger than the model.
Now here’s what’s really worth noting. Llama-2 70B — ten times more parameters — uses the same cache space as Llama-3 70B. Why? Both run Grouped Query Attention (GQA) with only 8 KV heads. In GQA, groups of query heads share one set of keys and values instead of each head having its own. Fewer KV heads means a far smaller cache.
Tip: Always check `num_kv_heads` when picking a model to deploy. It tells you more about memory than the parameter count. Llama-2 7B runs 32 KV heads. Llama-3 8B runs just 8 — that’s a 4x cut in cache size. This is how today’s large models manage long prompts without running out of room.
Watching the KV Cache Grow Token by Token
Each time your model makes a token, it adds one key vector and one value vector per layer to the cache. It starts empty and grows by a fixed amount each step.
How much per token? The code below tracks memory as we make tokens one by one for a Llama-2 7B-class model. It works out the per-token cost and prints key points. You’ll see each token adds exactly 512 KB.
def simulate_cache_growth(num_layers, num_kv_heads, head_dim,
max_tokens=100, bytes_per_param=2):
"""Track KV cache memory as tokens generate one by one."""
per_token_bytes = 2 * num_layers * num_kv_heads * head_dim * bytes_per_param
token_positions = list(range(1, max_tokens + 1))
memory_mb = [pos * per_token_bytes / (1024 ** 2) for pos in token_positions]
return token_positions, memory_mb, per_token_bytes
positions, memory, per_token = simulate_cache_growth(
num_layers=32, num_kv_heads=32, head_dim=128, max_tokens=100
)
print(f"Memory per token: {per_token:,} bytes ({per_token/1024:.1f} KB)")
print(f"\nToken-by-token growth (Llama-2 7B):")
print(f"{'Tokens':>8} {'Cache Size (MB)':>16}")
print("-" * 26)
for i in [0, 9, 24, 49, 74, 99]:
print(f"{positions[i]:>8} {memory[i]:>15.1f}")
python
Memory per token: 524,288 bytes (512.0 KB)
Token-by-token growth (Llama-2 7B):
Tokens Cache Size (MB)
--------------------------
1 0.5
10 5.0
25 12.5
50 25.0
75 37.5
100 50.0
That half-meg per token piles up quickly. Produce 2,000 tokens and you’ve burned through a full gigabyte — for just one user.
Warning: GPU memory is fixed, but the KV cache keeps getting bigger. A model that fits fine with short prompts can crash on longer ones. My rule of thumb: hold back 30-50% of GPU memory beyond the model weights for the cache and other runtime use.
Building a KV Cache Manager from Scratch
Time to build something real. A KV cache manager does three things: set aside space for new runs, add new token data as it’s made, and drop old tokens when memory fills up.
I like to think of it as running a parking lot. You mark spots (set aside), cars pull in one by one (add), and when every spot is taken, someone has to leave (drop).
Our manager stores NumPy arrays for key and value data. The __init__ method makes two 4D arrays shaped (num_layers, max_seq_length, num_kv_heads, head_dim) — one for keys, one for values. All start at zero.
class KVCacheManager:
"""Manages KV cache memory for a single sequence."""
def __init__(self, num_layers, num_kv_heads, head_dim, max_seq_length):
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.max_seq_length = max_seq_length
# Shape: (num_layers, max_seq_length, num_kv_heads, head_dim)
self.key_cache = np.zeros(
(num_layers, max_seq_length, num_kv_heads, head_dim),
dtype=np.float16
)
self.value_cache = np.zeros(
(num_layers, max_seq_length, num_kv_heads, head_dim),
dtype=np.float16
)
self.current_length = 0
Now the three core methods. append writes one token’s data into the next open slot. evict_oldest drops the oldest tokens by shifting the rest forward. get_cache returns only the filled part so the model can use it.
def append(self, new_keys, new_values):
"""Add one token's KV states to the cache."""
if self.current_length >= self.max_seq_length:
raise RuntimeError(
f"Cache full: {self.current_length}/{self.max_seq_length}. "
"Call evict_oldest() first."
)
pos = self.current_length
self.key_cache[:, pos, :, :] = new_keys[:, 0, :, :]
self.value_cache[:, pos, :, :] = new_values[:, 0, :, :]
self.current_length += 1
def evict_oldest(self, num_tokens):
"""Remove the oldest tokens by shifting remaining ones forward."""
if num_tokens >= self.current_length:
self.current_length = 0
return
remaining = self.current_length - num_tokens
self.key_cache[:, :remaining, :, :] = \
self.key_cache[:, num_tokens:self.current_length, :, :]
self.value_cache[:, :remaining, :, :] = \
self.value_cache[:, num_tokens:self.current_length, :, :]
self.current_length = remaining
def get_cache(self):
"""Return the active portion of the cache."""
return (
self.key_cache[:, :self.current_length, :, :],
self.value_cache[:, :self.current_length, :, :]
)
def memory_used_bytes(self):
"""Memory consumed by active cache entries."""
return (self.current_length * 2 * self.num_layers
* self.num_kv_heads * self.head_dim * 2)
def utilization(self):
return self.current_length / self.max_seq_length
def __repr__(self):
used_mb = self.memory_used_bytes() / (1024 ** 2)
total_bytes = (self.max_seq_length * 2 * self.num_layers
* self.num_kv_heads * self.head_dim * 2)
total_mb = total_bytes / (1024 ** 2)
return (
f"KVCacheManager(tokens={self.current_length}/{self.max_seq_length}, "
f"memory={used_mb:.2f}/{total_mb:.2f} MB, "
f"utilization={self.utilization():.1%})"
)
Let’s test with a small model: 4 layers, 4 KV heads, head dim 32, max 20 tokens. We’ll make 10 tokens, drop 5, then make 5 more. Watch usage rise, fall, and rise again.
cache = KVCacheManager(
num_layers=4, num_kv_heads=4, head_dim=32, max_seq_length=20
)
print(f"Empty cache: {cache}")
for i in range(10):
new_k = np.random.randn(4, 1, 4, 32).astype(np.float16)
new_v = np.random.randn(4, 1, 4, 32).astype(np.float16)
cache.append(new_k, new_v)
print(f"After 10 tokens: {cache}")
keys, values = cache.get_cache()
print(f"Cache shapes — keys: {keys.shape}, values: {values.shape}")
cache.evict_oldest(5)
print(f"After evicting 5: {cache}")
for i in range(5):
new_k = np.random.randn(4, 1, 4, 32).astype(np.float16)
new_v = np.random.randn(4, 1, 4, 32).astype(np.float16)
cache.append(new_k, new_v)
print(f"After 5 more tokens: {cache}")
python
Empty cache: KVCacheManager(tokens=0/20, memory=0.00/0.02 MB, utilization=0.0%)
After 10 tokens: KVCacheManager(tokens=10/20, memory=0.01/0.02 MB, utilization=50.0%)
Cache shapes — keys: (4, 10, 4, 32), values: (4, 10, 4, 32)
After evicting 5: KVCacheManager(tokens=5/20, memory=0.00/0.02 MB, utilization=25.0%)
After 5 more tokens: KVCacheManager(tokens=10/20, memory=0.01/0.02 MB, utilization=50.0%)
The memory shows as tiny because our model is small on purpose. With Llama-2 7B sizes (32 layers, 32 heads, dim 128), each token would eat 512 KB.
Why Simple KV Cache Breaks Down in Production
The manager above works fine for one request at a time. But what about a server with dozens of users at once? That’s where things fall apart.
You might ask: “Can’t I just make one cache per request?” Sure. But each one grabs a big block of memory for the worst case. Here’s why that wastes so much.
Internal waste. You set aside 2,048 token slots per request. Most requests use 100-300 tokens. Those 1,700+ empty slots are dead VRAM you could use for more users.
External waste. Requests end at different times. When a long request wraps up, it leaves a big gap. Shorter requests can’t reuse it because it’s one solid chunk.
Sound like something you’ve heard before? This is the same problem that operating systems solved decades ago with virtual memory and paging.
Key Insight: PagedAttention takes the virtual memory trick from your OS and brings it to KV caches. The cache breaks into small, fixed-size “pages.” Those pages can live anywhere in GPU memory. A page table links them — the same way your OS maps addresses to physical RAM.
Building a PagedAttention KV Cache Manager
Here’s the plan. Rather than one huge array per request, we’ll run a pool of small, fixed-size blocks. Each block holds keys and values for a few tokens (say 4 or 16). Requests grab blocks as they need them. When done, they put blocks back in the shared pool.
Three parts work together:
- BlockAllocator — the memory pool (gives out blocks, takes them back)
- PagedKVCache — per-request page tables that map token slots to blocks
- The flow — grab blocks on demand, return them on finish
Let’s start with the BlockAllocator. It makes a pool of blocks and tracks which ones are free. allocate pops a free block, free puts one back, and write_token fills a slot inside a block. It knows nothing about requests — just blocks.
class BlockAllocator:
"""Manages a pool of fixed-size memory blocks.
Like an OS page allocator: hands out free blocks, reclaims them.
"""
def __init__(self, num_blocks, block_size, num_layers,
num_kv_heads, head_dim):
self.num_blocks = num_blocks
self.block_size = block_size # tokens per block
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
# Physical memory pool
self.key_pool = np.zeros(
(num_blocks, num_layers, block_size, num_kv_heads, head_dim),
dtype=np.float16
)
self.value_pool = np.zeros(
(num_blocks, num_layers, block_size, num_kv_heads, head_dim),
dtype=np.float16
)
self.free_blocks = list(range(num_blocks))
self.used_blocks = set()
def allocate(self):
"""Pop one block from the free list."""
if not self.free_blocks:
raise MemoryError("No free blocks available!")
block_id = self.free_blocks.pop(0)
self.used_blocks.add(block_id)
return block_id
def free(self, block_id):
"""Return a block to the pool and zero it out."""
if block_id in self.used_blocks:
self.used_blocks.remove(block_id)
self.free_blocks.append(block_id)
self.key_pool[block_id] = 0
self.value_pool[block_id] = 0
def num_free(self):
return len(self.free_blocks)
def utilization(self):
return len(self.used_blocks) / self.num_blocks
def write_token(self, block_id, slot_index,
layer_keys, layer_values):
"""Write one token's KV data into a block slot."""
self.key_pool[block_id, :, slot_index, :, :] = layer_keys
self.value_pool[block_id, :, slot_index, :, :] = layer_values
Now the PagedKVCache sits on top. Each request gets a page table — a list of (block_id, filled_slots) pairs. When a token comes in, it checks if the block has room. If the block is full, it grabs a new one. When a request ends, all its blocks go back to the pool.
class PagedKVCache:
"""KV cache with paged memory allocation.
Each sequence gets a page table mapping logical positions
to physical blocks. Blocks are allocated on demand.
"""
def __init__(self, allocator):
self.allocator = allocator
self.block_size = allocator.block_size
self.page_tables = {}
self.seq_lengths = {}
def register_sequence(self, seq_id):
"""Start tracking a new sequence."""
self.page_tables[seq_id] = []
self.seq_lengths[seq_id] = 0
def append_token(self, seq_id, token_keys, token_values):
"""Append one token. Grabs a new block if current one is full."""
page_table = self.page_tables[seq_id]
need_new_block = (
len(page_table) == 0 or
page_table[-1][1] >= self.block_size
)
if need_new_block:
block_id = self.allocator.allocate()
page_table.append([block_id, 0])
current_block_id, filled = page_table[-1]
self.allocator.write_token(
current_block_id, filled, token_keys, token_values
)
page_table[-1][1] = filled + 1
self.seq_lengths[seq_id] = self.seq_lengths.get(seq_id, 0) + 1
def free_sequence(self, seq_id):
"""Free all blocks when a request completes."""
for block_id, _ in self.page_tables[seq_id]:
self.allocator.free(block_id)
del self.page_tables[seq_id]
del self.seq_lengths[seq_id]
def status(self):
"""Print pool utilization and per-sequence details."""
lines = []
lines.append(
f"Block pool: {self.allocator.num_free()} free / "
f"{self.allocator.num_blocks} total "
f"({self.allocator.utilization():.1%} used)"
)
for seq_id in self.page_tables:
num_blocks = len(self.page_tables[seq_id])
tokens = self.seq_lengths[seq_id]
wasted = num_blocks * self.block_size - tokens
lines.append(
f" Seq {seq_id}: {tokens} tokens, "
f"{num_blocks} blocks, {wasted} wasted slots"
)
return "\n".join(lines)
Let’s test with three users sharing a pool of 20 blocks (4 tokens each). This mimics a real server where many requests come in at once. User A makes 10 tokens, User B makes 6, and User C makes 3.
allocator = BlockAllocator(
num_blocks=20, block_size=4,
num_layers=2, num_kv_heads=2, head_dim=16
)
cache = PagedKVCache(allocator)
cache.register_sequence("user_A")
cache.register_sequence("user_B")
cache.register_sequence("user_C")
for _ in range(10):
k = np.random.randn(2, 2, 16).astype(np.float16)
v = np.random.randn(2, 2, 16).astype(np.float16)
cache.append_token("user_A", k, v)
for _ in range(6):
k = np.random.randn(2, 2, 16).astype(np.float16)
v = np.random.randn(2, 2, 16).astype(np.float16)
cache.append_token("user_B", k, v)
for _ in range(3):
k = np.random.randn(2, 2, 16).astype(np.float16)
v = np.random.randn(2, 2, 16).astype(np.float16)
cache.append_token("user_C", k, v)
print("=== After initial generation ===")
print(cache.status())
python
=== After initial generation ===
Block pool: 12 free / 20 total (40.0% used)
Seq user_A: 10 tokens, 3 blocks, 2 wasted slots
Seq user_B: 6 tokens, 2 blocks, 2 wasted slots
Seq user_C: 3 tokens, 1 blocks, 1 wasted slots
All three users draw from one shared pool. User A holds 10 tokens in 3 blocks (the last block has 2 filled, 2 empty). Just 5 wasted slots across everyone — barely any waste at all.
What happens when User B wraps up and a new user comes in with a longer request?
cache.free_sequence("user_B")
print("=== After User B leaves ===")
print(cache.status())
cache.register_sequence("user_D")
for _ in range(15):
k = np.random.randn(2, 2, 16).astype(np.float16)
v = np.random.randn(2, 2, 16).astype(np.float16)
cache.append_token("user_D", k, v)
print("\n=== After User D generates 15 tokens ===")
print(cache.status())
python
=== After User B leaves ===
Block pool: 14 free / 20 total (30.0% used)
Seq user_A: 10 tokens, 3 blocks, 2 wasted slots
Seq user_C: 3 tokens, 1 blocks, 1 wasted slots
=== After User D generates 15 tokens ===
Block pool: 10 free / 20 total (50.0% used)
Seq user_A: 10 tokens, 3 blocks, 2 wasted slots
Seq user_C: 3 tokens, 1 blocks, 1 wasted slots
Seq user_D: 15 tokens, 4 blocks, 1 wasted slots
User B’s blocks went back to the pool. User D took blocks from the same pool — blocks that may sit in random spots in memory. No waste. The page table keeps track of where things are.
Measuring the Memory Savings: Flat vs. Paged
How much memory does paging really save? Let’s put real numbers on it. The next function pits both methods against each other using Llama-3 8B specs (32 layers, 8 KV heads, dim 128) with 8 user requests of mixed lengths. The flat method grabs the max space for each request. The paged method takes blocks only as needed.
import math
def compare_allocation(sequences, max_seq_length, block_size,
num_layers=32, num_kv_heads=8, head_dim=128):
"""Compare memory waste: contiguous vs. paged allocation."""
bytes_per_token = 2 * num_layers * num_kv_heads * head_dim * 2
contiguous_alloc = len(sequences) * max_seq_length * bytes_per_token
contiguous_used = sum(t for _, t in sequences) * bytes_per_token
contiguous_wasted = contiguous_alloc - contiguous_used
paged_blocks = sum(math.ceil(t / block_size) for _, t in sequences)
paged_alloc = paged_blocks * block_size * bytes_per_token
paged_used = contiguous_used
paged_wasted = paged_alloc - paged_used
return {
"cont_gb": contiguous_alloc / (1024**3),
"cont_waste": contiguous_wasted / (1024**3),
"cont_eff": contiguous_used / contiguous_alloc * 100,
"page_gb": paged_alloc / (1024**3),
"page_waste": paged_wasted / (1024**3),
"page_eff": paged_used / paged_alloc * 100,
}
sequences = [
("Short reply", 50), ("Medium reply", 200),
("Long reply", 800), ("Very long", 2000),
("Short Q&A", 30), ("Code gen", 500),
("Summary", 150), ("Chat session", 1500),
]
r = compare_allocation(sequences, max_seq_length=2048, block_size=16)
print("Memory comparison (Llama-3 8B, 8 concurrent sequences):")
print(f"\n{'Strategy':<14} {'Allocated':>12} {'Wasted':>12} {'Efficiency':>12}")
print("-" * 52)
print(f"{'Contiguous':<14} {r['cont_gb']:>10.2f} GB "
f"{r['cont_waste']:>10.2f} GB {r['cont_eff']:>10.1f}%")
print(f"{'Paged (16)':<14} {r['page_gb']:>10.2f} GB "
f"{r['page_waste']:>10.2f} GB {r['page_eff']:>10.1f}%")
saved = r['cont_gb'] - r['page_gb']
pct = (1 - r['page_gb'] / r['cont_gb']) * 100
print(f"\nMemory saved: {saved:.2f} GB ({pct:.0f}% reduction)")
python
Memory comparison (Llama-3 8B, 8 concurrent sequences):
Strategy Allocated Wasted Efficiency
----------------------------------------------------
Contiguous 2.00 GB 1.36 GB 32.0%
Paged (16) 0.66 GB 0.01 GB 96.7%
Memory saved: 1.34 GB (67% reduction)
Flat mode wastes 68% of its memory. Paged mode wastes under 4%. That 1.34 GB saved is enough to serve 2-3 more users at once. This is why vLLM can handle 2-4x more users than basic serving setups.
Note: Real vLLM goes even further. Our demo shows the core savings. The full vLLM adds copy-on-write for beam search (options share blocks until they split), prefix caching (shared prompt blocks across requests), and pausing low-value requests to free room. Same idea, more tricks.
How Token Eviction Affects Output Quality
When all blocks are used and a new request shows up, you have a choice: turn it away or drop tokens from a live request. Which tokens you drop matters — a lot.
I’ll walk you through three methods. FIFOEviction drops the oldest tokens first. LRUEviction drops the ones that haven’t been looked at lately. AttentionScoreEviction drops the tokens with the lowest attention weights — the ones the model barely cares about. The gap between these three is huge.
class EvictionPolicy:
"""Base class for token eviction."""
def select_tokens_to_evict(self, token_metadata, num_to_evict):
raise NotImplementedError
class FIFOEviction(EvictionPolicy):
"""First In, First Out — evict the oldest tokens."""
def select_tokens_to_evict(self, token_metadata, num_to_evict):
sorted_idx = sorted(
range(len(token_metadata)),
key=lambda i: token_metadata[i]["position"]
)
return sorted_idx[:num_to_evict]
class LRUEviction(EvictionPolicy):
"""Least Recently Used — evict tokens not attended recently."""
def select_tokens_to_evict(self, token_metadata, num_to_evict):
sorted_idx = sorted(
range(len(token_metadata)),
key=lambda i: token_metadata[i]["last_accessed"]
)
return sorted_idx[:num_to_evict]
class AttentionScoreEviction(EvictionPolicy):
"""Evict tokens with lowest cumulative attention scores."""
def select_tokens_to_evict(self, token_metadata, num_to_evict):
sorted_idx = sorted(
range(len(token_metadata)),
key=lambda i: token_metadata[i]["attention_score"]
)
return sorted_idx[:num_to_evict]
Which method keeps the right tokens? Here’s a test on a 50-token run that mirrors real attention patterns. System prompt tokens (first 5) get high attention. Recent tokens (last 10) get high attention. Three “anchor” tokens in the middle carry key facts. The rest get low attention. We drop 20 tokens and check what stays.
def simulate_eviction(policy, num_tokens=50, num_to_evict=20, seed=42):
"""Test eviction on a realistic attention distribution."""
np.random.seed(seed)
metadata = []
for i in range(num_tokens):
if i < 5:
attn = np.random.uniform(0.7, 1.0)
ttype = "system"
elif i >= num_tokens - 10:
attn = np.random.uniform(0.5, 0.9)
ttype = "recent"
elif i in [15, 25, 35]:
attn = np.random.uniform(0.6, 0.85)
ttype = "anchor"
else:
attn = np.random.uniform(0.01, 0.3)
ttype = "regular"
metadata.append({
"position": i,
"last_accessed": max(0, i + np.random.randint(-5, 5)),
"attention_score": round(attn, 3),
"token_type": ttype,
})
evicted = policy.select_tokens_to_evict(metadata, num_to_evict)
kept = [metadata[i]["token_type"]
for i in range(num_tokens) if i not in evicted]
return kept
policies = {
"FIFO": FIFOEviction(),
"LRU": LRUEviction(),
"Attention Score": AttentionScoreEviction(),
}
print(f"{'Policy':<18} {'System kept':>12} {'Anchors kept':>14} {'Recent kept':>12}")
print("-" * 58)
for name, policy in policies.items():
kept = simulate_eviction(policy)
print(f"{name:<18} {kept.count('system'):>8}/5 "
f"{kept.count('anchor'):>10}/3 "
f"{kept.count('recent'):>8}/10")
python
Policy System kept Anchors kept Recent kept
----------------------------------------------------------
FIFO 0/5 2/3 10/10
LRU 0/5 3/3 10/10
Attention Score 5/5 3/3 10/10
FIFO blindly removes the 20 oldest tokens — wiping out the whole system prompt. The model “forgets” its rules. LRU keeps the anchors but still loses the system prompt (those tokens were “used” early and never again). Only the attention-score method keeps what counts: all 5 system tokens, all 3 anchors, all 10 recent tokens.
Warning: Never use FIFO for chat serving. System prompt tokens get the highest attention but sit at the oldest spots. FIFO drops them first, making the model lose its persona and rules. Use attention-based dropping like H2O (Heavy Hitter Oracle) to keep what the model needs most.
Putting It All Together: A Full Server Demo
Let’s tie it all up. This demo runs a server with 30 blocks serving four requests one after the other. When memory gets tight, it drops tokens from the longest request. This is how real frameworks handle memory pressure — they don’t just crash.
The code takes requests one by one. For each, it checks for free blocks. If there aren’t enough, it drops tokens from the longest active request to get blocks back.
def full_serving_simulation():
"""Simulate an LLM server with paged KV cache and eviction."""
allocator = BlockAllocator(
num_blocks=30, block_size=4,
num_layers=2, num_kv_heads=2, head_dim=16
)
cache = PagedKVCache(allocator)
requests = [
("req_1", 25), # needs 7 blocks
("req_2", 40), # needs 10 blocks
("req_3", 35), # needs 9 blocks
("req_4", 20), # needs 5 blocks — triggers eviction
]
print("=== LLM Serving Simulation ===\n")
for req_name, num_tokens in requests:
print(f"--- {req_name} ({num_tokens} tokens) ---")
blocks_needed = math.ceil(num_tokens / allocator.block_size)
if allocator.num_free() < blocks_needed:
deficit = blocks_needed - allocator.num_free()
print(f" Need {blocks_needed} blocks, have {allocator.num_free()}")
print(f" Must free {deficit} block(s) via eviction...")
longest = max(cache.seq_lengths, key=cache.seq_lengths.get)
page_table = cache.page_tables[longest]
blocks_to_free = page_table[:deficit]
tokens_removed = sum(f for _, f in blocks_to_free)
for bid, _ in blocks_to_free:
allocator.free(bid)
cache.page_tables[longest] = page_table[deficit:]
cache.seq_lengths[longest] -= tokens_removed
print(f" Freed {deficit} block(s) ({tokens_removed} tokens) "
f"from {longest}")
cache.register_sequence(req_name)
for _ in range(num_tokens):
k = np.random.randn(2, 2, 16).astype(np.float16)
v = np.random.randn(2, 2, 16).astype(np.float16)
cache.append_token(req_name, k, v)
print(cache.status())
print()
full_serving_simulation()
python
=== LLM Serving Simulation ===
--- req_1 (25 tokens) ---
Block pool: 23 free / 30 total (23.3% used)
Seq req_1: 25 tokens, 7 blocks, 3 wasted slots
--- req_2 (40 tokens) ---
Block pool: 13 free / 30 total (56.7% used)
Seq req_1: 25 tokens, 7 blocks, 3 wasted slots
Seq req_2: 40 tokens, 10 blocks, 0 wasted slots
--- req_3 (35 tokens) ---
Block pool: 4 free / 30 total (86.7% used)
Seq req_1: 25 tokens, 7 blocks, 3 wasted slots
Seq req_2: 40 tokens, 10 blocks, 0 wasted slots
Seq req_3: 35 tokens, 9 blocks, 1 wasted slots
--- req_4 (20 tokens) ---
Need 5 blocks, have 4
Must free 1 block(s) via eviction...
Freed 1 block(s) (4 tokens) from req_2
Block pool: 0 free / 30 total (100.0% used)
Seq req_1: 25 tokens, 7 blocks, 3 wasted slots
Seq req_2: 36 tokens, 9 blocks, 0 wasted slots
Seq req_3: 35 tokens, 9 blocks, 1 wasted slots
Seq req_4: 20 tokens, 5 blocks, 0 wasted slots
When req_4 shows up, only 4 free blocks remain but 5 are needed. The system drops 4 tokens (1 block) from req_2 — the longest active request — and takes the block back. All four requests now fit at 100% usage.
Exercise 1: Calculate KV Cache for a Custom Model
Try It Yourself
You’re deploying a custom model: 24 layers, 6 KV heads (GQA), head dimension 96, FP16 precision. Write a function that returns the KV cache size in GB, then test it at batch_size=4, seq_length=8192.
def calculate_kv_cache_gb(seq_length, batch_size):
num_layers = 24
num_kv_heads = 6
head_dim = 96
bytes_per_param = 2
# TODO: Calculate total bytes and convert to GB
pass
result = calculate_kv_cache_gb(8192, 4)
print(f"KV cache size: {result:.4f} GB")
Hints:
1. Formula: 2 * num_layers * num_kv_heads * head_dim * bytes_per_param * seq_length * batch_size
2. Divide by 1024**3 to get GB.
Exercise 2: Implement a Sliding Window KV Cache
Try It Yourself
Many models use a sliding window for the KV cache. Instead of keeping all tokens, they keep only the most recent N. When the window fills, the oldest token drops.
Complete the append method. It should add a new token and drop the oldest when the window is full.
class SlidingWindowCache:
def __init__(self, window_size, feature_dim):
self.window_size = window_size
self.feature_dim = feature_dim
self.keys = np.zeros((window_size, feature_dim), dtype=np.float32)
self.values = np.zeros((window_size, feature_dim), dtype=np.float32)
self.count = 0
def append(self, new_key, new_value):
"""Add a token. Drop the oldest if window is full."""
# TODO: Your code here
pass
def get_active(self):
n = min(self.count, self.window_size)
return self.keys[:n], self.values[:n]
cache = SlidingWindowCache(window_size=4, feature_dim=3)
for i in range(7):
k = np.array([i, i*10, i*100], dtype=np.float32)
v = np.array([i+1, (i+1)*10, (i+1)*100], dtype=np.float32)
cache.append(k, v)
active_k, _ = cache.get_active()
print(f"After token {i}: keys = {active_k[:, 0].tolist()}")
Hints:
1. When count < window_size, insert at position count.
2. When full, shift left by 1 (drop index 0) and insert at the last position.
Exercise 3: Build a Memory Budget Planner
Try It Yourself
You’re planning a deployment. GPU has fixed VRAM, model weights take a known amount. How many concurrent sequences can you serve?
Complete max_concurrent_sequences — it calculates available VRAM after weights and safety margin, divides by per-sequence KV cache, and returns the count.
def max_concurrent_sequences(
total_vram_gb,
model_weights_gb,
num_layers,
num_kv_heads,
head_dim,
max_seq_length,
safety_margin=0.1,
bytes_per_param=2
):
"""Max concurrent sequences that fit in remaining VRAM."""
# TODO: available memory, per-sequence cache, integer divide
pass
result = max_concurrent_sequences(
total_vram_gb=24,
model_weights_gb=16,
num_layers=32,
num_kv_heads=8,
head_dim=128,
max_seq_length=4096
)
print(f"Max concurrent sequences: {result}")
Hints:
1. Available VRAM = total_vram_gb * (1 - safety_margin) - model_weights_gb
2. Per-sequence cache (GB): use the formula, divide by 1024**3.
3. Integer division — you can’t serve half a sequence.
Common Mistakes with KV Cache
I’ve seen each of these cause real headaches in production. All are easy to dodge once you know the numbers.
Mistake 1: Leaving the KV cache out of GPU memory math
# BAD: Only counts model weights
model_gb = 14 # 7B params in FP16
print(f"Need {model_gb} GB GPU") # WRONG for long contexts
# GOOD: Weights + KV cache + overhead
kv_gb = calc_kv_cache_memory(
num_layers=32, num_kv_heads=32, head_dim=128, seq_length=4096
) / (1024**3)
overhead_gb = 2 # CUDA kernels, activations
total = model_gb + kv_gb + overhead_gb
print(f"Need {total:.0f} GB GPU for 4K context")
python
Need 18 GB GPU for 4K context
Mistake 2: Grabbing max space when most requests are short
Most chat requests use under 500 tokens. Saving 128K slots per request wastes 99.6% of memory. Use paged mode.
Mistake 3: Not thinking about batch size
The KV cache is per-request. Batching 8 requests means 8x the cache memory. I’ve seen people test with batch_size=1 and then panic when batch_size=16 crashes.
When NOT to Bother with KV Cache Tricks
Not every setup needs smart cache handling. Here’s when to keep things basic.
One user, short prompts. Running a local LLM with prompts under 2K tokens? A simple buffer works fine. The extra memory is too small to care about.
Batch jobs on short docs. Working through a pile of short texts one at a time? The cache gets made and thrown out per request. No waste, no sharing needed.
Models that are already lean. Models using GQA have small caches by design. Llama-3 8B’s cache is 4x smaller than Llama-2 7B’s. Sometimes the model itself fixes the problem for you.
When to invest: Many users at once on the same server. Long-context apps like RAG with big text chunks. Real-time systems where every bit of latency counts.
Summary
The KV cache holds on to past keys and values so your LLM can skip repeat work. It trades memory for speed — and with real models, the cost adds up fast.
Llama-2 7B at 32K context burns 16 GB just on the cache — more than its own weights. GQA trims this down by sharing KV heads across groups. That’s why newer models like Llama-3 lean on it.
How you handle memory shapes how many users you can serve at once. A flat setup wastes 60-70% of GPU memory. PagedAttention solves this by carving the cache into small blocks claimed on demand — the same core idea your OS uses for virtual memory.
When the cache runs full, your drop rule picks what the model loses. FIFO wipes out system prompts. Attention-based dropping keeps the tokens the model depends on.
Practice exercise: A startup wants to serve Llama-3 70B to 50 concurrent users at 8K context on 4x A100 80GB GPUs. Calculate: (1) total KV cache needed, (2) whether the GPUs handle it, (3) maximum context on 2x A100s instead.
Complete Code
FAQ
How does the KV cache differ from model weights?
Weights are fixed after training. The KV cache is live — it grows with each new token and is unique to each request. Weights are shared across all users; the cache belongs to one request.
Can you shrink the KV cache with lower precision?
Yes. FP8 or INT4 cache values cut memory by 2-4x with little quality loss. vLLM supports FP8 cache out of the box. The cost: slightly less precise output on very long runs where small errors add up.
What is Multi-Query Attention and how does it help?
Normal multi-head attention gives each head its own keys and values. Multi-Query Attention (MQA) shares one set across all heads, cutting cache size by up to 32x. Grouped Query Attention (GQA) is the middle path: groups of heads share. Llama-3 uses GQA with 8 KV heads for 32 query heads — a 4x cut.
Does the KV cache change model output?
The basic cache stores exact values, so no. But token dropping and lower precision do involve tradeoffs. Dropping tokens means the model loses context. Attention-based dropping keeps quality higher than FIFO because it holds on to the tokens the model actually needs.
How does prefix caching work?
When many requests share the same system prompt, their KV cache for that shared part is the same. Prefix caching works out those values once and shares them. This saves both compute and memory. vLLM spots shared prefixes on its own.
References
- Kwon, W., et al. — “Efficient Memory Management for Large Language Model Serving with PagedAttention.” SOSP 2023. Link
- vLLM documentation — PagedAttention. Link
- Raschka, S. — “Understanding and Coding the KV Cache in LLMs from Scratch.” (2025). Link
- Ainslie, J., et al. — “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” EMNLP 2023. Link
- Zhang, Z., et al. — “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models.” NeurIPS 2023. Link
- Shazeer, N. — “Fast Transformer Decoding: One Write-Head is All You Need.” (2019). Link
- NVIDIA Technical Blog — “Accelerate Large-Scale LLM Inference and KV Cache Offload.” Link
- Meta AI — Llama 3 Model Card. Link
Topic Cluster Plan:
1. Attention Mechanisms Explained — Self-Attention, Multi-Head, and Cross-Attention
2. Grouped Query Attention vs Multi-Query Attention — A Practical Comparison
3. vLLM Tutorial — Serve LLMs with PagedAttention in Python
4. LLM Inference Optimization — Quantization, Batching, and Speculative Decoding
5. Transformer Architecture from Scratch — Build a GPT in Python
6. GPU Memory Management for Deep Learning — A Practical Guide
7. KV Cache Quantization — FP8 and INT4 Techniques for Efficient Inference
8. Prefix Caching and Prompt Sharing — How to Serve Multi-Tenant LLMs
Free Course
Master Core Python — Your First Step into AI/ML
Build a strong Python foundation with hands-on exercises designed for aspiring Data Scientists and AI/ML Engineers.
Start Free Course →Trusted by 50,000+ learners
Related Course
Master Gen AI — Hands-On
Join 5,000+ students at edu.machinelearningplus.com
Explore Course
