diff --git a/src/cachekit/backends/redis/backend.py b/src/cachekit/backends/redis/backend.py index b11b9f9..8367060 100644 --- a/src/cachekit/backends/redis/backend.py +++ b/src/cachekit/backends/redis/backend.py @@ -107,15 +107,11 @@ def get(self, key: str) -> Optional[bytes]: try: client = self._get_client() value = client.get(key) - # Redis client with decode_responses=True returns str, need bytes - # But get() with binary data returns bytes if decode fails - # For safety, encode if we got str - if value is not None: - if isinstance(value, str): - return value.encode("utf-8") - if isinstance(value, bytes): - return value - return None + # The pool is configured with decode_responses=False (see redis/client.py), + # so Redis returns raw bytes, or None for a missing key. Cached payloads are + # binary (LZ4/Arrow/AES ciphertext) and must never be UTF-8 decoded. The + # isinstance check enforces the bytes|None contract without any str coercion. + return value if isinstance(value, bytes) else None except Exception as e: raise BackendError( message=f"Redis GET failed: {e}", diff --git a/src/cachekit/backends/redis/client.py b/src/cachekit/backends/redis/client.py index b76bb7c..daa4411 100644 --- a/src/cachekit/backends/redis/client.py +++ b/src/cachekit/backends/redis/client.py @@ -78,7 +78,7 @@ def get_redis_client() -> redis.Redis: # Use URL-based connection _pool_instance = redis.ConnectionPool.from_url( redis_config.redis_url, - decode_responses=True, + decode_responses=False, # cached payloads are raw bytes (LZ4/Arrow/AES) — never UTF-8 decode max_connections=redis_config.connection_pool_size, ) @@ -120,7 +120,7 @@ async def get_async_redis_client() -> redis_async.Redis: # Use URL-based connection _async_pool_instance = redis_async.ConnectionPool.from_url( redis_config.redis_url, - decode_responses=True, + decode_responses=False, # cached payloads are raw bytes (LZ4/Arrow/AES) — never UTF-8 decode max_connections=redis_config.connection_pool_size, ) diff --git a/tests/integration/test_redis_backend.py b/tests/integration/test_redis_backend.py index 1f34420..d776f32 100644 --- a/tests/integration/test_redis_backend.py +++ b/tests/integration/test_redis_backend.py @@ -18,6 +18,8 @@ from cachekit.backends.base import BackendError, BaseBackend from cachekit.backends.redis import RedisBackend +from ..utils.redis_test_helpers import RedisIsolationMixin + @pytest.mark.unit class TestRedisBackendInitialization: @@ -139,14 +141,15 @@ def test_get_missing_key_returns_none(self): assert result is None @patch.dict("os.environ", {"REDIS_URL": "redis://localhost:6379"}, clear=True) - def test_get_handles_string_response(self): - """get() should handle string response from decode_responses=True.""" + def test_get_returns_raw_bytes_unchanged(self): + """get() returns Redis bytes unchanged — the pool uses decode_responses=False, + so binary payloads (incl. non-UTF-8) pass through without decode/coercion (#154).""" with patch("cachekit.backends.redis.backend.DIContainer") as mock_container_class: mock_container_instance = Mock() mock_container_class.return_value = mock_container_instance mock_client = Mock() - # Redis with decode_responses=True returns str - mock_client.get.return_value = "string_value" + # decode_responses=False -> redis-py returns raw bytes, including non-UTF-8. + mock_client.get.return_value = b"\x82\xa3val\xff\xfe" mock_provider = Mock() mock_provider.get_sync_client.return_value = mock_client mock_container_instance.get.return_value = mock_provider @@ -154,8 +157,7 @@ def test_get_handles_string_response(self): backend = RedisBackend() result = backend.get("test:key") - # Should convert str to bytes - assert result == b"string_value" + assert result == b"\x82\xa3val\xff\xfe" assert isinstance(result, bytes) @patch.dict("os.environ", {"REDIS_URL": "redis://localhost:6379"}, clear=True) @@ -777,3 +779,37 @@ def test_reset_global_pool_resets_async_lock(self): finally: loop.close() asyncio.set_event_loop(None) + + +@pytest.mark.integration +class TestRedisBinaryRoundtrip(RedisIsolationMixin): + """Regression for #154: the shared Redis pool must return raw bytes. + + Non-UTF-8 payloads (Rust ByteStorage LZ4, Arrow IPC, AES-256-GCM ciphertext) + must round-trip intact. Previously the pool used decode_responses=True, so + redis-py ran value.decode('utf-8', 'strict') on every GET and raised + UnicodeDecodeError on any binary payload. + """ + + def test_non_utf8_payload_roundtrips_through_shared_pool(self, monkeypatch): + import cachekit.backends.redis.client as rc + from cachekit.config.singleton import reset_settings + + kw = self.redis_client.connection_pool.connection_kwargs + # pytest-redis uses a unix socket by default; support both transports. + if "path" in kw: + url = f"unix://{kw['path']}?db={kw.get('db', 0)}" + else: + url = f"redis://{kw['host']}:{kw['port']}/{kw.get('db', 0)}" + monkeypatch.setenv("CACHEKIT_REDIS_URL", url) + rc._pool_instance = None + reset_settings() + try: + client = rc.get_cached_redis_client() + # Valid MessagePack-ish bytes that are NOT valid UTF-8 (0x82, 0xff). + evil = b"\x82\xa3ssn\xa3123\x00\xff\xfe" + client.set("ck:154:bin", evil) + assert client.get("ck:154:bin") == evil + finally: + rc._pool_instance = None + reset_settings() diff --git a/tests/unit/backends/test_redis_backend.py b/tests/unit/backends/test_redis_backend.py new file mode 100644 index 0000000..3a72490 --- /dev/null +++ b/tests/unit/backends/test_redis_backend.py @@ -0,0 +1,93 @@ +"""Unit tests for the Redis connection-pool configuration and RedisBackend.get() contract. + +These are mocked (no real Redis) and live under tests/unit/ so they run on pull +requests — unlike the real-Redis suite in tests/integration/test_redis_backend.py, +which CI only runs on push-to-main. + +Regression coverage for #154: the shared pools must use decode_responses=False so +binary payloads (LZ4 / Arrow IPC / AES-256-GCM ciphertext) are never UTF-8 decoded, +and RedisBackend.get() must return those raw bytes (or None) without coercion. +""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest + +from cachekit.backends.redis import RedisBackend + + +@pytest.mark.unit +class TestRedisPoolDecodeResponses: + """The shared sync and async pools must be created with decode_responses=False.""" + + @staticmethod + def _reset(monkeypatch): + import cachekit.backends.redis.client as rc + from cachekit.config.singleton import reset_settings + + monkeypatch.setenv("CACHEKIT_REDIS_URL", "redis://localhost:6379") + rc._pool_instance = None + rc._async_pool_instance = None + reset_settings() + return rc + + def test_sync_pool_uses_decode_responses_false(self, monkeypatch): + rc = self._reset(monkeypatch) + from cachekit.config.singleton import reset_settings + + with patch("redis.ConnectionPool.from_url") as mock_from_url, patch("redis.Redis"): + try: + rc.get_cached_redis_client() + finally: + rc._pool_instance = None + reset_settings() + assert mock_from_url.call_args.kwargs["decode_responses"] is False + + async def test_async_pool_uses_decode_responses_false(self, monkeypatch): + rc = self._reset(monkeypatch) + from cachekit.config.singleton import reset_settings + + with patch("redis.asyncio.ConnectionPool.from_url") as mock_from_url, patch("redis.asyncio.Redis"): + try: + await rc.get_async_redis_client() + finally: + rc._async_pool_instance = None + reset_settings() + assert mock_from_url.call_args.kwargs["decode_responses"] is False + + +@pytest.mark.unit +class TestRedisBackendGetContract: + """get() returns raw bytes (or None) — never str, never UTF-8 decoded. + + Uses explicit client_provider injection (no DIContainer / env patching) so the + tests are independent of REDIS_URL vs CACHEKIT_REDIS_URL alias resolution. + """ + + @staticmethod + def _backend_returning(value): + from cachekit.backends.provider import CacheClientProvider + + mock_client = Mock() + mock_client.get.return_value = value + provider = Mock(spec=CacheClientProvider) + provider.get_sync_client.return_value = mock_client + return RedisBackend("redis://localhost:6379", client_provider=provider) + + def test_get_returns_non_utf8_bytes_unchanged(self): + backend = self._backend_returning(b"\x82\xa3val\xff\xfe") + result = backend.get("k") + assert result == b"\x82\xa3val\xff\xfe" + assert isinstance(result, bytes) + + def test_get_returns_none_for_missing_key(self): + backend = self._backend_returning(None) + assert backend.get("missing") is None + + def test_get_returns_none_for_non_bytes_response(self): + # decode_responses=False means this never happens in practice, but the + # bytes|None narrowing guard must hold defensively (no str coercion). + backend = self._backend_returning("unexpected-str") + assert backend.get("k") is None