Spaces:
Sleeping
Sleeping
| """Unit tests for proactive_cache.""" | |
| import pytest | |
| import numpy as np | |
| import torch | |
| from proactive_cache.eviction import score_tokens, select_indices, prune_kv_cache, evict | |
| from proactive_cache.prototypes import build_prototypes, save_prototypes, load_prototypes | |
| from proactive_cache.utils import to_tuple_kv, to_dynamic_cache | |
| # ── Fixtures ────────────────────────────────────────────────────────────────── | |
| def make_dummy_patterns(num_docs=5, num_layers=2, num_heads=4, seq_len=64): | |
| """Create synthetic attention patterns for testing.""" | |
| patterns = [] | |
| for _ in range(num_docs): | |
| doc = {} | |
| for layer in range(num_layers): | |
| for head in range(num_heads): | |
| arr = np.random.rand(seq_len).astype(np.float32) | |
| arr /= arr.sum() | |
| doc[(layer, head)] = arr | |
| patterns.append(doc) | |
| return patterns | |
| def make_dummy_kv_cache(num_layers=2, num_heads=4, seq_len=64, head_dim=32, device="cpu"): | |
| """Create a synthetic KV cache tuple.""" | |
| return tuple( | |
| (torch.randn(1, num_heads, seq_len, head_dim, device=device), | |
| torch.randn(1, num_heads, seq_len, head_dim, device=device)) | |
| for _ in range(num_layers) | |
| ) | |
| # ── Eviction tests ──────────────────────────────────────────────────────────── | |
| class TestScoreTokens: | |
| def test_returns_correct_shape(self): | |
| scores = score_tokens(None, seq_len=128, budget=64) | |
| assert scores.shape == (128,) | |
| def test_token_zero_has_highest_score(self): | |
| scores = score_tokens(None, seq_len=128, budget=64) | |
| # Sink boost means position 0 is always kept | |
| top_k = np.argsort(scores)[-64:] | |
| assert 0 in top_k, "Token 0 (attention sink) must always be selected" | |
| def test_recency_tokens_kept(self): | |
| seq_len, budget = 128, 64 | |
| scores = score_tokens(None, seq_len=seq_len, budget=budget) | |
| top_k = np.argsort(scores)[-budget:] | |
| # Last few tokens should be in top-k | |
| assert (seq_len - 1) in top_k, "Most recent token must always be kept" | |
| def test_with_prototypes(self): | |
| patterns = make_dummy_patterns(seq_len=64) | |
| protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64) | |
| scores = score_tokens(protos, seq_len=64, budget=32) | |
| assert scores.shape == (64,) | |
| assert np.all(np.isfinite(scores)), "Scores must be finite" | |
| def test_budget_proportional_recency(self): | |
| # Larger budget → larger recency window (proportional) | |
| s128 = score_tokens(None, seq_len=512, budget=128) | |
| s256 = score_tokens(None, seq_len=512, budget=256) | |
| # More positions should be elevated in s256 | |
| # (just check both run without error) | |
| assert s128.shape == s256.shape == (512,) | |
| class TestSelectIndices: | |
| def test_returns_sorted(self): | |
| scores = np.random.rand(100) | |
| idx = select_indices(scores, budget=20) | |
| assert idx == sorted(idx), "Indices must be in ascending order" | |
| def test_correct_count(self): | |
| scores = np.random.rand(100) | |
| idx = select_indices(scores, budget=30) | |
| assert len(idx) == 30 | |
| def test_budget_larger_than_seq(self): | |
| scores = np.random.rand(10) | |
| idx = select_indices(scores, budget=50) | |
| assert len(idx) == 10 # clipped to seq_len | |
| class TestPruneKVCache: | |
| def test_prunes_to_budget(self): | |
| kv = make_dummy_kv_cache(num_layers=3, num_heads=4, seq_len=128) | |
| indices = list(range(0, 64, 2)) # 32 indices | |
| pruned = prune_kv_cache(kv, indices, device=torch.device("cpu")) | |
| pruned_tuple = to_tuple_kv(pruned) | |
| assert pruned_tuple[0][0].shape[2] == 32, "Pruned KV must have budget tokens" | |
| def test_all_layers_pruned(self): | |
| num_layers = 4 | |
| kv = make_dummy_kv_cache(num_layers=num_layers, seq_len=100) | |
| indices = list(range(50)) | |
| pruned_tuple = to_tuple_kv(prune_kv_cache(kv, indices, torch.device("cpu"))) | |
| assert len(pruned_tuple) == num_layers | |
| def test_no_prune_when_under_budget(self): | |
| kv = make_dummy_kv_cache(seq_len=32) | |
| result = evict(kv, budget=64, prototypes=None, seq_len=32, device=torch.device("cpu")) | |
| # Should return unchanged (seq_len <= budget) | |
| assert to_tuple_kv(result)[0][0].shape[2] == 32 | |
| # ── Prototype tests ─────────────────────────────────────────────────────────── | |
| class TestPrototypes: | |
| def test_build_returns_dict(self): | |
| patterns = make_dummy_patterns() | |
| protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64) | |
| assert isinstance(protos, dict) | |
| assert len(protos) > 0 | |
| def test_centroid_shapes(self): | |
| patterns = make_dummy_patterns(num_layers=2, num_heads=4, seq_len=64) | |
| protos = build_prototypes(patterns, n_clusters=3, max_seq_len=64) | |
| for key, val in protos.items(): | |
| centroids = val["centroids"] | |
| assert centroids.shape == (3, 64), f"Wrong centroid shape: {centroids.shape}" | |
| def test_save_load_roundtrip(self, tmp_path): | |
| patterns = make_dummy_patterns() | |
| protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64) | |
| path = str(tmp_path / "test_protos.pkl") | |
| save_prototypes(protos, path) | |
| loaded = load_prototypes(path) | |
| assert set(loaded.keys()) == set(protos.keys()) | |
| def test_load_missing_raises(self, tmp_path): | |
| with pytest.raises(FileNotFoundError): | |
| load_prototypes(str(tmp_path / "does_not_exist.pkl")) | |
| def test_empty_patterns_raises(self): | |
| with pytest.raises(ValueError): | |
| build_prototypes([], n_clusters=2) | |
| # ── Utils tests ─────────────────────────────────────────────────────────────── | |
| class TestUtils: | |
| def test_to_tuple_kv_from_tuple(self): | |
| kv = make_dummy_kv_cache(num_layers=2) | |
| result = to_tuple_kv(kv) | |
| assert len(result) == 2 | |
| assert isinstance(result[0], tuple) | |
| def test_to_dynamic_cache_roundtrip(self): | |
| kv = make_dummy_kv_cache(num_layers=2, seq_len=32) | |
| kv_tuple = to_tuple_kv(kv) | |
| dynamic = to_dynamic_cache(kv_tuple) | |
| back = to_tuple_kv(dynamic) | |
| # Shapes should be preserved | |
| assert back[0][0].shape == kv_tuple[0][0].shape | |