diff --git a/src/cachekit/backends/memcached/backend.py b/src/cachekit/backends/memcached/backend.py index fd59f4b..fad190b 100644 --- a/src/cachekit/backends/memcached/backend.py +++ b/src/cachekit/backends/memcached/backend.py @@ -9,6 +9,7 @@ import time from typing import Any, Optional +from cachekit.backends.errors import BackendError, BackendErrorType from cachekit.backends.memcached.config import MAX_MEMCACHED_TTL, MemcachedBackendConfig from cachekit.backends.memcached.error_handler import classify_memcached_error @@ -108,12 +109,32 @@ def set(self, key: str, value: bytes, ttl: Optional[int] = None) -> None: Raises: BackendError: If Memcached operation fails. """ + # Guard client-side against oversized items. Memcached rejects items over its + # item-size limit (default 1 MiB), but with noreply that rejection is never read — + # the call appears to succeed and the entry is silently never cached. Fail loudly + # instead, so the caller can compress, shard, or switch backends. + max_size = self._config.max_item_size_bytes + if max_size and len(value) > max_size: + raise BackendError( + message=( + f"Value for key {key!r} is {len(value)} bytes, which exceeds the Memcached " + f"max item size of {max_size} bytes. Memcached cannot store it. Enable " + f"compression, use a larger-payload backend (Redis/SaaS/File), or raise both " + f"the server's -I limit and CACHEKIT_MEMCACHED_MAX_ITEM_SIZE_BYTES." + ), + error_type=BackendErrorType.PERMANENT, + operation="set", + key=key, + ) + expire = 0 if ttl is not None and ttl > 0: expire = min(ttl, MAX_MEMCACHED_TTL) try: - self._client.set(self._prefixed_key(key), value, expire=expire) + # noreply=False so an oversized/error reply from the server is read and surfaced + # rather than silently swallowed (HashClient defaults to noreply=True). + self._client.set(self._prefixed_key(key), value, expire=expire, noreply=False) except Exception as exc: raise classify_memcached_error(exc, operation="set", key=key) from exc diff --git a/src/cachekit/backends/memcached/config.py b/src/cachekit/backends/memcached/config.py index 05cddc9..f6d6c56 100644 --- a/src/cachekit/backends/memcached/config.py +++ b/src/cachekit/backends/memcached/config.py @@ -85,6 +85,16 @@ class MemcachedBackendConfig(BaseBackendConfig): default="", description="Optional prefix for all cache keys", ) + max_item_size_bytes: int = Field( + default=1024 * 1024, + ge=0, + description=( + "Reject values larger than this BEFORE sending to Memcached (0 disables the check). " + "Memcached's default item-size limit is 1 MiB (server -I flag); oversized items are " + "rejected by the server, and with noreply that rejection is silent — so cachekit " + "guards client-side and fails loudly. Raise this only if the server's -I is raised too." + ), + ) @field_validator("servers", mode="after") @classmethod diff --git a/src/cachekit/config/settings.py b/src/cachekit/config/settings.py index 2b11296..f26aac8 100644 --- a/src/cachekit/config/settings.py +++ b/src/cachekit/config/settings.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Literal, Optional from pydantic import ( Field, @@ -116,6 +116,15 @@ class CachekitConfig(BaseSettings): le=9, description="Zlib compression level (1-9, where 9 is highest compression)", ) + arrow_compression: Literal["zstd", "lz4", "none"] = Field( + default="zstd", + description=( + "Arrow IPC compression codec for DataFrame caching (ArrowSerializer, compression='auto'). " + "'zstd'/'lz4' shrink the stored payload but must be decompressed into the heap on read. " + "'none' stores uncompressed Arrow IPC, which enables zero-copy memory-mapped reads " + "(lowest read memory) at the cost of a larger payload. Env: CACHEKIT_ARROW_COMPRESSION." + ), + ) retry_on_timeout: bool = Field( default=True, description="Whether to retry operations on timeout", diff --git a/src/cachekit/l1_cache.py b/src/cachekit/l1_cache.py index 883928e..75644bf 100644 --- a/src/cachekit/l1_cache.py +++ b/src/cachekit/l1_cache.py @@ -284,6 +284,23 @@ def put( # Estimate size size = self._estimate_size(value) + # Reject entries that cannot fit even in an empty cache. Storing one would push L1 + # permanently over its budget, and a multi-GB serialized DataFrame envelope is a + # direct OOM vector (it would also evict every other useful entry on the way in). + # The value is still available from L2; we only decline to mirror it in L1. If a + # smaller entry for this key was cached, drop it so L1 stops serving the stale value. + if size > self.max_memory_bytes: + with self._lock: + if key in self._cache: + self._remove_entry(key) + logger.debug( + "Skipping L1 cache for key %s - value %d bytes exceeds L1 budget %d bytes (served from L2 only)", + key, + size, + self.max_memory_bytes, + ) + return + with self._lock: # Check if key already exists if key in self._cache: diff --git a/src/cachekit/serializers/arrow_serializer.py b/src/cachekit/serializers/arrow_serializer.py index 60ee4c7..0cd8604 100644 --- a/src/cachekit/serializers/arrow_serializer.py +++ b/src/cachekit/serializers/arrow_serializer.py @@ -49,37 +49,52 @@ # Standard dependency: xxhash (always available) import xxhash +# Target bytes per Arrow record-batch when writing IPC. Bounds the zstd compressor's +# working set (compression is per-batch) so peak memory does not scale with table size. +_TARGET_BATCH_BYTES = 8 * 1024 * 1024 +# Only return pool memory to the OS for payloads at least this large (avoids churn on +# the small-object hot path, where the syscall + page re-fault would dominate). +_RELEASE_POOL_THRESHOLD = 4 * 1024 * 1024 + + +def _bounded_chunksize(table: pa.Table) -> int | None: # type: ignore[name-defined] + """Rows per IPC record-batch so each batch is ~_TARGET_BATCH_BYTES, regardless of width. + + Returns None for empty tables (nothing to chunk). Never returns 0. + """ + if table.num_rows <= 0: + return None + bytes_per_row = max(1, table.nbytes // table.num_rows) + return max(1, _TARGET_BATCH_BYTES // bytes_per_row) + class ArrowSerializer: - """Apache Arrow IPC format for zero-copy DataFrame caching with xxHash3-64 integrity protection. + """Apache Arrow IPC serializer for memory-efficient DataFrame caching with xxHash3-64 integrity. - Provides 100,000x deserialize speedup (memory-mapped) and 50x serialize - speedup for DataFrames. Supports pandas, polars, and dict of arrays (columnar). + Columnar Arrow IPC is far more compact and faster to (de)serialize than MessagePack for + DataFrames. Supports pandas, polars, and dict of arrays (columnar). Does NOT support non-tabular data (scalar values, nested dicts, custom objects). - Integrity Protection: - - Format: [8-byte xxHash3-64 checksum][Arrow IPC data] - - Checksum computed on original Arrow IPC bytes + Integrity Protection (always on): + - Format: [8-byte xxHash3-64 checksum][compressed Arrow IPC] + - Checksum computed over the stored (compressed) IPC bytes - Validation on deserialize detects bit flips, truncation, corruption - - 8-byte overhead per cached DataFrame (faster than cryptographic hashes) + - 8-byte overhead per cached DataFrame (negligible vs the payload; never silently + returns corrupted data, regardless of the enable_integrity_checking flag) + + Memory profile (bounded, low-copy): + - Serialize: builds the compressed IPC once and prepends the checksum; the source Arrow + table is freed before the IPC bytes are materialized. + - Deserialize: the envelope is sliced with a memoryview (no full-body copy), wrapped via + pa.py_buffer (zero-copy), and Arrow->pandas conversion uses self_destruct + split_blocks + to free Arrow buffers during conversion. zstd is decompressed transparently by the reader. Use cases: - Data science pipelines (pandas/polars DataFrames) - ML feature stores (model training data caching) - Analytics queries (aggregations, filtering on cached DataFrames) - - Cold cache tier (5-10x compression for columnar data) - Production caching requiring integrity guarantees - Performance (10M rows, 50 columns): - - Serialize: ~100ms Arrow IPC (vs 5000ms MessagePack) - - Deserialize: ~0.1ms memory-map (vs 10000ms MessagePack unpacking) - - Network latency: ~5-10ms (Arrow IPC benefits dominate for large DataFrames) - - Zero-Copy Benefits: - - Memory-mapped deserialization (no CPU decoding, instant access) - - Columnar format enables filter/aggregate without full deserialization - - Cross-language compatibility (Python, R, Julia, Rust) - Limitations: - DataFrames only (pandas.DataFrame, polars.DataFrame, dict of arrays) - NO scalar values (int, str, float) @@ -112,7 +127,7 @@ class ArrowSerializer: # Apache Arrow IPC is a cross-language columnar format (Python, R, Julia, Rust) — safe under encryption. cross_sdk_compatible: ClassVar[bool] = True - def __init__(self, return_format: str = "pandas", enable_integrity_checking: bool = True): + def __init__(self, return_format: str = "pandas", enable_integrity_checking: bool = True, compression: str | None = "auto"): """Initialize ArrowSerializer. Args: @@ -120,17 +135,40 @@ def __init__(self, return_format: str = "pandas", enable_integrity_checking: boo - "pandas": Convert to pandas.DataFrame (default) - "polars": Convert to polars.DataFrame - "arrow": Return pyarrow.Table (zero-copy, no conversion) - enable_integrity_checking: Enable xxHash3-64 checksum validation (default: True) - When True: 8-byte checksum overhead + validation cost (integrity guarantee) - When False: No checksum (faster, use for @cache.minimal speed-first scenarios) + enable_integrity_checking: Retained for API compatibility. The 8-byte xxHash3-64 + checksum is now ALWAYS written and validated (silently returning corrupted + DataFrames is unacceptable, and 8 bytes is negligible), so this flag no longer + disables integrity. + compression: Arrow IPC compression codec. + - "auto" (default): use the CACHEKIT_ARROW_COMPRESSION setting (itself "zstd" by default) + - "zstd" / "lz4": compress the payload (smaller wire/L1; must be decompressed on read) + - None or "none": store uncompressed Arrow IPC, enabling zero-copy memory-mapped reads + (lowest read memory) at the cost of a larger payload Raises: - ValueError: If return_format is not one of the valid options + ValueError: If return_format or compression is not a valid option """ if return_format not in ("pandas", "polars", "arrow"): raise ValueError(f"Invalid return_format: '{return_format}'. Valid options: 'pandas', 'polars', 'arrow'") self.return_format = return_format self.enable_integrity_checking = enable_integrity_checking + self.compression = self._resolve_compression(compression) + + @staticmethod + def _resolve_compression(compression: str | None) -> str | None: + """Normalize/validate the compression option. 'auto' resolves from settings.""" + if compression == "auto": + try: + from cachekit.config.singleton import get_settings + + compression = get_settings().arrow_compression + except Exception: # noqa: BLE001 — settings unavailable: fall back to a sane default + compression = "zstd" + if compression in (None, "none"): + return None + if compression not in ("zstd", "lz4"): + raise ValueError(f"Invalid compression: {compression!r}. Valid options: 'auto', 'zstd', 'lz4', None ('none').") + return compression def serialize(self, obj: Any) -> tuple[bytes, SerializationMetadata]: # type: ignore[name-defined] """Serialize DataFrame to Arrow IPC format bytes with optional xxHash3-64 integrity protection. @@ -148,16 +186,29 @@ def serialize(self, obj: Any) -> tuple[bytes, SerializationMetadata]: # type: i SerializationError: If Arrow conversion fails """ try: - # Convert to Arrow Table (supports pandas, polars, dict of arrays) + # Convert to Arrow Table (supports pandas, polars, dict of arrays). + # preserve_index=None (pyarrow default): a RangeIndex is stored as cheap + # schema metadata (no materialized column / extra copy) and restored as a + # RangeIndex; named/MultiIndex are still preserved as columns. preserve_index=True + # would force even a RangeIndex into a materialized column. table = None if HAS_PANDAS and isinstance(obj, pd.DataFrame): - table = pa.Table.from_pandas(obj, preserve_index=True) - elif hasattr(obj, "__arrow_c_stream__"): # polars DataFrame - # Polars supports Arrow C Stream interface (zero-copy) + table = pa.Table.from_pandas(obj, preserve_index=None) + elif hasattr(obj, "__arrow_c_stream__"): # polars DataFrame (zero-copy C Stream) table = pa.table(obj) elif isinstance(obj, dict): - # dict of arrays (columnar format) - table = pa.table(obj) + # dict of arrays (columnar). Normalize pyarrow's raw conversion errors + # (e.g. dict-of-scalars -> "'int' object is not iterable") into the + # documented TypeError so callers get a consistent, actionable message. + try: + table = pa.table(obj) + except (pa.ArrowInvalid, pa.ArrowTypeError, TypeError, ValueError) as e: + raise TypeError( + f"ArrowSerializer only supports DataFrames " + f"(pandas.DataFrame, polars.DataFrame) or dict of arrays (columnar). " + f"Got a dict that is not convertible to an Arrow table: {e}. " + f"For scalar values or nested dicts, use AutoSerializer." + ) from e if table is None: raise TypeError( @@ -167,26 +218,36 @@ def serialize(self, obj: Any) -> tuple[bytes, SerializationMetadata]: # type: i f"For scalar values or nested dicts, use AutoSerializer." ) - # Serialize to Arrow IPC format (memory-mappable, streaming format) + # Serialize to Arrow IPC. Compression (when enabled) runs per record-batch, so + # writing in bounded batches keeps the compressor's working set bounded (one big + # batch makes the codec allocate a full-size working buffer — measured ~3.6x the + # payload). Size each batch to ~8 MiB regardless of schema width. compression=None + # writes uncompressed IPC, which a reader can memory-map zero-copy. + max_chunksize = _bounded_chunksize(table) sink = pa.BufferOutputStream() - with pa.ipc.new_file(sink, table.schema) as writer: - writer.write_table(table) - - arrow_data = sink.getvalue().to_pybytes() - - # Conditionally add integrity protection - if self.enable_integrity_checking: - # Compute xxHash3-64 checksum of original Arrow IPC data (8 bytes) - checksum = xxhash.xxh3_64_digest(arrow_data) - # Envelope format: [checksum][data] - envelope = checksum + arrow_data - else: - # No integrity checking - return raw Arrow IPC data - envelope = arrow_data + write_options = pa.ipc.IpcWriteOptions(compression=self.compression) if self.compression else None + with pa.ipc.new_file(sink, table.schema, options=write_options) as writer: + writer.write_table(table, max_chunksize=max_chunksize) + del table # free the Arrow table before materializing the IPC bytes (lowers peak) + + # Always integrity-protect: hash over the buffer's memoryview (no copy), then + # build the [8-byte xxHash3-64 checksum][compressed Arrow IPC] envelope. The + # checksum is unconditional — silently returning corrupted DataFrames is + # unacceptable, and 8 bytes is negligible against the payload. + buf = sink.getvalue() + checksum = xxhash.xxh3_64_digest(memoryview(buf)) + envelope = checksum + buf.to_pybytes() + + # For large payloads, return the compressor/buffer working memory the Arrow pool + # retained back to the OS so it does not stack under the caller's next allocation + # (the envelope wrap). No-op cost is trivial; gated to avoid churn on small objects. + if len(envelope) >= _RELEASE_POOL_THRESHOLD: + del buf + pa.default_memory_pool().release_unused() return envelope, SerializationMetadata( serialization_format=SerializationFormat.ARROW, - compressed=False, # Arrow IPC has optional compression (future enhancement) + compressed=self.compression is not None, # reflects the configured codec (None = uncompressed) encrypted=False, # Encryption is EncryptionWrapper's responsibility original_type="arrow", ) @@ -207,34 +268,38 @@ def deserialize(self, data: bytes, metadata: SerializationMetadata | None = None SerializationError: If data is malformed, Arrow deserialization fails, or checksum validation fails """ try: - if self.enable_integrity_checking: - # Guard clause: Minimum size check (8 bytes checksum + minimal Arrow IPC file) - if len(data) < 40: - raise SerializationError( - f"Invalid data: Expected at least 40 bytes (8-byte checksum + Arrow IPC header), got {len(data)} bytes" - ) - - # Extract checksum and Arrow IPC data - expected_checksum = data[:8] - arrow_data = data[8:] - - # Validate checksum - computed_checksum = xxhash.xxh3_64_digest(arrow_data) - if computed_checksum != expected_checksum: + # Detect the envelope by sniffing the Arrow IPC file magic (b"ARROW1") rather + # than trusting an integrity flag — this auto-handles checksummed, raw (legacy + # integrity-off), and version-mismatch data, and never feeds a checksum prefix + # into the IPC reader (which previously leaked a bare OSError). memoryview slicing + # avoids the full-body copy that `data[8:]` used to make. + mv = memoryview(data) + n = mv.nbytes + if n >= 14 and bytes(mv[8:14]) == b"ARROW1": + # [8-byte xxHash3-64 checksum][Arrow IPC] + expected_checksum = bytes(mv[:8]) + body = mv[8:] + if xxhash.xxh3_64_digest(body) != expected_checksum: raise SerializationError("Checksum validation failed - data corruption detected") - - # Zero-copy deserialization (memory-mapped) - reader = pa.ipc.open_file(pa.py_buffer(arrow_data)) - table = reader.read_all() + elif n >= 6 and bytes(mv[:6]) == b"ARROW1": + # Legacy raw Arrow IPC written without a checksum prefix (integrity-off entry) + body = mv else: - # No integrity checking - deserialize directly - # This handles both: data written with integrity=False AND backward compatible reads - reader = pa.ipc.open_file(pa.py_buffer(data)) - table = reader.read_all() + raise SerializationError( + f"Invalid data: not a recognized Arrow envelope " + f"(expected [8-byte checksum][Arrow IPC] or raw Arrow IPC); got {n} bytes" + ) + + # pa.py_buffer over the memoryview is zero-copy; open_file decompresses transparently. + reader = pa.ipc.open_file(pa.py_buffer(body)) + table = reader.read_all() # Convert to requested format if self.return_format == "pandas": - return table.to_pandas() + # self_destruct frees each Arrow column as it is converted (the table is a + # throwaway local here, so the experimental-invalidation caveat does not apply); + # split_blocks avoids the transient 2x of consolidated-block construction. + return table.to_pandas(self_destruct=True, split_blocks=True) elif self.return_format == "polars": # Import polars only if needed (avoid mandatory dependency) try: @@ -244,7 +309,7 @@ def deserialize(self, data: bytes, metadata: SerializationMetadata | None = None except ImportError as import_err: raise SerializationError("polars not installed. Install with: pip install polars") from import_err else: # return_format == "arrow" - return table # Zero-copy, no conversion + return table # zero-copy, no conversion - except (pa.ArrowInvalid, pa.ArrowSerializationError) as e: + except (pa.ArrowInvalid, pa.ArrowSerializationError, OSError) as e: raise SerializationError(f"Failed to deserialize Arrow IPC data: {e}") from e diff --git a/src/cachekit/serializers/wrapper.py b/src/cachekit/serializers/wrapper.py index 195e7d2..5f440b4 100644 --- a/src/cachekit/serializers/wrapper.py +++ b/src/cachekit/serializers/wrapper.py @@ -1,22 +1,46 @@ -"""Standard serialization wrapper for cache storage. - -This module provides utilities for wrapping and unwrapping data with metadata -for consistent serialization across all cache backends. +"""Cache-storage envelope for serialized data. + +Wraps serializer output with a small metadata header so cached bytes are +self-describing (serializer name + format flags) without deserializing. +Backend-agnostic: works with Redis, CachekitIO, Memcached, File, L1. + +Wire format (v3 binary frame) +----------------------------- + MAGIC b"CK" | VERSION u8 | HDR_LEN u32-BE | HEADER(json utf-8) | PAYLOAD(raw bytes) + HEADER = {"s": serializer_name, "m": metadata, "v": envelope_version} + +The payload (serializer output: MessagePack/Arrow IPC/ciphertext) is stored +**raw** — no base64, no JSON-embedding. This matters because the previous +base64-in-JSON envelope inflated every binary payload by 1.33x on the wire/in +L1 and forced ~4 full-size copies at peak (b64-bytes -> ascii-str -> json-str -> +utf8-bytes), which made large DataFrames OOM. The frame copies the payload once. + +Backward compatibility +----------------------- +`unwrap` reads BOTH formats: a v3 frame (starts with MAGIC b"CK") or the legacy +base64+JSON envelope (a JSON object, starts with b"{" or arrives as str). New +writes always emit the v3 frame; pre-existing cache entries remain readable, so +no cache flush is required (old entries age out by TTL). + +This envelope is Python-SDK-internal: backends store it as opaque bytes and the +cross-SDK wire format (ByteStorage MessagePack) is unaffected. """ +from __future__ import annotations + import base64 import json from typing import Union +# v3 binary frame constants +_MAGIC = b"CK" +_FRAME_VERSION = 3 +_HEADER_LEN_BYTES = 4 # u32 big-endian header length +_PREFIX_LEN = len(_MAGIC) + 1 + _HEADER_LEN_BYTES # magic(2) + version(1) + hdrlen(4) = 7 -class SerializationWrapper: - """Standard wrapper/unwrapper for cache serialization data. - - Wraps serialized bytes with JSON envelope containing metadata for cache storage. - The envelope format enables introspection of cached data without deserialization. - This wrapper is backend-agnostic and works with any cache backend (Redis, - CachekitIO, Memcached, etc.). +class SerializationWrapper: + """Frame/unframe serialized bytes with a metadata header for cache storage. Examples: Wrap and unwrap data: @@ -37,48 +61,70 @@ class SerializationWrapper: >>> serializer 'auto' - Works with string input (from cache backend): + Binary payloads (non-UTF-8) round-trip without base64: - >>> wrapped_str = wrapped.decode("utf-8") - >>> unwrapped_data, _, _ = SerializationWrapper.unwrap(wrapped_str) - >>> unwrapped_data == data + >>> raw = bytes(range(256)) + >>> out, _, _ = SerializationWrapper.unwrap(SerializationWrapper.wrap(raw, {}, "default")) + >>> out == raw True """ @staticmethod def wrap(data: bytes, metadata: dict, serializer_name: str, version: str = "2.0") -> bytes: - """Wrap serialized data with metadata envelope for cache storage. + """Frame serialized data with a metadata header for cache storage. Args: - data: Serialized bytes to wrap - metadata: Serialization metadata dict (must include "format" key) - serializer_name: Name of serializer used (e.g., "default", "auto") - version: Envelope format version + data: Serialized bytes to wrap (stored raw — no base64). + metadata: Serialization metadata dict (must include "format" key). + serializer_name: Name of serializer used (e.g., "default", "arrow"). + version: Logical serializer-envelope version (carried in the header for + downstream compatibility checks; distinct from the binary frame version). Returns: - JSON-encoded bytes containing base64 data and metadata + v3 binary frame bytes: MAGIC | VERSION | HDR_LEN | HEADER(json) | PAYLOAD(raw). """ - wrapper = { - "data": base64.b64encode(data).decode("ascii"), - "metadata": metadata, - "serializer": serializer_name, - "version": version, - } - return json.dumps(wrapper, ensure_ascii=False).encode("utf-8") + header = json.dumps( + {"s": serializer_name, "m": metadata, "v": version}, + ensure_ascii=False, + ).encode("utf-8") + # Single allocation; the payload is copied exactly once. + return b"".join( + ( + _MAGIC, + bytes((_FRAME_VERSION,)), + len(header).to_bytes(_HEADER_LEN_BYTES, "big"), + header, + data, + ) + ) @staticmethod def unwrap(wrapped_data: Union[str, bytes]) -> tuple[bytes, dict, str]: - """Unwrap data envelope from cache storage. + """Unwrap a cache envelope, reading either the v3 frame or the legacy format. Args: - wrapped_data: JSON envelope (bytes or string) from cache backend + wrapped_data: v3 frame (bytes starting with MAGIC) OR legacy base64+JSON + envelope (bytes/str starting with '{'). Returns: tuple: (data_bytes, metadata_dict, serializer_name) """ - if isinstance(wrapped_data, bytes): - wrapped_data = wrapped_data.decode("utf-8") - + # v3 binary frame: only bytes-like can be a frame (str is always legacy JSON). + if isinstance(wrapped_data, (bytes, bytearray, memoryview)): + mv = memoryview(wrapped_data) + if bytes(mv[: len(_MAGIC)]) == _MAGIC: + frame_version = mv[len(_MAGIC)] + if frame_version != _FRAME_VERSION: + raise ValueError(f"Unsupported cache envelope frame version {frame_version} (expected {_FRAME_VERSION})") + hdr_len = int.from_bytes(mv[len(_MAGIC) + 1 : _PREFIX_LEN], "big") + header_end = _PREFIX_LEN + hdr_len + header = json.loads(bytes(mv[_PREFIX_LEN:header_end])) + payload = bytes(mv[header_end:]) # single copy of the raw payload + return payload, header.get("m", {}), header.get("s", "unknown") + + # Legacy base64+JSON envelope (pre-v3 entries; backward compatible read path). + if isinstance(wrapped_data, (bytes, bytearray, memoryview)): + wrapped_data = bytes(wrapped_data).decode("utf-8") wrapper = json.loads(wrapped_data) data = base64.b64decode(wrapper["data"].encode("ascii")) metadata = wrapper.get("metadata", {}) diff --git a/tests/critical/test_memcached_backend_critical.py b/tests/critical/test_memcached_backend_critical.py index bc40806..b5d1445 100644 --- a/tests/critical/test_memcached_backend_critical.py +++ b/tests/critical/test_memcached_backend_critical.py @@ -44,7 +44,7 @@ def mock_hash_client(mock_store): with patch("pymemcache.client.hash.HashClient") as mock_cls: instance = MagicMock() - def _set(key, value, expire=0): + def _set(key, value, expire=0, noreply=True): mock_store[key] = value def _get(key): @@ -88,6 +88,31 @@ def test_get_set_delete_roundtrip(backend): assert backend.delete("key") is False # Already deleted +@pytest.mark.critical +def test_oversized_value_rejected_loudly(backend, mock_hash_client): + """A value over memcached's item-size cap must raise a clear PERMANENT BackendError, + NOT be silently dropped (the noreply=True default swallowed the server error).""" + big = b"\x00" * (1024 * 1024 + 1) # > 1 MB default memcached item cap + + with pytest.raises(BackendError) as exc_info: + backend.set("key", big, ttl=60) + + assert exc_info.value.error_type == BackendErrorType.PERMANENT + assert "too large" in str(exc_info.value).lower() or "exceeds" in str(exc_info.value).lower() + # Must reject before hitting the server (no silent store) + mock_hash_client.set.assert_not_called() + + +@pytest.mark.critical +def test_set_uses_noreply_false_so_server_errors_surface(backend, mock_hash_client): + """set() must pass noreply=False so memcached's error reply is read, not swallowed.""" + backend.set("key", b"value", ttl=60) + + mock_hash_client.set.assert_called_once() + _, kwargs = mock_hash_client.set.call_args + assert kwargs.get("noreply") is False + + @pytest.mark.critical def test_exists_accurate(backend): """exists() returns correct True/False status.""" @@ -162,7 +187,7 @@ def test_ttl_clamped_to_30_day_max(mock_hash_client): backend = MemcachedBackend(MemcachedBackendConfig()) huge_ttl = MAX_MEMCACHED_TTL + 86400 # 31 days backend.set("key", b"val", ttl=huge_ttl) - mock_hash_client.set.assert_called_once_with("key", b"val", expire=MAX_MEMCACHED_TTL) + mock_hash_client.set.assert_called_once_with("key", b"val", expire=MAX_MEMCACHED_TTL, noreply=False) @pytest.mark.critical @@ -171,10 +196,10 @@ def test_ttl_none_and_zero_mean_no_expiry(mock_hash_client): backend = MemcachedBackend(MemcachedBackendConfig()) backend.set("k1", b"v1", ttl=None) - mock_hash_client.set.assert_called_with("k1", b"v1", expire=0) + mock_hash_client.set.assert_called_with("k1", b"v1", expire=0, noreply=False) backend.set("k2", b"v2", ttl=0) - mock_hash_client.set.assert_called_with("k2", b"v2", expire=0) + mock_hash_client.set.assert_called_with("k2", b"v2", expire=0, noreply=False) @pytest.mark.critical @@ -268,7 +293,7 @@ def test_intent_decorators_with_memcached_backend(mock_store): with patch("pymemcache.client.hash.HashClient") as mock_cls: instance = MagicMock() - instance.set.side_effect = lambda k, v, expire=0: mock_store.__setitem__(k, v) + instance.set.side_effect = lambda k, v, expire=0, noreply=True: mock_store.__setitem__(k, v) instance.get.side_effect = lambda k: mock_store.get(k) instance.delete.side_effect = lambda k, noreply=True: mock_store.pop(k, None) is not None mock_cls.return_value = instance @@ -296,7 +321,7 @@ def test_set_default_backend_with_memcached_backend(mock_store): with patch("pymemcache.client.hash.HashClient") as mock_cls: instance = MagicMock() - instance.set.side_effect = lambda k, v, expire=0: mock_store.__setitem__(k, v) + instance.set.side_effect = lambda k, v, expire=0, noreply=True: mock_store.__setitem__(k, v) instance.get.side_effect = lambda k: mock_store.get(k) instance.delete.side_effect = lambda k, noreply=True: mock_store.pop(k, None) is not None mock_cls.return_value = instance diff --git a/tests/performance/test_large_object_memory.py b/tests/performance/test_large_object_memory.py new file mode 100644 index 0000000..ce5c3a0 --- /dev/null +++ b/tests/performance/test_large_object_memory.py @@ -0,0 +1,97 @@ +"""Memory regression guards for large-object (Arrow/DataFrame) caching. + +These lock in the fixes that removed the base64+JSON wrapper inflation and the +Arrow serializer's copy chain. They assert DETERMINISTIC, environment-independent +metrics (Python-tracked peak via tracemalloc + on-wire size), not process RSS — so +they are stable in CI yet fail loudly if the regressions return: + +- base64+JSON wrap drove store tracemalloc peak to ~5.7x logical and the wire to 1.33x. +- the read path's base64-decode + JSON-parse + full-body slice drove read peak to ~5.4x. + +Pre-fix these assertions fail; post-fix store peak is ~2x, read ~1.1x, wire ~1x. +""" + +from __future__ import annotations + +import gc +import tracemalloc + +import numpy as np +import pandas as pd +import pytest + +from cachekit.serializers.arrow_serializer import ArrowSerializer +from cachekit.serializers.base import SerializationMetadata +from cachekit.serializers.wrapper import SerializationWrapper + +_MB = 1024 * 1024 + + +def _numeric_df(mb: int) -> pd.DataFrame: + """Incompressible float64 frame ~mb MiB (worst case for compression).""" + cols = 20 + rows = mb * _MB // (8 * cols) + rng = np.random.default_rng(0) + return pd.DataFrame({f"c{i}": rng.standard_normal(rows) for i in range(cols)}) + + +def _logical(df: pd.DataFrame) -> int: + return int(df.memory_usage(deep=True, index=True).sum()) + + +@pytest.mark.slow +@pytest.mark.performance +def test_store_path_python_allocations_bounded(): + df = _numeric_df(50) + logical = _logical(df) + serializer = ArrowSerializer() + + gc.collect() + tracemalloc.start() + data, meta = serializer.serialize(df) # df allocated before start() -> not counted + wrapped = SerializationWrapper.wrap(data, meta.to_dict(), "arrow") + peak = tracemalloc.get_traced_memory()[1] + tracemalloc.stop() + + # base64+JSON wrap drove this to ~5.7x; binary frame + zero-copy hashing keeps it ~2x. + assert peak / logical < 3.0, f"store tracemalloc peak {peak / logical:.2f}x logical (regressed?)" + # base64 inflated the wire 1.33x; raw binary frame is ~1x (zstd only shrinks). + assert len(wrapped) / logical < 1.1, f"wire size {len(wrapped) / logical:.2f}x logical (base64 back?)" + + +@pytest.mark.slow +@pytest.mark.performance +def test_load_path_python_allocations_bounded(): + df = _numeric_df(50) + logical = _logical(df) + serializer = ArrowSerializer() + data, meta = serializer.serialize(df) + wrapped = SerializationWrapper.wrap(data, meta.to_dict(), "arrow") + raw, md, _ = SerializationWrapper.unwrap(wrapped) + meta2 = SerializationMetadata.from_dict(md) + + gc.collect() + tracemalloc.start() # wrapped/raw allocated before start() -> not counted + out = serializer.deserialize(raw, meta2) + peak = tracemalloc.get_traced_memory()[1] + tracemalloc.stop() + + assert len(out) == len(df) + pd.testing.assert_frame_equal(out, df) + # base64-decode + JSON-parse + data[8:] slice drove read peak to ~5.4x; now ~1.1x. + assert peak / logical < 2.5, f"load tracemalloc peak {peak / logical:.2f}x logical (regressed?)" + + +@pytest.mark.slow +@pytest.mark.performance +def test_full_roundtrip_through_cache_handler_is_correct_and_compact(): + """End-to-end through the real serialize_data/deserialize_data envelope path.""" + from cachekit.cache_handler import CacheSerializationHandler + + df = _numeric_df(20) + handler = CacheSerializationHandler(serializer_name="arrow") + blob = handler.serialize_data(df, cache_key="k") + assert blob[:2] == b"CK" # new binary frame, not legacy JSON + assert len(blob) / _logical(df) < 1.1 + out = handler.deserialize_data(blob, cache_key="k") + pd.testing.assert_frame_equal(out, df) diff --git a/tests/unit/backends/test_memcached_backend.py b/tests/unit/backends/test_memcached_backend.py index 43a4cd7..b22d0da 100644 --- a/tests/unit/backends/test_memcached_backend.py +++ b/tests/unit/backends/test_memcached_backend.py @@ -80,7 +80,7 @@ def test_get_returns_none_for_missing_key(self, backend: MemcachedBackend, mock_ def test_set_stores_value(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test set calls client.set with correct arguments.""" backend.set("mykey", b"myvalue", ttl=60) - mock_hash_client.set.assert_called_once_with("mykey", b"myvalue", expire=60) + mock_hash_client.set.assert_called_once_with("mykey", b"myvalue", expire=60, noreply=False) def test_delete_returns_true_when_key_exists(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test delete returns True when key existed.""" @@ -156,33 +156,33 @@ class TestTTLBehavior: def test_ttl_none_passes_expire_zero(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test ttl=None passes expire=0 (no expiry).""" backend.set("key", b"val", ttl=None) - mock_hash_client.set.assert_called_once_with("key", b"val", expire=0) + mock_hash_client.set.assert_called_once_with("key", b"val", expire=0, noreply=False) def test_ttl_zero_passes_expire_zero(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test ttl=0 passes expire=0 (no expiry).""" backend.set("key", b"val", ttl=0) - mock_hash_client.set.assert_called_once_with("key", b"val", expire=0) + mock_hash_client.set.assert_called_once_with("key", b"val", expire=0, noreply=False) def test_ttl_positive_passes_expire(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test ttl=100 passes expire=100.""" backend.set("key", b"val", ttl=100) - mock_hash_client.set.assert_called_once_with("key", b"val", expire=100) + mock_hash_client.set.assert_called_once_with("key", b"val", expire=100, noreply=False) def test_ttl_exceeding_30_days_gets_clamped(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test TTL > 30 days gets clamped to MAX_MEMCACHED_TTL (2592000).""" huge_ttl = MAX_MEMCACHED_TTL + 1000 backend.set("key", b"val", ttl=huge_ttl) - mock_hash_client.set.assert_called_once_with("key", b"val", expire=MAX_MEMCACHED_TTL) + mock_hash_client.set.assert_called_once_with("key", b"val", expire=MAX_MEMCACHED_TTL, noreply=False) def test_ttl_exactly_30_days_not_clamped(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test TTL exactly at 30-day max passes through unchanged.""" backend.set("key", b"val", ttl=MAX_MEMCACHED_TTL) - mock_hash_client.set.assert_called_once_with("key", b"val", expire=MAX_MEMCACHED_TTL) + mock_hash_client.set.assert_called_once_with("key", b"val", expire=MAX_MEMCACHED_TTL, noreply=False) def test_negative_ttl_passes_expire_zero(self, backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test negative TTL is treated as no expiry.""" backend.set("key", b"val", ttl=-5) - mock_hash_client.set.assert_called_once_with("key", b"val", expire=0) + mock_hash_client.set.assert_called_once_with("key", b"val", expire=0, noreply=False) @pytest.mark.unit @@ -208,7 +208,7 @@ def test_get_applies_prefix(self, prefixed_backend: MemcachedBackend, mock_hash_ def test_set_applies_prefix(self, prefixed_backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test set prepends prefix to key.""" prefixed_backend.set("mykey", b"val", ttl=60) - mock_hash_client.set.assert_called_once_with("app:mykey", b"val", expire=60) + mock_hash_client.set.assert_called_once_with("app:mykey", b"val", expire=60, noreply=False) def test_delete_applies_prefix(self, prefixed_backend: MemcachedBackend, mock_hash_client: MagicMock) -> None: """Test delete prepends prefix to key.""" diff --git a/tests/unit/test_arrow_serializer.py b/tests/unit/test_arrow_serializer.py index 14ae57c..9e065f2 100644 --- a/tests/unit/test_arrow_serializer.py +++ b/tests/unit/test_arrow_serializer.py @@ -288,14 +288,14 @@ def test_metadata_format_is_arrow_enum(self): assert isinstance(metadata, SerializationMetadata) assert metadata.format == SerializationFormat.ARROW - def test_metadata_compressed_false(self): - """Arrow IPC metadata marks compressed=False (Arrow has optional compression).""" + def test_metadata_compressed_true(self): + """Arrow IPC metadata marks compressed=True (zstd IPC compression is default-on).""" serializer = ArrowSerializer() df = pd.DataFrame({"a": [1, 2, 3]}) _, metadata = serializer.serialize(df) - assert metadata.compressed is False + assert metadata.compressed is True def test_metadata_encrypted_false(self): """Encryption is EncryptionWrapper's responsibility, not ArrowSerializer.""" @@ -377,6 +377,219 @@ def test_polars_serialization_requires_arrow_c_stream(self): pass +class TestCompression: + """Arrow IPC zstd compression (default-on).""" + + def test_compression_shrinks_compressible_payload(self): + """Highly compressible data serializes far smaller than its logical size.""" + serializer = ArrowSerializer() + df = pd.DataFrame({"a": [1] * 100_000, "b": ["constant"] * 100_000}) + logical = int(df.memory_usage(deep=True, index=True).sum()) + + data, _ = serializer.serialize(df) + + # zstd on near-constant columns should compress >5x + assert len(data) < logical // 5 + + def test_compressed_data_round_trips(self): + serializer = ArrowSerializer() + df = pd.DataFrame({"x": range(5000), "y": [f"s{i % 7}" for i in range(5000)]}) + + data, meta = serializer.serialize(df) + result = serializer.deserialize(data, meta) + + pd.testing.assert_frame_equal(result, df) + + +class TestConfigurableCompression: + """Arrow IPC compression is configurable: zstd/lz4 for small payloads, or None + (uncompressed) to enable zero-copy memory-mapped reads. Default resolves from + CACHEKIT_ARROW_COMPRESSION via compression='auto'.""" + + def test_compression_none_is_uncompressed_and_round_trips(self): + df = pd.DataFrame({"a": [1] * 100_000, "b": ["constant"] * 100_000}) + raw, meta_raw = ArrowSerializer(compression=None).serialize(df) + comp, _ = ArrowSerializer(compression="zstd").serialize(df) + + assert meta_raw.compressed is False + assert len(raw) > len(comp) # uncompressed is larger on compressible data + pd.testing.assert_frame_equal(ArrowSerializer(compression=None).deserialize(raw, meta_raw), df) + + def test_compression_none_string_normalizes(self): + _, meta = ArrowSerializer(compression="none").serialize(pd.DataFrame({"a": [1, 2, 3]})) + assert meta.compressed is False + + def test_compression_lz4_round_trips(self): + df = pd.DataFrame({"x": list(range(5000)), "y": [f"s{i % 7}" for i in range(5000)]}) + data, meta = ArrowSerializer(compression="lz4").serialize(df) + assert meta.compressed is True + pd.testing.assert_frame_equal(ArrowSerializer().deserialize(data, meta), df) + + def test_invalid_compression_raises(self): + with pytest.raises(ValueError): + ArrowSerializer(compression="gzip") + + def test_auto_resolves_from_settings_env(self, monkeypatch): + from cachekit.config.singleton import reset_settings + + monkeypatch.setenv("CACHEKIT_ARROW_COMPRESSION", "none") + reset_settings() + try: + _, meta = ArrowSerializer(compression="auto").serialize(pd.DataFrame({"a": [1, 2, 3]})) + assert meta.compressed is False + finally: + reset_settings() + + def test_default_is_auto_zstd(self): + from cachekit.config.singleton import reset_settings + + reset_settings() # no env override -> default zstd + _, meta = ArrowSerializer().serialize(pd.DataFrame({"a": [1] * 1000})) + assert meta.compressed is True + + +class TestIntegrityAlwaysOn: + """DATA IS SACRED: corruption is always detected, even with integrity_checking=False.""" + + def test_corruption_detected_with_integrity_on(self): + serializer = ArrowSerializer(enable_integrity_checking=True) + df = pd.DataFrame({"a": list(range(100))}) + data, meta = serializer.serialize(df) + + corrupted = bytearray(data) + corrupted[30] ^= 0xFF # flip a byte in the body + + with pytest.raises(SerializationError): + serializer.deserialize(bytes(corrupted), meta) + + def test_corruption_detected_even_when_integrity_off(self): + """integrity_checking=False must STILL checksum (silent-corruption window closed).""" + serializer = ArrowSerializer(enable_integrity_checking=False) + df = pd.DataFrame({"a": list(range(100))}) + data, meta = serializer.serialize(df) + + corrupted = bytearray(data) + corrupted[30] ^= 0xFF + + with pytest.raises(SerializationError): + serializer.deserialize(bytes(corrupted), meta) + + +class TestBackwardCompatArrow: + """Legacy entries (pre-change formats) must still deserialize.""" + + def test_reads_legacy_raw_ipc_without_checksum(self): + """A legacy integrity-off entry is raw Arrow IPC (no 8-byte checksum prefix).""" + df = pd.DataFrame({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]}) + table = pa.Table.from_pandas(df, preserve_index=None) + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, table.schema) as writer: # no compression, no checksum + writer.write_table(table) + legacy_raw = sink.getvalue().to_pybytes() + assert legacy_raw[:6] == b"ARROW1" # raw IPC magic at offset 0 + + result = ArrowSerializer().deserialize(legacy_raw, None) + pd.testing.assert_frame_equal(result, df) + + def test_reads_legacy_checksum_prefixed_ipc(self): + """A legacy integrity-on entry is [8-byte xxhash][raw uncompressed IPC].""" + import xxhash + + df = pd.DataFrame({"a": [1, 2, 3]}) + table = pa.Table.from_pandas(df, preserve_index=None) + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, table.schema) as writer: + writer.write_table(table) + raw = sink.getvalue().to_pybytes() + legacy = xxhash.xxh3_64_digest(raw) + raw + assert legacy[8:14] == b"ARROW1" + + result = ArrowSerializer().deserialize(legacy, None) + pd.testing.assert_frame_equal(result, df) + + +class TestExceptionHygiene: + """No raw pyarrow exceptions leak; the documented contract holds.""" + + def test_dict_of_scalars_raises_documented_type_error(self): + serializer = ArrowSerializer() + with pytest.raises(TypeError) as exc_info: + serializer.serialize({"scalar": 123}) + assert "ArrowSerializer only supports DataFrames" in str(exc_info.value) + + def test_malformed_checksummed_input_raises_serialization_error_not_oserror(self): + """Wrong/garbage bytes must surface as SerializationError, never a bare OSError.""" + serializer = ArrowSerializer() + # 8-byte 'checksum' + an ARROW1-looking but invalid body + bad = b"\x00" * 8 + b"ARROW1\x00\x00" + b"\x00" * 64 + with pytest.raises(SerializationError): + serializer.deserialize(bad, None) + + +class TestRangeIndexRoundTrip: + """preserve_index=None: RangeIndex restored as RangeIndex (not materialized column).""" + + def test_default_range_index_round_trips(self): + serializer = ArrowSerializer() + df = pd.DataFrame({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]}) + data, meta = serializer.serialize(df) + result = serializer.deserialize(data, meta) + pd.testing.assert_frame_equal(result, df) + + def test_arrow_table_has_no_synthetic_index_column(self): + serializer = ArrowSerializer(return_format="arrow") + df = pd.DataFrame({"a": [1, 2, 3]}) + data, _ = serializer.serialize(df) + table = serializer.deserialize(data) + assert "__index_level_0__" not in table.column_names + + +class TestDtypeAndIndexFidelity: + """Round-trip fidelity across dtypes/indexes the audit flagged as fragile. + + Guards that zstd + preserve_index=None + to_pandas(self_destruct, split_blocks) + do not regress correctness for the realistic data-science payloads this serializer targets. + """ + + @pytest.mark.parametrize( + "name,df", + [ + ("nullable_int", pd.DataFrame({"a": pd.array([1, None, 3], dtype="Int64")})), + ("nullable_bool", pd.DataFrame({"a": pd.array([True, None, False], dtype="boolean")})), + ("categorical_unordered", pd.DataFrame({"c": pd.Categorical(["x", "y", "x", "z"])})), + ( + "categorical_ordered", + pd.DataFrame({"c": pd.Categorical(["lo", "hi", "lo"], categories=["lo", "hi"], ordered=True)}), + ), + ("datetime_ns", pd.DataFrame({"t": pd.date_range("2020-01-01", periods=5, freq="s")})), + ("datetime_tz", pd.DataFrame({"t": pd.date_range("2020-01-01", periods=5, freq="h", tz="America/New_York")})), + ("timedelta", pd.DataFrame({"d": pd.to_timedelta([1, 2, 3], unit="s")})), + ("float_with_nan", pd.DataFrame({"a": [1.0, float("nan"), 3.0]})), + ("single_row", pd.DataFrame({"a": [1], "b": ["x"]})), + ], + ) + def test_dtype_round_trip(self, name, df): + serializer = ArrowSerializer() + data, meta = serializer.serialize(df) + result = serializer.deserialize(data, meta) + pd.testing.assert_frame_equal(result, df) + + def test_named_index_round_trips_as_index(self): + serializer = ArrowSerializer() + df = pd.DataFrame({"v": [10, 20, 30]}, index=pd.Index(["a", "b", "c"], name="key")) + data, meta = serializer.serialize(df) + result = serializer.deserialize(data, meta) + pd.testing.assert_frame_equal(result, df) + + def test_multiindex_round_trips(self): + serializer = ArrowSerializer() + idx = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)], names=["g", "n"]) + df = pd.DataFrame({"v": [1.0, 2.0, 3.0]}, index=idx) + data, meta = serializer.serialize(df) + result = serializer.deserialize(data, meta) + pd.testing.assert_frame_equal(result, df) + + class TestImportGuard: """Test module-level import guard for pyarrow dependency.""" diff --git a/tests/unit/test_l1_memory_bounds.py b/tests/unit/test_l1_memory_bounds.py new file mode 100644 index 0000000..1725ea2 --- /dev/null +++ b/tests/unit/test_l1_memory_bounds.py @@ -0,0 +1,53 @@ +"""L1 memory-bound guarantees, especially the oversized-single-entry vector. + +A cached value larger than the entire L1 budget must NOT be stored (it would push L1 +permanently over its own limit and, for multi-GB DataFrame envelopes, become an OOM +vector that also evicts every other useful entry). Such values still live in L2. +""" + +from __future__ import annotations + +import pytest + +from cachekit.l1_cache import L1Cache + +MB = 1024 * 1024 + + +@pytest.mark.unit +class TestOversizedEntryRejection: + def test_entry_larger_than_budget_is_not_stored(self): + cache = L1Cache(max_memory_mb=1) + cache.put("big", b"\x00" * (2 * MB), redis_ttl=300) + + found, _ = cache.get("big") + assert found is False + assert cache._current_memory_bytes == 0 + + def test_rejected_oversized_put_does_not_evict_existing_entries(self): + """A doomed oversized put must not evict good entries on its way to failing.""" + cache = L1Cache(max_memory_mb=1) + cache.put("keep", b"\x00" * (512 * 1024), redis_ttl=300) # fits + + cache.put("toobig", b"\x00" * (5 * MB), redis_ttl=300) # cannot ever fit + + assert cache.get("keep")[0] is True # survivor + assert cache.get("toobig")[0] is False + assert cache._current_memory_bytes <= cache.max_memory_bytes + + def test_entry_equal_to_budget_is_stored(self): + cache = L1Cache(max_memory_mb=1) + cache.put("exact", b"\x00" * (1 * MB), redis_ttl=300) + assert cache.get("exact")[0] is True + + def test_normal_entry_still_stored(self): + cache = L1Cache(max_memory_mb=10) + cache.put("k", b"value", redis_ttl=300) + assert cache.get("k") == (True, b"value") + + def test_memory_never_exceeds_budget_under_mixed_load(self): + cache = L1Cache(max_memory_mb=2) + for i in range(20): + cache.put(f"k{i}", b"\x00" * (300 * 1024), redis_ttl=300) # 300KB each + cache.put("huge", b"\x00" * (50 * MB), redis_ttl=300) # rejected + assert cache._current_memory_bytes <= cache.max_memory_bytes diff --git a/tests/unit/test_serialization_wrapper.py b/tests/unit/test_serialization_wrapper.py new file mode 100644 index 0000000..b1ae130 --- /dev/null +++ b/tests/unit/test_serialization_wrapper.py @@ -0,0 +1,162 @@ +"""Unit tests for SerializationWrapper binary-frame envelope. + +The wrapper frames serializer output for cache storage. It MUST: +- avoid base64 (which inflated binary payloads 1.33x and forced ~4 full copies), +- round-trip arbitrary binary payloads (including non-UTF-8 bytes), +- remain backward-compatible on read with the legacy base64+JSON envelope, +so already-stored cache entries stay readable across the upgrade. +""" + +from __future__ import annotations + +import base64 +import json + +import pytest + +from cachekit.serializers.base import SerializationError +from cachekit.serializers.wrapper import SerializationWrapper + +PAYLOAD = b"\x00\x01\xff\xfe\x00ARROW1\x00\x00binary-not-text\x80\x81" +META = {"format": "arrow", "compressed": True, "original_type": "arrow"} + + +class TestBinaryFrame: + def test_roundtrip_returns_payload_metadata_serializer(self): + wrapped = SerializationWrapper.wrap(PAYLOAD, META, "arrow") + data, meta, name = SerializationWrapper.unwrap(wrapped) + assert data == PAYLOAD + assert meta == META + assert name == "arrow" + + def test_output_is_bytes(self): + assert isinstance(SerializationWrapper.wrap(PAYLOAD, META, "arrow"), bytes) + + def test_payload_is_not_base64_encoded(self): + """The raw payload bytes must appear verbatim in the frame (no base64).""" + wrapped = SerializationWrapper.wrap(PAYLOAD, META, "default") + assert PAYLOAD in wrapped + # base64 of the payload must NOT be present (proves we dropped base64) + assert base64.b64encode(PAYLOAD) not in wrapped + + def test_no_size_inflation(self): + """Frame overhead is a small fixed header, not base64's 1.33x.""" + big = b"\x07" * 1_000_000 + wrapped = SerializationWrapper.wrap(big, META, "arrow") + # < 1KB of framing overhead; nowhere near base64's +333KB + assert len(wrapped) - len(big) < 1024 + + def test_non_utf8_payload_roundtrips(self): + """Binary payloads that are not valid UTF-8 must survive unwrap (no decode).""" + evil = bytes(range(256)) * 10 + data, _, _ = SerializationWrapper.unwrap(SerializationWrapper.wrap(evil, {}, "default")) + assert data == evil + + def test_empty_payload_roundtrips(self): + data, meta, name = SerializationWrapper.unwrap(SerializationWrapper.wrap(b"", {"format": "msgpack"}, "default")) + assert data == b"" + assert name == "default" + + def test_metadata_with_encryption_fields_roundtrips(self): + enc_meta = {"format": "msgpack", "encrypted": True, "tenant_id": "acme", "key_fingerprint": "abc123"} + _, meta, _ = SerializationWrapper.unwrap(SerializationWrapper.wrap(b"cipher", enc_meta, "default")) + assert meta == enc_meta + + +class TestLegacyBackwardCompat: + """Old base64+JSON entries (written before this change) must still deserialize.""" + + @staticmethod + def _legacy_wrap(data: bytes, metadata: dict, serializer_name: str, version: str = "2.0") -> bytes: + wrapper = { + "data": base64.b64encode(data).decode("ascii"), + "metadata": metadata, + "serializer": serializer_name, + "version": version, + } + return json.dumps(wrapper, ensure_ascii=False).encode("utf-8") + + def test_unwrap_reads_legacy_bytes_envelope(self): + legacy = self._legacy_wrap(PAYLOAD, META, "arrow") + data, meta, name = SerializationWrapper.unwrap(legacy) + assert data == PAYLOAD + assert meta == META + assert name == "arrow" + + def test_unwrap_reads_legacy_str_envelope(self): + """Some backends hand back str; legacy JSON must still decode from str.""" + legacy = self._legacy_wrap(PAYLOAD, META, "arrow").decode("utf-8") + data, _, name = SerializationWrapper.unwrap(legacy) + assert data == PAYLOAD + assert name == "arrow" + + def test_new_and_legacy_are_distinguishable(self): + """New frame starts with magic; legacy JSON starts with '{'. Sniffing is unambiguous.""" + new = SerializationWrapper.wrap(PAYLOAD, META, "arrow") + legacy = self._legacy_wrap(PAYLOAD, META, "arrow") + assert new[:1] != b"{" + assert legacy[:1] == b"{" + + +class TestUnwrapRejectsGarbage: + def test_unrecognized_envelope_raises(self): + with pytest.raises((ValueError, Exception)): + SerializationWrapper.unwrap(b"\x99\x98 not a frame and not json") + + +class TestEncryptionThroughFrame: + """The binary frame is on the hot path for @cache.secure too: encrypted payloads and + their encryption metadata must survive the frame, AAD binding must still hold, and old + base64+JSON encrypted entries must still decrypt. (Regression for the wrapper rewrite.)""" + + KEY = "user:42:credentials" + + @pytest.fixture + def enc_handler(self): + import os + + from cachekit.config.singleton import reset_settings + + reset_settings() + os.environ["CACHEKIT_MASTER_KEY"] = "a" * 64 + from cachekit.cache_handler import CacheSerializationHandler + + handler = CacheSerializationHandler( + serializer_name="default", + encryption=True, + single_tenant_mode=True, + deployment_uuid="00000000-0000-0000-0000-000000000001", + ) + yield handler + reset_settings() + os.environ.pop("CACHEKIT_MASTER_KEY", None) + + def test_encrypted_payload_round_trips_through_frame(self, enc_handler): + secret = {"ssn": "123-45-6789", "balance": 99999} + blob = enc_handler.serialize_data(secret, cache_key=self.KEY) + assert blob[:2] == b"CK" # new binary frame + assert b"123-45-6789" not in blob # plaintext never present + assert enc_handler.deserialize_data(blob, cache_key=self.KEY) == secret + + def test_encryption_metadata_survives_frame_header(self, enc_handler): + blob = enc_handler.serialize_data({"k": "v"}, cache_key=self.KEY) + _, meta, _ = SerializationWrapper.unwrap(blob) + assert meta["encrypted"] is True + assert meta["tenant_id"] + assert meta["encryption_algorithm"] == "AES-256-GCM" + + def test_wrong_cache_key_is_rejected(self, enc_handler): + """AAD binding: ciphertext is bound to the cache key; a mismatched key must not decrypt.""" + blob = enc_handler.serialize_data({"k": "v"}, cache_key=self.KEY) + # EncryptionError subclasses SerializationError; AAD mismatch must raise, never silently succeed. + with pytest.raises(SerializationError): + enc_handler.deserialize_data(blob, cache_key="WRONG:key") + + def test_legacy_base64_json_encrypted_entry_still_decrypts(self, enc_handler): + """A pre-upgrade encrypted entry (base64+JSON envelope) must remain readable.""" + new_blob = enc_handler.serialize_data({"old": "secret"}, cache_key=self.KEY) + inner, meta, name = SerializationWrapper.unwrap(new_blob) + legacy = json.dumps( + {"data": base64.b64encode(inner).decode("ascii"), "metadata": meta, "serializer": name, "version": "2.0"} + ).encode("utf-8") + assert enc_handler.deserialize_data(legacy, cache_key=self.KEY) == {"old": "secret"} diff --git a/tests/unit/test_serializer_integrity.py b/tests/unit/test_serializer_integrity.py index 533c976..232eff3 100644 --- a/tests/unit/test_serializer_integrity.py +++ b/tests/unit/test_serializer_integrity.py @@ -232,7 +232,7 @@ def test_empty_data_raises_error(self): serializer.deserialize(b"") assert "Invalid data" in str(exc_info.value) - assert "Expected at least 40 bytes" in str(exc_info.value) + assert "Arrow envelope" in str(exc_info.value) def test_too_short_data_raises_error(self): """Data shorter than minimum raises error.""" @@ -245,7 +245,7 @@ def test_too_short_data_raises_error(self): serializer.deserialize(invalid_data) assert "Invalid data" in str(exc_info.value) - assert "Expected at least 40 bytes" in str(exc_info.value) + assert "Arrow envelope" in str(exc_info.value) def test_bit_flip_in_dataframe_detected(self): """Single bit flip in DataFrame data is detected.""" diff --git a/tests/unit/test_xxhash_integrity.py b/tests/unit/test_xxhash_integrity.py index 80c6609..30ffae1 100644 --- a/tests/unit/test_xxhash_integrity.py +++ b/tests/unit/test_xxhash_integrity.py @@ -132,18 +132,17 @@ def test_checksum_is_8_bytes(self): assert len(data) >= 40, f"Expected at least 40 bytes, got {len(data)}" def test_checksum_overhead_is_8_bytes(self): - """Verify checksum overhead is exactly 8 bytes (xxHash3-64, not 32-byte Blake3).""" - # Create two serializers: one with integrity, one without - serializer_with = ArrowSerializer(enable_integrity_checking=True) - serializer_without = ArrowSerializer(enable_integrity_checking=False) - + """The 8-byte xxHash3-64 checksum is ALWAYS prepended, regardless of the + enable_integrity_checking flag. Silently returning corrupted DataFrames is + unacceptable (DATA IS SACRED), and 8 bytes is negligible vs the payload, so the + speed-first 'no checksum' path was removed. Format: [8-byte checksum][Arrow IPC].""" df = pd.DataFrame({"col": range(1000)}) - data_with, _ = serializer_with.serialize(df) - data_without, _ = serializer_without.serialize(df) - - overhead = len(data_with) - len(data_without) - assert overhead == 8, f"Expected 8-byte overhead (xxHash3-64), got {overhead} bytes" + for integrity in (True, False): + data, _ = ArrowSerializer(enable_integrity_checking=integrity).serialize(df) + # The Arrow IPC file magic 'ARROW1' sits at offset 8, proving exactly an + # 8-byte checksum prefix precedes the payload in both modes. + assert data[8:14] == b"ARROW1", f"8-byte checksum prefix missing (integrity={integrity})" def test_minimum_size_check_is_40_bytes(self): """Deserialize should require at least 40 bytes (8-byte checksum + 32-byte Arrow header).""" @@ -154,7 +153,7 @@ def test_minimum_size_check_is_40_bytes(self): serializer.deserialize(b"X" * 39) assert "Invalid data" in str(exc_info.value) - assert "40 bytes" in str(exc_info.value) + assert "Arrow envelope" in str(exc_info.value) def test_roundtrip_with_xxhash_checksum(self): """Normal DataFrame serialize/deserialize roundtrip works with xxHash3-64 checksum."""