proactive-cache / tests /test_proactive_cache.py
skhavin's picture
feat: initial release of proactive-cache v0.1.0
b786614
"""Unit tests for proactive_cache."""
import pytest
import numpy as np
import torch
from proactive_cache.eviction import score_tokens, select_indices, prune_kv_cache, evict
from proactive_cache.prototypes import build_prototypes, save_prototypes, load_prototypes
from proactive_cache.utils import to_tuple_kv, to_dynamic_cache
# ── Fixtures ──────────────────────────────────────────────────────────────────
def make_dummy_patterns(num_docs=5, num_layers=2, num_heads=4, seq_len=64):
"""Create synthetic attention patterns for testing."""
patterns = []
for _ in range(num_docs):
doc = {}
for layer in range(num_layers):
for head in range(num_heads):
arr = np.random.rand(seq_len).astype(np.float32)
arr /= arr.sum()
doc[(layer, head)] = arr
patterns.append(doc)
return patterns
def make_dummy_kv_cache(num_layers=2, num_heads=4, seq_len=64, head_dim=32, device="cpu"):
"""Create a synthetic KV cache tuple."""
return tuple(
(torch.randn(1, num_heads, seq_len, head_dim, device=device),
torch.randn(1, num_heads, seq_len, head_dim, device=device))
for _ in range(num_layers)
)
# ── Eviction tests ────────────────────────────────────────────────────────────
class TestScoreTokens:
def test_returns_correct_shape(self):
scores = score_tokens(None, seq_len=128, budget=64)
assert scores.shape == (128,)
def test_token_zero_has_highest_score(self):
scores = score_tokens(None, seq_len=128, budget=64)
# Sink boost means position 0 is always kept
top_k = np.argsort(scores)[-64:]
assert 0 in top_k, "Token 0 (attention sink) must always be selected"
def test_recency_tokens_kept(self):
seq_len, budget = 128, 64
scores = score_tokens(None, seq_len=seq_len, budget=budget)
top_k = np.argsort(scores)[-budget:]
# Last few tokens should be in top-k
assert (seq_len - 1) in top_k, "Most recent token must always be kept"
def test_with_prototypes(self):
patterns = make_dummy_patterns(seq_len=64)
protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64)
scores = score_tokens(protos, seq_len=64, budget=32)
assert scores.shape == (64,)
assert np.all(np.isfinite(scores)), "Scores must be finite"
def test_budget_proportional_recency(self):
# Larger budget → larger recency window (proportional)
s128 = score_tokens(None, seq_len=512, budget=128)
s256 = score_tokens(None, seq_len=512, budget=256)
# More positions should be elevated in s256
# (just check both run without error)
assert s128.shape == s256.shape == (512,)
class TestSelectIndices:
def test_returns_sorted(self):
scores = np.random.rand(100)
idx = select_indices(scores, budget=20)
assert idx == sorted(idx), "Indices must be in ascending order"
def test_correct_count(self):
scores = np.random.rand(100)
idx = select_indices(scores, budget=30)
assert len(idx) == 30
def test_budget_larger_than_seq(self):
scores = np.random.rand(10)
idx = select_indices(scores, budget=50)
assert len(idx) == 10 # clipped to seq_len
class TestPruneKVCache:
def test_prunes_to_budget(self):
kv = make_dummy_kv_cache(num_layers=3, num_heads=4, seq_len=128)
indices = list(range(0, 64, 2)) # 32 indices
pruned = prune_kv_cache(kv, indices, device=torch.device("cpu"))
pruned_tuple = to_tuple_kv(pruned)
assert pruned_tuple[0][0].shape[2] == 32, "Pruned KV must have budget tokens"
def test_all_layers_pruned(self):
num_layers = 4
kv = make_dummy_kv_cache(num_layers=num_layers, seq_len=100)
indices = list(range(50))
pruned_tuple = to_tuple_kv(prune_kv_cache(kv, indices, torch.device("cpu")))
assert len(pruned_tuple) == num_layers
def test_no_prune_when_under_budget(self):
kv = make_dummy_kv_cache(seq_len=32)
result = evict(kv, budget=64, prototypes=None, seq_len=32, device=torch.device("cpu"))
# Should return unchanged (seq_len <= budget)
assert to_tuple_kv(result)[0][0].shape[2] == 32
# ── Prototype tests ───────────────────────────────────────────────────────────
class TestPrototypes:
def test_build_returns_dict(self):
patterns = make_dummy_patterns()
protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64)
assert isinstance(protos, dict)
assert len(protos) > 0
def test_centroid_shapes(self):
patterns = make_dummy_patterns(num_layers=2, num_heads=4, seq_len=64)
protos = build_prototypes(patterns, n_clusters=3, max_seq_len=64)
for key, val in protos.items():
centroids = val["centroids"]
assert centroids.shape == (3, 64), f"Wrong centroid shape: {centroids.shape}"
def test_save_load_roundtrip(self, tmp_path):
patterns = make_dummy_patterns()
protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64)
path = str(tmp_path / "test_protos.pkl")
save_prototypes(protos, path)
loaded = load_prototypes(path)
assert set(loaded.keys()) == set(protos.keys())
def test_load_missing_raises(self, tmp_path):
with pytest.raises(FileNotFoundError):
load_prototypes(str(tmp_path / "does_not_exist.pkl"))
def test_empty_patterns_raises(self):
with pytest.raises(ValueError):
build_prototypes([], n_clusters=2)
# ── Utils tests ───────────────────────────────────────────────────────────────
class TestUtils:
def test_to_tuple_kv_from_tuple(self):
kv = make_dummy_kv_cache(num_layers=2)
result = to_tuple_kv(kv)
assert len(result) == 2
assert isinstance(result[0], tuple)
def test_to_dynamic_cache_roundtrip(self):
kv = make_dummy_kv_cache(num_layers=2, seq_len=32)
kv_tuple = to_tuple_kv(kv)
dynamic = to_dynamic_cache(kv_tuple)
back = to_tuple_kv(dynamic)
# Shapes should be preserved
assert back[0][0].shape == kv_tuple[0][0].shape