machine learning +
KV Cache Explained: Build a Cache Manager in Python
MHA vs GQA vs MQA: Attention & KV Cache Guide
Build MHA, GQA, and MQA attention from scratch in NumPy. Calculate KV cache VRAM for Llama 3 70B, Mistral 7B, and any model with a reusable calculator.
This post has interactive code — click ‘Run’ or press Ctrl+Enter on any code block to execute it directly in your browser.
Build all three attention types from scratch in NumPy and figure out exactly how much GPU memory the KV cache eats for real models like Llama 3 70B and Mistral 7B.
Llama 3 70B has 64 query heads but only 8 key-value heads. Why not keep all 64? Because the KV cache would eat your entire GPU alive.
Every token you generate stores a key vector and a value vector for every layer. That adds up. With standard attention, a 70B model at 128K context needs hundreds of gigabytes just for this cache. Grouped-Query Attention and Multi-Query Attention fix the problem by sharing KV heads across queries. The savings are huge.
This article builds all three variants from scratch. Pure NumPy. You’ll see how each one works, then use a calculator to check the VRAM cost for any model config you want.
Before we build anything, here’s how the three variants fit together.
Standard Multi-Head Attention gives every query head its own key and value head. Full power. But you cache K and V for every head at every layer. For a 70B model with 64 heads across 80 layers, that’s enormous.
Multi-Query Attention goes the other way. All query heads share one K head and one V head. The cache shrinks by a factor equal to the head count. But quality drops — one KV pair can’t learn all the patterns that 64 pairs could.
Grouped-Query Attention splits the difference. You group the query heads. Each group shares one KV pair. Llama 3 uses 8 groups for 64 queries. You get most of MHA’s quality with most of MQA’s savings.
We’ll code each variant, compare their caches side by side, and then run a full VRAM calculator on real configs. All code runs in your browser with Pyodide.
Prerequisites
- Python version: 3.9+
- Required libraries: NumPy (1.24+)
- Install:
pip install numpy - Time to complete: 25-30 minutes
What Is Multi-Head Attention?
Multi-Head Attention (MHA) is the original design from the “Attention Is All You Need” paper. Every head gets its own Query, Key, and Value weights. With \(h$ heads, you have \)h$ sets of Q, K, V matrices.
Here’s a concrete picture. A model with d_model = 64 and n_heads = 4 splits each token into 4 slices of size 16. Every head runs attention on its own slice. Every head stores its own K and V in the cache.
The formula for a single head:
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]Where:
– $Q$ = query matrix (what each token looks for)
– $K$ = key matrix (what each token offers)
– $V$ = value matrix (what each token carries)
– \(d_k\) = head size (scaling stops the dot products from blowing up)
The function below takes a token sequence, projects it into Q, K, V for each head, runs scaled dot-product attention per head, and joins the results. We use np.random.seed(42) so every run gives the same numbers. The softmax uses the max-subtraction trick for number safety.
import numpy as np
np.random.seed(42)
def scaled_dot_product_attention(Q, K, V):
"""Single-head scaled dot-product attention."""
d_k = Q.shape[-1]
scores = Q @ K.T / np.sqrt(d_k)
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
return weights @ V, weights
def multi_head_attention(X, W_q, W_k, W_v, W_o, n_heads):
"""Multi-Head Attention: each head has its own Q, K, V."""
seq_len, d_model = X.shape
d_head = d_model // n_heads
Q = X @ W_q # (seq_len, d_model)
K = X @ W_k
V = X @ W_v
# Reshape into heads: (n_heads, seq_len, d_head)
Q = Q.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
K = K.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
V = V.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
outputs = []
all_weights = []
for i in range(n_heads):
out, w = scaled_dot_product_attention(Q[i], K[i], V[i])
outputs.append(out)
all_weights.append(w)
concat = np.concatenate(outputs, axis=-1) # (seq_len, d_model)
result = concat @ W_o
return result, all_weights, K, V
K and V each end up with shape (n_heads, seq_len, d_head). Every head stores its own key and value vectors. That’s the root of the memory problem — and what GQA and MQA fix.
Let’s feed it 5 tokens with 4 heads. Watch the cache shapes.
seq_len = 5
d_model = 64
n_heads = 4
d_head = d_model // n_heads # 16
X = np.random.randn(seq_len, d_model)
W_q = np.random.randn(d_model, d_model) * 0.1
W_k = np.random.randn(d_model, d_model) * 0.1
W_v = np.random.randn(d_model, d_model) * 0.1
W_o = np.random.randn(d_model, d_model) * 0.1
output, attn_weights, K_cache, V_cache = multi_head_attention(
X, W_q, W_k, W_v, W_o, n_heads
)
print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"K cache shape: {K_cache.shape}")
print(f"V cache shape: {V_cache.shape}")
print(f"KV cache per layer: {K_cache.nbytes + V_cache.nbytes} bytes")
Output:
python
Input shape: (5, 64)
Output shape: (5, 64)
K cache shape: (4, 5, 16)
V cache shape: (4, 5, 16)
KV cache per layer: 5120 bytes
The K and V caches hold n_heads * seq_len * d_head values each. That’s 4 * 5 * 16 = 320 floats per cache. Two caches, 8 bytes per float64 = 5,120 bytes. Tiny here, but it blows up at real scale.
Key Insight: In standard MHA, KV cache scales with the head count. This is the bottleneck. GQA and MQA fix it by cutting the number of KV heads while keeping all the query heads.
Quick check — can you predict what happens to the cache if we double the head count to 8? The cache doubles too, because MHA gives each head its own K and V. That’s the linear scaling that breaks things at production scale.
What Is Multi-Query Attention?
What if all those query heads shared a single set of keys and values?
That’s Multi-Query Attention. Noam Shazeer proposed it in 2019. You still have 4 heads asking 4 different questions (Q). But they all look at the same answers (K, V). One K, one V — shared across every head.
The cache drops from n_heads * seq_len * d_head to seq_len * d_head. For a model with 64 heads, that’s 64x smaller. The price? One KV pair can’t learn patterns as well as 64 could. Quality takes a hit.
Here’s the code. The key change: W_k_shared and W_v_shared project to d_head instead of d_model. That gives you one K and one V, not one per head.
def multi_query_attention(X, W_q, W_k_shared, W_v_shared, W_o, n_heads):
"""Multi-Query Attention: all heads share one K and one V."""
seq_len, d_model = X.shape
d_head = d_model // n_heads
Q = X @ W_q # full projection -- all heads
K = X @ W_k_shared # single head only
V = X @ W_v_shared # single head only
Q = Q.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
# K and V stay flat -- shared by every head
outputs = []
all_weights = []
for i in range(n_heads):
out, w = scaled_dot_product_attention(Q[i], K, V)
outputs.append(out)
all_weights.append(w)
concat = np.concatenate(outputs, axis=-1)
result = concat @ W_o
return result, all_weights, K, V
See how K and V don’t get reshaped into heads? They stay as (seq_len, d_head) — one flat matrix that every query head uses.
Let’s compare the cache sizes.
W_k_shared = np.random.randn(d_model, d_head) * 0.1
W_v_shared = np.random.randn(d_model, d_head) * 0.1
output_mqa, weights_mqa, K_mqa, V_mqa = multi_query_attention(
X, W_q, W_k_shared, W_v_shared, W_o, n_heads
)
print(f"MQA output shape: {output_mqa.shape}")
print(f"MQA K cache shape: {K_mqa.shape}")
print(f"MQA V cache shape: {V_mqa.shape}")
print(f"MQA KV cache: {K_mqa.nbytes + V_mqa.nbytes} bytes")
print(f"MHA KV cache: {K_cache.nbytes + V_cache.nbytes} bytes")
print(f"Reduction: {(K_cache.nbytes + V_cache.nbytes) / (K_mqa.nbytes + V_mqa.nbytes):.0f}x")
Output:
python
MQA output shape: (5, 64)
MQA K cache shape: (5, 16)
MQA V cache shape: (5, 16)
MQA KV cache: 1280 bytes
MHA KV cache: 5120 bytes
Reduction: 4x
A 4x reduction with 4 heads. Scale to Llama 3’s 64 heads, and you’d get 64x. That’s the gap between one GPU and a whole rack.
Warning: MQA cuts quality. One KV head can’t do what 64 could. For small models, the drop is easy to see. That’s why most real models use GQA now — it keeps the quality while still saving most of the memory.
What Is Grouped-Query Attention?
Here’s where it gets practical. What if you don’t go all the way to one KV head, but don’t keep all of them either?
GQA picks a number of KV groups. Each group of query heads shares one K and one V. Ainslie et al. showed in 2023 that GQA matches MHA quality while getting close to MQA speed. It won the industry.
Llama 3 70B: 64 query heads, 8 KV heads. Eight queries share each KV head. The cache is 8x smaller than MHA. Mistral 7B: 32 query heads, 8 KV heads. Four queries per group. Both use GQA.
The code groups query heads and points each group to its shared KV head. The n_kv_heads knob controls how many groups you have.
def grouped_query_attention(X, W_q, W_k_gqa, W_v_gqa, W_o, n_heads, n_kv_heads):
"""Grouped-Query Attention: groups of query heads share KV heads."""
seq_len, d_model = X.shape
d_head = d_model // n_heads
queries_per_group = n_heads // n_kv_heads
Q = X @ W_q
K = X @ W_k_gqa # projects to n_kv_heads * d_head
V = X @ W_v_gqa
Q = Q.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
K = K.reshape(seq_len, n_kv_heads, d_head).transpose(1, 0, 2)
V = V.reshape(seq_len, n_kv_heads, d_head).transpose(1, 0, 2)
outputs = []
all_weights = []
for i in range(n_heads):
kv_idx = i // queries_per_group
out, w = scaled_dot_product_attention(Q[i], K[kv_idx], V[kv_idx])
outputs.append(out)
all_weights.append(w)
concat = np.concatenate(outputs, axis=-1)
result = concat @ W_o
return result, all_weights, K, V
The key line: kv_idx = i // queries_per_group. Queries 0-1 share KV head 0. Queries 2-3 share KV head 1. Each KV head serves a cluster.
Let’s run it with 4 query heads and 2 KV heads.
n_kv_heads = 2
kv_dim = n_kv_heads * d_head # 2 * 16 = 32
W_k_gqa = np.random.randn(d_model, kv_dim) * 0.1
W_v_gqa = np.random.randn(d_model, kv_dim) * 0.1
output_gqa, weights_gqa, K_gqa, V_gqa = grouped_query_attention(
X, W_q, W_k_gqa, W_v_gqa, W_o, n_heads, n_kv_heads
)
print(f"GQA output shape: {output_gqa.shape}")
print(f"GQA K cache shape: {K_gqa.shape}")
print(f"GQA V cache shape: {V_gqa.shape}")
print(f"GQA KV cache: {K_gqa.nbytes + V_gqa.nbytes} bytes")
print(f"MHA KV cache: {K_cache.nbytes + V_cache.nbytes} bytes")
print(f"MQA KV cache: {K_mqa.nbytes + V_mqa.nbytes} bytes")
print(f"GQA vs MHA: {(K_cache.nbytes + V_cache.nbytes) / (K_gqa.nbytes + V_gqa.nbytes):.1f}x smaller")
Output:
python
GQA output shape: (5, 64)
GQA K cache shape: (2, 5, 16)
GQA V cache shape: (2, 5, 16)
GQA KV cache: 2560 bytes
MHA KV cache: 5120 bytes
MQA KV cache: 1280 bytes
GQA vs MHA: 2.0x smaller
Right in the middle. Scale this to Llama 3’s 8 KV heads out of 64 total, and you get an 8x cut.
Key Insight: GQA is the general case. Set `n_kv_heads = n_heads` and you get MHA. Set it to 1 and you get MQA. Any value in between is GQA. One knob controls the whole spectrum.
Side-by-Side: What Changes, What Stays
Before the VRAM math, let’s put all three next to each other. Same input, same 4 query heads — only the KV head count changes.
| Property | MHA | GQA (2 groups) | MQA |
|---|---|---|---|
| Query heads | 4 | 4 | 4 |
| KV heads | 4 | 2 | 1 |
| Queries per KV head | 1 | 2 | 4 |
| K cache shape | (4, 5, 16) | (2, 5, 16) | (5, 16) |
| KV cache bytes | 5,120 | 2,560 | 1,280 |
| Cache reduction | 1x | 2x | 4x |
More KV heads = more memory, more quality. Fewer = less memory, faster decoding, but some quality loss.
Predict the output: If you used GQA with n_kv_heads = 4 on our toy model (which also has n_heads = 4), what would the K cache shape be? Think about it before reading on.
Answer: (4, 5, 16) — same as MHA. Because with 4 KV heads and 4 query heads, every query gets its own KV head. GQA with n_kv_heads = n_heads IS MHA.
Let’s check the attention patterns from head 0 across all three variants.
print("=== Attention weights for head 0 ===\n")
print("MHA head 0 (own KV):")
print(np.round(attn_weights[0], 3))
print(f"\nMQA head 0 (shared KV):")
print(np.round(weights_mqa[0], 3))
print(f"\nGQA head 0 (group-shared KV):")
print(np.round(weights_gqa[0], 3))
The patterns differ because each variant has different K and V. MHA’s head 0 has its own KV pair. MQA shares one pair across all heads. GQA shares a pair with just one other head.
Tip: Picking a variant: Under 7B parameters, MHA works fine. Above 13B, GQA is standard. MQA is mostly historical now — GQA gives nearly the same speed with much less quality loss.
The KV Cache Formula
How much GPU memory does the cache cost? Here’s the formula.
\[\text{KV Cache} = 2 \times n_{layers} \times n_{kv\_heads} \times d_{head} \times seq_{len} \times bytes_{per\_param}\]Where:
– $2$ = one K and one V (same size)
– \(n_{layers}\) = transformer layers (80 for Llama 3 70B)
– \(n_{kv\_heads}\) = KV heads per layer (depends on the variant)
– \(d_{head}\) = size of each head (128 for most new models)
– \(seq_{len}\) = how many tokens you’ve seen so far
– \(bytes_{per\_param}\) = 2 for FP16/BF16, 4 for FP32
In code:
def kv_cache_bytes(n_layers, n_kv_heads, d_head, seq_len, bytes_per_param=2):
"""KV cache size in bytes."""
return 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_param
def bytes_to_gb(b):
return b / (1024 ** 3)
Two lines. That’s the whole calculator. Let’s point it at real models.
KV Cache for Real Model Configs
Here’s where it gets real. The code below checks four models at four context lengths. The configs come from each model’s published specs.
Notice how Llama 3 70B uses only 8 KV heads despite having 64 query heads. That 8x cut is what makes long context possible.
models = {
"Llama 3 70B (GQA)": {
"n_layers": 80, "n_kv_heads": 8, "d_head": 128
},
"Llama 3 8B (GQA)": {
"n_layers": 32, "n_kv_heads": 8, "d_head": 128
},
"Mistral 7B (GQA)": {
"n_layers": 32, "n_kv_heads": 8, "d_head": 128
},
"GPT-3 175B (MHA)": {
"n_layers": 96, "n_kv_heads": 96, "d_head": 128
},
}
seq_lengths = [1024, 4096, 8192, 131072]
print(f"{'Model':<25} ", end="")
for s in seq_lengths:
label = f"{s//1024}K tok"
print(f"{label:>10}", end="")
print()
print("-" * 67)
for name, cfg in models.items():
print(f"{name:<25} ", end="")
for s in seq_lengths:
cache = kv_cache_bytes(
cfg["n_layers"], cfg["n_kv_heads"], cfg["d_head"], s
)
gb = bytes_to_gb(cache)
print(f"{gb:>9.2f}G", end="")
print()
Output:
python
Model 1K tok 4K tok 8K tok 128K tok
-------------------------------------------------------------------
Llama 3 70B (GQA) 0.31G 1.25G 2.50G 40.00G
Llama 3 8B (GQA) 0.12G 0.50G 1.00G 16.00G
Mistral 7B (GQA) 0.12G 0.50G 1.00G 16.00G
GPT-3 175B (MHA) 3.60G 14.40G 28.80G 460.80G
Look at GPT-3 at 128K. That’s 460 GB for the KV cache alone. You’d need a cluster of GPUs just for caching. Llama 3 70B handles the same context with 40 GB. Still big, but it fits on a high-end GPU. The 8x cut from GQA makes all the difference.
Warning: These numbers are per-request. If you serve 8 users at 8K context each, multiply by 8. For Llama 3 70B, that’s 8 * 2.5 = 20 GB of KV caches on top of ~130 GB of model weights.
What If Llama 3 70B Used MHA or MQA Instead?
This is the experiment that makes GQA click. Same model. Same 80 layers. Same 128-dim heads. We only change the KV head count.
llama_configs = {
"Llama 3 70B — MHA (64 KV heads)": {"n_kv_heads": 64},
"Llama 3 70B — GQA (8 KV heads)": {"n_kv_heads": 8},
"Llama 3 70B — MQA (1 KV head)": {"n_kv_heads": 1},
}
seq = 8192
print(f"KV cache at {seq:,} tokens (FP16):\n")
for name, cfg in llama_configs.items():
cache = kv_cache_bytes(80, cfg["n_kv_heads"], 128, seq)
gb = bytes_to_gb(cache)
print(f" {name}: {gb:.2f} GB")
mha_cache = kv_cache_bytes(80, 64, 128, seq)
gqa_cache = kv_cache_bytes(80, 8, 128, seq)
mqa_cache = kv_cache_bytes(80, 1, 128, seq)
print(f"\n MHA -> GQA saves: {bytes_to_gb(mha_cache - gqa_cache):.2f} GB ({mha_cache/gqa_cache:.0f}x)")
print(f" MHA -> MQA saves: {bytes_to_gb(mha_cache - mqa_cache):.2f} GB ({mha_cache/mqa_cache:.0f}x)")
Output:
python
KV cache at 8,192 tokens (FP16):
Llama 3 70B — MHA (64 KV heads): 20.00 GB
Llama 3 70B — GQA (8 KV heads): 2.50 GB
Llama 3 70B — MQA (1 KV head): 0.31 GB
MHA -> GQA saves: 17.50 GB (8x)
MHA -> MQA saves: 19.69 GB (64x)
With MHA, the cache alone would burn 20 GB at 8K context. GQA brings it down to 2.5 GB. That 17.5 GB gap is the margin between “fits on an A100” and “doesn’t.”
MQA would push it to 0.31 GB. Tiny. But Meta picked GQA because the quality drop with MQA was too steep for a 70B model. Eight KV heads keep enough power for the model to do its job.
python
{
type: 'exercise',
id: 'kv-cache-calc',
title: 'Exercise 1: Calculate KV Cache for a Custom Model',
difficulty: 'beginner',
exerciseType: 'write',
instructions: 'A new model has 48 layers, 32 query heads, 4 KV heads (GQA), and d_head = 128. Calculate the KV cache in GB for 4096 tokens at FP16. Print the result rounded to 2 decimal places.',
starterCode: '# Model config\nn_layers = 48\nn_kv_heads = 4\nd_head = 128\nseq_len = 4096\nbytes_per_param = 2 # FP16\n\n# Calculate KV cache in bytes, then convert to GB\nkv_bytes = # your code here\nkv_gb = # your code here\nprint(f"{kv_gb:.2f}")',
testCases: [
{ id: 'tc1', input: '', expectedOutput: '0.75', description: 'KV cache should be 0.75 GB' },
],
hints: [
'Formula: 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_param',
'kv_bytes = 2 * 48 * 4 * 128 * 4096 * 2, then divide by 1024**3 for GB',
],
solution: 'n_layers = 48\nn_kv_heads = 4\nd_head = 128\nseq_len = 4096\nbytes_per_param = 2\n\nkv_bytes = 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_param\nkv_gb = kv_bytes / (1024 ** 3)\nprint(f"{kv_gb:.2f}")',
solutionExplanation: 'Multiply 2 (K + V) by layers, KV heads, head dim, seq length, and bytes per float. Divide by 1024^3 for GB. Result: 0.75 GB.',
xpReward: 15,
}
Full VRAM Calculator
Here’s a function that takes any model config and prints a full report. It shows the KV cache total, per-layer cost, and how it compares to model weights. Change the numbers and rerun it for any model you want.
def kv_cache_report(
model_name, n_params_b, n_layers, n_q_heads,
n_kv_heads, d_head, seq_len, dtype_bytes=2
):
"""Full KV cache and VRAM report for any model."""
kv = kv_cache_bytes(n_layers, n_kv_heads, d_head, seq_len, dtype_bytes)
kv_gb = bytes_to_gb(kv)
weights_gb = (n_params_b * 1e9 * dtype_bytes) / (1024**3)
per_layer = kv_cache_bytes(1, n_kv_heads, d_head, seq_len, dtype_bytes)
per_layer_mb = per_layer / (1024**2)
kv_mha = kv_cache_bytes(n_layers, n_q_heads, d_head, seq_len, dtype_bytes)
attn_type = "MHA" if n_kv_heads == n_q_heads else (
"MQA" if n_kv_heads == 1 else "GQA"
)
print(f"=== {model_name} ===")
print(f" Attention: {attn_type}")
print(f" Query / KV heads: {n_q_heads} / {n_kv_heads}")
print(f" Queries per KV: {n_q_heads // n_kv_heads}")
print(f" Layers: {n_layers}")
print(f" Head dim: {d_head}")
print(f" Seq length: {seq_len:,}")
print(f" Precision: {'FP16' if dtype_bytes == 2 else 'FP32'}")
print(f" ---")
print(f" KV cache total: {kv_gb:.2f} GB")
print(f" KV cache per layer: {per_layer_mb:.1f} MB")
print(f" Model weights: {weights_gb:.1f} GB")
print(f" Total for serving: {weights_gb + kv_gb:.1f} GB")
print(f" KV as % of total: {kv_gb / (weights_gb + kv_gb) * 100:.1f}%")
if n_kv_heads != n_q_heads:
mha_gb = bytes_to_gb(kv_mha)
print(f" If MHA instead: {mha_gb:.2f} GB ({n_q_heads // n_kv_heads}x more)")
print()
Let’s check three real models at 8K context.
kv_cache_report("Llama 3.1 70B", 70, 80, 64, 8, 128, 8192)
kv_cache_report("Mistral 7B", 7.3, 32, 32, 8, 128, 8192)
kv_cache_report("Llama 3.1 8B", 8, 32, 32, 8, 128, 8192)
Output:
python
=== Llama 3.1 70B ===
Attention: GQA
Query / KV heads: 64 / 8
Queries per KV: 8
Layers: 80
Head dim: 128
Seq length: 8,192
Precision: FP16
---
KV cache total: 2.50 GB
KV cache per layer: 32.0 MB
Model weights: 130.4 GB
Total for serving: 132.9 GB
KV as % of total: 1.9%
If MHA instead: 20.00 GB (8x more)
=== Mistral 7B ===
Attention: GQA
Query / KV heads: 32 / 8
Queries per KV: 4
Layers: 32
Head dim: 128
Seq length: 8,192
Precision: FP16
---
KV cache total: 1.00 GB
KV cache per layer: 32.0 MB
Model weights: 13.6 GB
Total for serving: 14.6 GB
KV as % of total: 6.9%
If MHA instead: 4.00 GB (4x more)
=== Llama 3.1 8B ===
Attention: GQA
Query / KV heads: 32 / 8
Queries per KV: 4
Layers: 32
Head dim: 128
Seq length: 8,192
Precision: FP16
---
KV cache total: 1.00 GB
KV cache per layer: 32.0 MB
Model weights: 14.9 GB
Total for serving: 15.9 GB
KV as % of total: 6.3%
If MHA instead: 4.00 GB (4x more)
For Llama 3.1 70B at 8K, the KV cache is about 2% of the total VRAM. Sounds small. But at 128K tokens, the cache jumps to 40 GB — then it’s 23% of the total. Context length changes everything.
Tip: Batch serving math: multiply the KV cache by your batch size. Serving 16 users at 8K on Llama 3.1 70B costs 16 * 2.5 = 40 GB of KV caches. Add the 130 GB of weights and you need multiple GPUs.
How Context Length Shifts the Balance
At short contexts, model weights dominate. The KV cache is a rounding error. But as context grows, the cache climbs and can rival the weights.
Here’s the progression for Llama 3 70B.
print("Llama 3 70B — KV cache growth vs model weights (FP16)\n")
print(f"{'Seq Length':>12} {'KV Cache':>10} {'Weights':>10} {'KV %':>8}")
print("-" * 44)
weights_size = (70e9 * 2) / (1024**3) # ~130.4 GB
for s in [1024, 4096, 16384, 32768, 65536, 131072]:
kv = bytes_to_gb(kv_cache_bytes(80, 8, 128, s))
pct = kv / (kv + weights_size) * 100
print(f"{s:>12,} {kv:>9.1f}G {weights_size:>9.1f}G {pct:>7.1f}%")
Output:
python
Llama 3 70B — KV cache growth vs model weights (FP16)
Seq Length KV Cache Weights KV %
--------------------------------------------
1,024 0.3G 130.4G 0.2%
4,096 1.2G 130.4G 1.0%
16,384 5.0G 130.4G 3.7%
32,768 10.0G 130.4G 7.1%
65,536 20.0G 130.4G 13.3%
131,072 40.0G 130.4G 23.5%
At 128K tokens the cache is nearly a quarter of the total. Without GQA, it’d be 320 GB — more than twice the model weights. At long contexts, KV cache tricks aren’t optional. They’re survival.
[UNDER THE HOOD]
Why does the formula use n_kv_heads and not n_q_heads? During inference, the model generates one token at a time. Each new token adds its K and V to the cache. But Q doesn’t get cached — it’s only needed for the current token. The cached K and V tensors have shape (batch, n_kv_heads, seq_len, d_head). The query Q is computed fresh at each step, broadcast across the KV heads. That’s why only the KV head count matters for cache size.
python
{
type: 'exercise',
id: 'batch-serving-calc',
title: 'Exercise 2: Batch Serving Memory Budget',
difficulty: 'intermediate',
exerciseType: 'write',
instructions: 'You have an A100 80 GB GPU. Llama 3 8B weights take 14.9 GB in FP16. Each request uses 4096 tokens. Calculate the max number of concurrent requests that fit. Ignore activation memory. Print just the integer.',
starterCode: 'gpu_vram_gb = 80\nweights_gb = 14.9\nn_layers = 32\nn_kv_heads = 8\nd_head = 128\nseq_len = 4096\n\nkv_per_request = # your code here\navailable = gpu_vram_gb - weights_gb\nmax_requests = # your code here\nprint(max_requests)',
testCases: [
{ id: 'tc1', input: '', expectedOutput: '130', description: 'Should fit 130 concurrent requests' },
],
hints: [
'kv_per_request = 2 * 32 * 8 * 128 * 4096 * 2 / (1024**3) = 0.5 GB',
'available = 65.1. max_requests = int(65.1 / 0.5) = 130',
],
solution: 'gpu_vram_gb = 80\nweights_gb = 14.9\nn_layers = 32\nn_kv_heads = 8\nd_head = 128\nseq_len = 4096\n\nkv_per_request = 2 * n_layers * n_kv_heads * d_head * seq_len * 2 / (1024**3)\navailable = gpu_vram_gb - weights_gb\nmax_requests = int(available / kv_per_request)\nprint(max_requests)',
solutionExplanation: 'Each request caches K and V across 32 layers, 8 KV heads, 128 dims, 4096 tokens in FP16. That is 0.5 GB. With 65.1 GB free, you fit 130 requests.',
xpReward: 20,
}
Common Mistakes and How to Fix Them
Mistake 1: Mixing up query heads and KV heads
“Llama 3 has 64 attention heads.” That’s the query count. The KV count is different.
Wrong:
# Uses query head count (64) for KV cache
kv_wrong = 2 * 80 * 64 * 128 * 8192 * 2
print(f"Wrong KV cache: {kv_wrong / (1024**3):.2f} GB")
Output:
python
Wrong KV cache: 20.00 GB
Correct:
# Uses KV head count (8) -- the actual number cached
kv_right = 2 * 80 * 8 * 128 * 8192 * 2
print(f"Correct KV cache: {kv_right / (1024**3):.2f} GB")
Output:
python
Correct KV cache: 2.50 GB
That’s an 8x error. Always check the model card for num_key_value_heads, not just num_attention_heads.
Mistake 2: Forgetting the factor of 2
The cache stores both keys AND values. They’re the same size.
Wrong:
# Only counts K, misses V
k_only = 80 * 8 * 128 * 8192 * 2
print(f"K-only: {k_only / (1024**3):.2f} GB")
Output:
python
K-only: 1.25 GB
The real answer is 2.5 GB. Always multiply by 2 for both K and V.
Mistake 3: Ignoring batch size
A single-request calc looks fine. But in production, you cache for every user at once.
The trap: “Llama 3 70B KV cache is only 2.5 GB. That’s nothing!”
Reality: 32 concurrent users at 8K = 32 * 2.5 = 80 GB. More than one A100’s whole VRAM.
When NOT to Use Each Variant
Not every case needs KV cache tricks.
Stick with MHA when:
– Your model is under 7B. The cache fits on one GPU easily.
– You’re training, not serving. Training doesn’t cache K and V step by step.
– Quality matters more than memory and you’ve got VRAM to spare.
Use GQA when:
– Your model is 7B+ and you’re serving inference. This is most real cases.
– You need long context (32K+ tokens). Without GQA, the cache explodes.
– You want the best quality-to-memory ratio. It’s the default in 2026.
Consider MQA when:
– You’re building a very latency-sensitive system and can take some quality loss.
– The model is small enough that one KV head won’t hurt much.
– You need peak throughput on minimal hardware.
Note: GQA won the industry. Llama 3, Mistral, Gemma, Qwen 2, and Falcon all use it. MQA was the research spark, but GQA took the crown by giving nearly the same speed with much better quality.
python
{
type: 'exercise',
id: 'gqa-from-scratch',
title: 'Exercise 3: GQA as MQA',
difficulty: 'intermediate',
exerciseType: 'write',
instructions: 'Run the grouped_query_attention function with n_kv_heads=1 (MQA mode). Print the K cache shape.',
starterCode: 'n_kv_heads_mqa = 1\nkv_dim_mqa = n_kv_heads_mqa * d_head\n\nW_k_mqa = np.random.randn(d_model, kv_dim_mqa) * 0.1\nW_v_mqa = np.random.randn(d_model, kv_dim_mqa) * 0.1\n\n_, _, K_test, V_test = grouped_query_attention(\n X, W_q, W_k_mqa, W_v_mqa, W_o, n_heads, n_kv_heads_mqa\n)\nprint(K_test.shape)',
testCases: [
{ id: 'tc1', input: '', expectedOutput: '(1, 5, 16)', description: 'K cache has 1 KV head' },
],
hints: [
'GQA with n_kv_heads=1 acts like MQA — one KV head for all queries.',
'Shape: (n_kv_heads, seq_len, d_head) = (1, 5, 16)',
],
solution: 'n_kv_heads_mqa = 1\nkv_dim_mqa = n_kv_heads_mqa * d_head\n\nW_k_mqa = np.random.randn(d_model, kv_dim_mqa) * 0.1\nW_v_mqa = np.random.randn(d_model, kv_dim_mqa) * 0.1\n\n_, _, K_test, V_test = grouped_query_attention(\n X, W_q, W_k_mqa, W_v_mqa, W_o, n_heads, n_kv_heads_mqa\n)\nprint(K_test.shape)',
solutionExplanation: 'With n_kv_heads=1, all 4 query heads share one KV head. The K cache shape is (1, 5, 16). GQA with 1 KV head is MQA.',
xpReward: 15,
}
Summary
You’ve built all three attention variants and checked their VRAM costs on real models.
The core idea: MHA, GQA, and MQA all run scaled dot-product attention. They differ only in how many KV heads they keep. Fewer heads = smaller cache, faster inference — but a quality tradeoff.
Quick recap:
– MHA: every query head gets its own KV head. Max quality, max memory.
– MQA: all heads share one KV head. Min memory, lower quality.
– GQA: heads are grouped, each group shares a KV head. Best tradeoff. Industry standard.
– The formula: 2 * layers * kv_heads * head_dim * seq_len * bytes.
– At long contexts, KV cache can rival model weights.
Practice Exercise
Llama 3.1 405B has 126 layers, 128 query heads, 8 KV heads, 128 head dim. How much KV cache at 32K in FP16? How many 32K requests fit on 8x A100 80 GB GPUs (640 GB total) if weights take ~754 GB?
Complete Code
Frequently Asked Questions
Does GQA affect training speed or only inference?
GQA mostly helps at inference time. During training, you don’t cache K and V token by token — you process all tokens at once. GQA slightly cuts the parameter count (fewer K and V matrices), which gives a small training speedup. But the big memory wins only show up during autoregressive generation, where the cache builds up as you produce tokens.
Can you turn an MHA model into GQA after training?
Yes. The GQA paper showed a method called “uptrained GQA.” You take a trained MHA model, average the K and V heads within each target group, then fine-tune for a short time. Quality bounces back fast.
# Concept: averaging KV heads to convert MHA -> GQA
n_original = 64
n_target = 8
group_size = n_original // n_target
print(f"Each GQA group averages {group_size} original heads")
Output:
python
Each GQA group averages 8 original heads
How does KV cache quantization compare to fewer KV heads?
They attack different axes. GQA cuts the number of vectors. Quantization cuts the size of each vector (FP16 -> INT8 halves memory). You can stack them. GQA with 8 KV heads plus INT8 quantization uses 16x less cache than full MHA with FP16.
What is Multi-Head Latent Attention (MLA)?
MLA is a newer trick from DeepSeek. Instead of caching full K and V vectors, it stores a compressed version and rebuilds K and V on the fly. This can shrink the cache even more than GQA. DeepSeek-V2 uses it. It’s promising, but GQA stays the most common choice as of early 2026.
References
- Vaswani, A. et al. — “Attention Is All You Need.” NeurIPS 2017. arXiv:1706.03762
- Shazeer, N. — “Fast Transformer Decoding: One Write-Head is All You Need.” 2019. arXiv:1911.02150
- Ainslie, J. et al. — “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” EMNLP 2023. arXiv:2305.13245
- Meta AI — Llama 3 Model Card. Hugging Face
- Mistral AI — Mistral 7B Technical Report. mistral.ai
- NVIDIA — “Mastering LLM Techniques: Inference Optimization.” NVIDIA Blog
- Brenndoerfer, M. — “KV Cache Memory: Calculating GPU Requirements.” mbrenndoerfer.com
- Pope, R. et al. — “Efficiently Scaling Transformer Inference.” MLSys 2023. arXiv:2211.05102
Free Course
Master Core Python — Your First Step into AI/ML
Build a strong Python foundation with hands-on exercises designed for aspiring Data Scientists and AI/ML Engineers.
Start Free Course →Trusted by 50,000+ learners
Related Course
Master Gen AI — Hands-On
Join 5,000+ students at edu.machinelearningplus.com
Explore Course
