Source code for turboquant.kv_cache

"""TurboQuant-compressed KV cache for ESML's inference engine.

Stores attention key/value vectors as compressed TQBlocks instead of raw
float tensors, achieving 4-6x memory reduction during inference.

This plugs into :class:`esml.engine.ESMLEngine` to provide KV-cache
compression during actual transformer attention computation.

References
----------
* TurboQuant: Zandieh et al. (2026). ICLR 2026. arXiv:2504.19874
* QJL: Zandieh et al. (2025). AAAI 2025. arXiv:2406.03482
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray

from .quant import TQBlock, turboquant_mse, turboquant_mse_decode

F32 = NDArray[np.float32]
F64 = NDArray[np.float64]


[docs] @dataclass class CacheStats: """Memory statistics for a TurboQuantKVCache.""" compressed_bytes: int = 0 uncompressed_bytes: int = 0 n_tokens: int = 0 n_layers: int = 0 @property def compression_ratio(self) -> float: if self.compressed_bytes == 0: return 0.0 return self.uncompressed_bytes / self.compressed_bytes @property def savings_mb(self) -> float: return (self.uncompressed_bytes - self.compressed_bytes) / (1024 * 1024)
[docs] class TurboQuantKVCache: """KV cache that stores keys and values as TurboQuant-compressed blocks. Each key/value vector is quantized via :func:`esml.quant.turboquant_mse` on ``append()``, and decompressed on ``get_keys()`` / ``get_values()``. Parameters ---------- n_layers : int Number of transformer layers. head_dim : int Dimension per attention head (must be power of 2). bits : int TurboQuant quantization bits (2, 3, or 4). rotation_seed : int Shared rotation seed for reproducibility. Examples -------- >>> cache = TurboQuantKVCache(n_layers=32, head_dim=128, bits=3) >>> k = np.random.randn(128) >>> v = np.random.randn(128) >>> cache.append(layer=0, k_vec=k, v_vec=v) >>> keys = cache.get_keys(0) # (1, 128) decompressed >>> values = cache.get_values(0) >>> cache.stats.compression_ratio 5.1 """ def __init__( self, n_layers: int, head_dim: int, bits: int = 3, rotation_seed: int = 42, ): self.n_layers = n_layers self.head_dim = head_dim self.bits = bits self.rotation_seed = rotation_seed # Per-layer lists of TQBlocks self._k_cache: list[list[TQBlock]] = [[] for _ in range(n_layers)] self._v_cache: list[list[TQBlock]] = [[] for _ in range(n_layers)]
[docs] def append(self, layer: int, k_vec: F64, v_vec: F64) -> None: """Compress and cache a new key/value pair for one token. Parameters ---------- layer : int Transformer layer index. k_vec : ndarray of shape (head_dim,) Key vector (will be quantized). v_vec : ndarray of shape (head_dim,) Value vector (will be quantized). """ k_block = turboquant_mse( k_vec.astype(np.float64), bits=self.bits, rotation_seed=self.rotation_seed, ) v_block = turboquant_mse( v_vec.astype(np.float64), bits=self.bits, rotation_seed=self.rotation_seed, ) self._k_cache[layer].append(k_block) self._v_cache[layer].append(v_block)
[docs] def get_keys(self, layer: int) -> F64: """Decompress all cached keys for a layer. Returns ------- ndarray of shape (seq_len, head_dim) """ if not self._k_cache[layer]: return np.zeros((0, self.head_dim)) return np.stack([turboquant_mse_decode(b) for b in self._k_cache[layer]])
[docs] def get_values(self, layer: int) -> F64: """Decompress all cached values for a layer. Returns ------- ndarray of shape (seq_len, head_dim) """ if not self._v_cache[layer]: return np.zeros((0, self.head_dim)) return np.stack([turboquant_mse_decode(b) for b in self._v_cache[layer]])
@property def seq_len(self) -> int: """Number of tokens currently cached (from layer 0).""" return len(self._k_cache[0]) if self._k_cache else 0
[docs] def clear(self) -> None: """Clear all cached blocks.""" for layer_k, layer_v in zip(self._k_cache, self._v_cache, strict=False): layer_k.clear() layer_v.clear()
@property def stats(self) -> CacheStats: """Compute memory statistics.""" compressed = 0 uncompressed = 0 n_tokens = self.seq_len for layer in range(self.n_layers): for block in self._k_cache[layer]: compressed += block.total_bits // 8 uncompressed += block.d * 2 # FP16 = 2 bytes per element for block in self._v_cache[layer]: compressed += block.total_bits // 8 uncompressed += block.d * 2 return CacheStats( compressed_bytes=compressed, uncompressed_bytes=uncompressed, n_tokens=n_tokens, n_layers=self.n_layers, )
[docs] class UncompressedKVCache: """Baseline uncompressed KV cache for comparison benchmarks.""" def __init__(self, n_layers: int, head_dim: int): self.n_layers = n_layers self.head_dim = head_dim self._k_cache: list[list[F64]] = [[] for _ in range(n_layers)] self._v_cache: list[list[F64]] = [[] for _ in range(n_layers)]
[docs] def append(self, layer: int, k_vec: F64, v_vec: F64) -> None: self._k_cache[layer].append(k_vec.copy()) self._v_cache[layer].append(v_vec.copy())
[docs] def get_keys(self, layer: int) -> F64: if not self._k_cache[layer]: return np.zeros((0, self.head_dim)) return np.stack(self._k_cache[layer])
[docs] def get_values(self, layer: int) -> F64: if not self._v_cache[layer]: return np.zeros((0, self.head_dim)) return np.stack(self._v_cache[layer])
@property def seq_len(self) -> int: return len(self._k_cache[0]) if self._k_cache else 0
[docs] def clear(self) -> None: for layer_k, layer_v in zip(self._k_cache, self._v_cache, strict=False): layer_k.clear() layer_v.clear()
@property def memory_bytes(self) -> int: total = 0 for layer in range(self.n_layers): total += len(self._k_cache[layer]) * self.head_dim * 8 # float64 total += len(self._v_cache[layer]) * self.head_dim * 8 return total