Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/cachekit/backends/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,11 @@ def get(self, key: str) -> Optional[bytes]:
try:
client = self._get_client()
value = client.get(key)
# Redis client with decode_responses=True returns str, need bytes
# But get() with binary data returns bytes if decode fails
# For safety, encode if we got str
if value is not None:
if isinstance(value, str):
return value.encode("utf-8")
if isinstance(value, bytes):
return value
return None
# The pool is configured with decode_responses=False (see redis/client.py),
# so Redis returns raw bytes, or None for a missing key. Cached payloads are
# binary (LZ4/Arrow/AES ciphertext) and must never be UTF-8 decoded. The
# isinstance check enforces the bytes|None contract without any str coercion.
return value if isinstance(value, bytes) else None
except Exception as e:
raise BackendError(
message=f"Redis GET failed: {e}",
Expand Down
4 changes: 2 additions & 2 deletions src/cachekit/backends/redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_redis_client() -> redis.Redis:
# Use URL-based connection
_pool_instance = redis.ConnectionPool.from_url(
redis_config.redis_url,
decode_responses=True,
decode_responses=False, # cached payloads are raw bytes (LZ4/Arrow/AES) — never UTF-8 decode
max_connections=redis_config.connection_pool_size,
)

Expand Down Expand Up @@ -120,7 +120,7 @@ async def get_async_redis_client() -> redis_async.Redis:
# Use URL-based connection
_async_pool_instance = redis_async.ConnectionPool.from_url(
redis_config.redis_url,
decode_responses=True,
decode_responses=False, # cached payloads are raw bytes (LZ4/Arrow/AES) — never UTF-8 decode
max_connections=redis_config.connection_pool_size,
)

Expand Down
48 changes: 42 additions & 6 deletions tests/integration/test_redis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from cachekit.backends.base import BackendError, BaseBackend
from cachekit.backends.redis import RedisBackend

from ..utils.redis_test_helpers import RedisIsolationMixin


@pytest.mark.unit
class TestRedisBackendInitialization:
Expand Down Expand Up @@ -139,23 +141,23 @@ def test_get_missing_key_returns_none(self):
assert result is None

@patch.dict("os.environ", {"REDIS_URL": "redis://localhost:6379"}, clear=True)
def test_get_handles_string_response(self):
"""get() should handle string response from decode_responses=True."""
def test_get_returns_raw_bytes_unchanged(self):
"""get() returns Redis bytes unchanged — the pool uses decode_responses=False,
so binary payloads (incl. non-UTF-8) pass through without decode/coercion (#154)."""
with patch("cachekit.backends.redis.backend.DIContainer") as mock_container_class:
mock_container_instance = Mock()
mock_container_class.return_value = mock_container_instance
mock_client = Mock()
# Redis with decode_responses=True returns str
mock_client.get.return_value = "string_value"
# decode_responses=False -> redis-py returns raw bytes, including non-UTF-8.
mock_client.get.return_value = b"\x82\xa3val\xff\xfe"
mock_provider = Mock()
mock_provider.get_sync_client.return_value = mock_client
mock_container_instance.get.return_value = mock_provider

backend = RedisBackend()
result = backend.get("test:key")

# Should convert str to bytes
assert result == b"string_value"
assert result == b"\x82\xa3val\xff\xfe"
assert isinstance(result, bytes)

@patch.dict("os.environ", {"REDIS_URL": "redis://localhost:6379"}, clear=True)
Expand Down Expand Up @@ -777,3 +779,37 @@ def test_reset_global_pool_resets_async_lock(self):
finally:
loop.close()
asyncio.set_event_loop(None)


@pytest.mark.integration
class TestRedisBinaryRoundtrip(RedisIsolationMixin):
"""Regression for #154: the shared Redis pool must return raw bytes.

Non-UTF-8 payloads (Rust ByteStorage LZ4, Arrow IPC, AES-256-GCM ciphertext)
must round-trip intact. Previously the pool used decode_responses=True, so
redis-py ran value.decode('utf-8', 'strict') on every GET and raised
UnicodeDecodeError on any binary payload.
"""

def test_non_utf8_payload_roundtrips_through_shared_pool(self, monkeypatch):
import cachekit.backends.redis.client as rc
from cachekit.config.singleton import reset_settings

kw = self.redis_client.connection_pool.connection_kwargs
# pytest-redis uses a unix socket by default; support both transports.
if "path" in kw:
url = f"unix://{kw['path']}?db={kw.get('db', 0)}"
else:
url = f"redis://{kw['host']}:{kw['port']}/{kw.get('db', 0)}"
monkeypatch.setenv("CACHEKIT_REDIS_URL", url)
rc._pool_instance = None
reset_settings()
try:
client = rc.get_cached_redis_client()
# Valid MessagePack-ish bytes that are NOT valid UTF-8 (0x82, 0xff).
evil = b"\x82\xa3ssn\xa3123\x00\xff\xfe"
client.set("ck:154:bin", evil)
assert client.get("ck:154:bin") == evil
finally:
rc._pool_instance = None
reset_settings()
93 changes: 93 additions & 0 deletions tests/unit/backends/test_redis_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Unit tests for the Redis connection-pool configuration and RedisBackend.get() contract.

These are mocked (no real Redis) and live under tests/unit/ so they run on pull
requests — unlike the real-Redis suite in tests/integration/test_redis_backend.py,
which CI only runs on push-to-main.

Regression coverage for #154: the shared pools must use decode_responses=False so
binary payloads (LZ4 / Arrow IPC / AES-256-GCM ciphertext) are never UTF-8 decoded,
and RedisBackend.get() must return those raw bytes (or None) without coercion.
"""

from __future__ import annotations

from unittest.mock import Mock, patch

import pytest

from cachekit.backends.redis import RedisBackend


@pytest.mark.unit
class TestRedisPoolDecodeResponses:
"""The shared sync and async pools must be created with decode_responses=False."""

@staticmethod
def _reset(monkeypatch):
import cachekit.backends.redis.client as rc
from cachekit.config.singleton import reset_settings

monkeypatch.setenv("CACHEKIT_REDIS_URL", "redis://localhost:6379")
rc._pool_instance = None
rc._async_pool_instance = None
reset_settings()
return rc

def test_sync_pool_uses_decode_responses_false(self, monkeypatch):
rc = self._reset(monkeypatch)
from cachekit.config.singleton import reset_settings

with patch("redis.ConnectionPool.from_url") as mock_from_url, patch("redis.Redis"):
try:
rc.get_cached_redis_client()
finally:
rc._pool_instance = None
reset_settings()
assert mock_from_url.call_args.kwargs["decode_responses"] is False

async def test_async_pool_uses_decode_responses_false(self, monkeypatch):
rc = self._reset(monkeypatch)
from cachekit.config.singleton import reset_settings

with patch("redis.asyncio.ConnectionPool.from_url") as mock_from_url, patch("redis.asyncio.Redis"):
try:
await rc.get_async_redis_client()
finally:
rc._async_pool_instance = None
reset_settings()
assert mock_from_url.call_args.kwargs["decode_responses"] is False


@pytest.mark.unit
class TestRedisBackendGetContract:
"""get() returns raw bytes (or None) — never str, never UTF-8 decoded.

Uses explicit client_provider injection (no DIContainer / env patching) so the
tests are independent of REDIS_URL vs CACHEKIT_REDIS_URL alias resolution.
"""

@staticmethod
def _backend_returning(value):
from cachekit.backends.provider import CacheClientProvider

mock_client = Mock()
mock_client.get.return_value = value
provider = Mock(spec=CacheClientProvider)
provider.get_sync_client.return_value = mock_client
return RedisBackend("redis://localhost:6379", client_provider=provider)

def test_get_returns_non_utf8_bytes_unchanged(self):
backend = self._backend_returning(b"\x82\xa3val\xff\xfe")
result = backend.get("k")
assert result == b"\x82\xa3val\xff\xfe"
assert isinstance(result, bytes)

def test_get_returns_none_for_missing_key(self):
backend = self._backend_returning(None)
assert backend.get("missing") is None

def test_get_returns_none_for_non_bytes_response(self):
# decode_responses=False means this never happens in practice, but the
# bytes|None narrowing guard must hold defensively (no str coercion).
backend = self._backend_returning("unexpected-str")
assert backend.get("k") is None
Loading