diff --git a/.sampo/changesets/async-stream-context-manager.md b/.sampo/changesets/async-stream-context-manager.md new file mode 100644 index 00000000..e2cbf340 --- /dev/null +++ b/.sampo/changesets/async-stream-context-manager.md @@ -0,0 +1,5 @@ +--- +pypi/posthog: patch +--- + +Fix async streaming responses from the AI wrappers (OpenAI, Anthropic, Gemini) so they support `async with` as well as `async for`. Previously, consuming a stream via `async with` (e.g. with pydantic-ai) raised `TypeError: 'async_generator' object does not support the asynchronous context manager protocol`. diff --git a/posthog/ai/anthropic/anthropic_async.py b/posthog/ai/anthropic/anthropic_async.py index 9b02e35c..df098955 100644 --- a/posthog/ai/anthropic/anthropic_async.py +++ b/posthog/ai/anthropic/anthropic_async.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional from posthog import setup +from posthog.ai.stream import AsyncStreamWrapper from posthog.ai.types import StreamingContentBlock, TokenUsage, ToolInProgress from posthog.ai.utils import ( call_llm_and_track_usage_async, @@ -225,7 +226,7 @@ async def generator(): stop_reason=stop_reason, ) - return generator() + return AsyncStreamWrapper(generator(), stream=response) async def _capture_streaming_event( self, diff --git a/posthog/ai/gemini/gemini_async.py b/posthog/ai/gemini/gemini_async.py index cd2b962f..07ba3f02 100644 --- a/posthog/ai/gemini/gemini_async.py +++ b/posthog/ai/gemini/gemini_async.py @@ -3,6 +3,7 @@ import uuid from typing import Any, Dict, Optional +from posthog.ai.stream import AsyncStreamWrapper from posthog.ai.types import TokenUsage, StreamingEventData from posthog.ai.utils import merge_system_prompt @@ -354,7 +355,7 @@ async def async_generator(): stop_reason=stop_reason, ) - return async_generator() + return AsyncStreamWrapper(async_generator(), stream=response) def _capture_streaming_event( self, diff --git a/posthog/ai/openai/openai_async.py b/posthog/ai/openai/openai_async.py index 8e4644ff..7e7b5838 100644 --- a/posthog/ai/openai/openai_async.py +++ b/posthog/ai/openai/openai_async.py @@ -2,6 +2,7 @@ import uuid from typing import Any, Dict, List, Optional +from posthog.ai.stream import AsyncStreamWrapper from posthog.ai.types import TokenUsage try: @@ -221,7 +222,7 @@ async def async_generator(): stop_reason=stop_reason, ) - return async_generator() + return AsyncStreamWrapper(async_generator(), stream=response) async def _capture_streaming_event( self, @@ -515,7 +516,7 @@ async def async_generator(): stop_reason=stop_reason, ) - return async_generator() + return AsyncStreamWrapper(async_generator(), stream=response) async def _capture_streaming_event( self, diff --git a/posthog/ai/stream.py b/posthog/ai/stream.py new file mode 100644 index 00000000..4ed8ca94 --- /dev/null +++ b/posthog/ai/stream.py @@ -0,0 +1,62 @@ +"""Shared async streaming utilities for PostHog AI wrappers.""" + +from typing import Any, AsyncGenerator, Generic, Optional, TypeVar + +T = TypeVar("T") + + +class AsyncStreamWrapper(Generic[T]): + """Adds the async context manager protocol to a PostHog streaming generator. + + The OpenAI and Anthropic SDK streams support both ``async for`` and + ``async with``. PostHog's wrappers returned a bare async generator, which + only supports ``async for``, so ``async with response:`` (used by + pydantic-ai) raised a TypeError. This wraps the tracking generator and, + when given the original provider stream, closes it and proxies attribute + access (e.g. ``.response``) to it. + """ + + def __init__( + self, + generator: AsyncGenerator[T, None], + stream: Optional[Any] = None, + ) -> None: + self._generator = generator + self._stream = stream + + def __aiter__(self) -> "AsyncStreamWrapper[T]": + return self + + async def __anext__(self) -> T: + return await self._generator.__anext__() + + async def __aenter__(self) -> "AsyncStreamWrapper[T]": + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + # Close the generator first so its `finally` captures the event, even on + # early exit. try/finally still closes the provider stream if that raises. + try: + await self._generator.aclose() + finally: + if self._stream is not None: + close = getattr(self._stream, "aclose", None) or getattr( + self._stream, "close", None + ) + if close is not None: + await close() + + return False + + # aclose/asend/athrow belong to the generator; provider streams expose + # close(), not these. Forwarding aclose() keeps it firing the event. + _GENERATOR_METHODS = ("aclose", "asend", "athrow") + + def __getattr__(self, name: str) -> Any: + # Proxy only public attributes (e.g. `.response`) to the provider stream. + if name.startswith("_"): + raise AttributeError(name) + if name in self._GENERATOR_METHODS: + return getattr(self._generator, name) + target = self._stream if self._stream is not None else self._generator + return getattr(target, name) diff --git a/posthog/test/ai/anthropic/test_anthropic.py b/posthog/test/ai/anthropic/test_anthropic.py index ee7aa640..a2dd451b 100644 --- a/posthog/test/ai/anthropic/test_anthropic.py +++ b/posthog/test/ai/anthropic/test_anthropic.py @@ -10,6 +10,7 @@ from anthropic.types import Message, Usage from posthog.ai.anthropic import Anthropic, AsyncAnthropic + from posthog.test.ai.utils import RecordingAsyncStream ANTHROPIC_AVAILABLE = True except ImportError: @@ -1421,3 +1422,74 @@ def test_integration_stop_reason(mock_client): assert props["$ai_stop_reason"] in ("end_turn", "max_tokens") assert props["$ai_provider"] == "anthropic" assert props["$ai_input_tokens"] > 0 + + +def _anthropic_stream_events(): + final = MockStreamEvent("message_delta") + final.usage = MockUsage( + input_tokens=10, + output_tokens=5, + cache_read_input_tokens=0, + cache_creation_input_tokens=0, + ) + return [ + MockStreamEvent("message_start"), + MockStreamEvent("content_block_delta", text="Hi"), + final, + ] + + +@pytest.mark.asyncio +async def test_async_messages_create_streaming_supports_async_with(mock_client): + """Regression test for #393: messages.create(stream=True) must support + `async with`.""" + + async def mock_async_create(**kwargs): + return RecordingAsyncStream(_anthropic_stream_events()) + + with patch( + "anthropic.resources.messages.AsyncMessages.create", + side_effect=mock_async_create, + ): + client = AsyncAnthropic(posthog_client=mock_client) + response = await client.messages.create( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "Foo"}], + stream=True, + max_tokens=1, + ) + + async with response as stream: + events = [event async for event in stream] + + assert len(events) == 3 + assert mock_client.capture.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_messages_streaming_early_exit_closes_provider_stream(mock_client): + """Breaking out early must close the underlying Anthropic stream and still + capture the event.""" + source = RecordingAsyncStream(_anthropic_stream_events()) + + async def mock_async_create(**kwargs): + return source + + with patch( + "anthropic.resources.messages.AsyncMessages.create", + side_effect=mock_async_create, + ): + client = AsyncAnthropic(posthog_client=mock_client) + response = await client.messages.create( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "Foo"}], + stream=True, + max_tokens=1, + ) + + async with response as stream: + async for _ in stream: + break + + assert source.closed is True + assert mock_client.capture.call_count == 1 diff --git a/posthog/test/ai/gemini/test_gemini_async.py b/posthog/test/ai/gemini/test_gemini_async.py index 2823626e..53f85b26 100644 --- a/posthog/test/ai/gemini/test_gemini_async.py +++ b/posthog/test/ai/gemini/test_gemini_async.py @@ -1110,3 +1110,43 @@ async def test_async_embed_content_integration_batch(mock_client): assert response.embeddings is not None assert len(response.embeddings) == len(inputs) + + +async def test_async_client_streaming_supports_async_with( + mock_client, mock_google_genai_client +): + """Regression test for #393: generate_content_stream must support `async with`.""" + + async def mock_streaming_response(): + chunk = MagicMock() + chunk.text = "Hi" + usage = MagicMock() + usage.prompt_token_count = 5 + usage.candidates_token_count = 3 + usage.cached_content_token_count = 0 + usage.thoughts_token_count = 0 + chunk.usage_metadata = usage + yield chunk + + mock_google_genai_client.aio.models.generate_content_stream = AsyncMock( + return_value=mock_streaming_response() + ) + + client = AsyncClient(api_key="test-key", posthog_client=mock_client) + + response = await client.models.generate_content_stream( + model="gemini-2.0-flash", + contents=["Hi"], + posthog_distinct_id="test-id", + ) + + chunks = [] + async with response as stream: + async for chunk in stream: + chunks.append(chunk) + + assert len(chunks) == 1 + assert mock_client.capture.call_count == 1 + call_args = mock_client.capture.call_args[1] + assert call_args["event"] == "$ai_generation" + assert call_args["properties"]["$ai_provider"] == "gemini" diff --git a/posthog/test/ai/openai/test_openai.py b/posthog/test/ai/openai/test_openai.py index 7049fb40..e75b429b 100644 --- a/posthog/test/ai/openai/test_openai.py +++ b/posthog/test/ai/openai/test_openai.py @@ -39,6 +39,7 @@ from posthog.ai.openai import OpenAI from posthog.ai.openai.openai_async import AsyncOpenAI from posthog.ai.openai.wrapper_utils import reset_fallback_warnings + from posthog.test.ai.utils import RecordingAsyncStream OPENAI_AVAILABLE = True except ImportError: @@ -2352,3 +2353,100 @@ def test_integration_stop_reason(mock_client): assert props["$ai_stop_reason"] in ("stop", "length") assert props["$ai_provider"] == "openai" assert props["$ai_input_tokens"] > 0 + + +@pytest.mark.asyncio +async def test_async_chat_streaming_supports_async_with( + mock_client, streaming_tool_call_chunks +): + """Regression test for #393: chat completions stream=True must support + `async with` (the protocol pydantic-ai relies on).""" + + async def mock_create(self, **kwargs): + return RecordingAsyncStream(streaming_tool_call_chunks) + + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", new=mock_create + ): + client = AsyncOpenAI(api_key="test-key", posthog_client=mock_client) + + response = await client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + stream=True, + posthog_distinct_id="test-id", + ) + + chunks = [] + async with response as stream: + async for chunk in stream: + chunks.append(chunk) + + assert chunks == streaming_tool_call_chunks + assert mock_client.capture.call_count == 1 + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "openai" + assert props["$ai_model"] == "gpt-4" + + +@pytest.mark.asyncio +async def test_async_responses_streaming_supports_async_with(mock_client): + """Regression test for #393: responses stream=True must support + `async with`.""" + from unittest.mock import MagicMock + + chunk = MagicMock() + chunk.type = "response.text.delta" + chunk.text = "hello" + + async def mock_create(self, **kwargs): + return RecordingAsyncStream([chunk]) + + with patch("openai.resources.responses.AsyncResponses.create", new=mock_create): + client = AsyncOpenAI(api_key="test-key", posthog_client=mock_client) + + response = await client.responses.create( + model="gpt-4o-mini", + input=[{"role": "user", "content": "Hi"}], + stream=True, + posthog_distinct_id="test-id", + ) + + async with response as stream: + received = [c async for c in stream] + + assert received == [chunk] + assert mock_client.capture.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_chat_streaming_early_exit_closes_provider_stream( + mock_client, streaming_tool_call_chunks +): + """Breaking out of the stream early must close the underlying provider + stream (release the HTTP connection) and still capture the event.""" + source = RecordingAsyncStream(streaming_tool_call_chunks) + + async def mock_create(self, **kwargs): + return source + + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", new=mock_create + ): + client = AsyncOpenAI(api_key="test-key", posthog_client=mock_client) + + response = await client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + stream=True, + posthog_distinct_id="test-id", + ) + + async with response as stream: + async for _ in stream: + break + + assert source.closed is True + assert mock_client.capture.call_count == 1 diff --git a/posthog/test/ai/test_async_stream_wrapper.py b/posthog/test/ai/test_async_stream_wrapper.py new file mode 100644 index 00000000..7cdba829 --- /dev/null +++ b/posthog/test/ai/test_async_stream_wrapper.py @@ -0,0 +1,143 @@ +"""Unit tests for AsyncStreamWrapper (no external SDKs required).""" + +import pytest + +from posthog.ai.stream import AsyncStreamWrapper +from posthog.test.ai.utils import RecordingAsyncStream + + +@pytest.mark.asyncio +async def test_async_for_iteration_still_works(): + async def gen(): + yield 1 + yield 2 + yield 3 + + wrapper = AsyncStreamWrapper(gen()) + assert [item async for item in wrapper] == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_async_with_yields_self_and_iterates(): + async def gen(): + yield "a" + yield "b" + + wrapper = AsyncStreamWrapper(gen()) + async with wrapper as stream: + assert stream is wrapper + assert [item async for item in stream] == ["a", "b"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("consume_all", [False, True]) +async def test_finally_block_runs_on_exit(consume_all): + captured = [] + + async def gen(): + try: + yield 1 + yield 2 + yield 3 + finally: + captured.append("done") + + async with AsyncStreamWrapper(gen()) as stream: + async for item in stream: + if not consume_all and item == 1: + break + + assert captured == ["done"] + + +@pytest.mark.asyncio +async def test_exit_closes_underlying_provider_stream(): + source = RecordingAsyncStream([1, 2, 3]) + + async def gen(): + async for item in source: + yield item + + async with AsyncStreamWrapper(gen(), source) as stream: + async for _ in stream: + break + + assert source.closed is True + + +@pytest.mark.asyncio +async def test_provider_stream_closed_even_if_generator_aclose_raises(): + source = RecordingAsyncStream([1, 2, 3]) + + async def gen(): + try: + async for item in source: + yield item + finally: + raise RuntimeError("capture blew up") + + with pytest.raises(RuntimeError, match="capture blew up"): + async with AsyncStreamWrapper(gen(), source) as stream: + async for _ in stream: + break + + assert source.closed is True + + +@pytest.mark.asyncio +async def test_exception_in_body_propagates(): + source = RecordingAsyncStream([1, 2, 3]) + + async def gen(): + async for item in source: + yield item + + with pytest.raises(ValueError, match="boom"): + async with AsyncStreamWrapper(gen(), source) as stream: + async for _ in stream: + raise ValueError("boom") + + assert source.closed is True + + +@pytest.mark.asyncio +async def test_getattr_proxies_to_provider_stream(): + source = RecordingAsyncStream([]) + + async def gen(): + if False: + yield # make this an async generator + + wrapper = AsyncStreamWrapper(gen(), source) + assert wrapper.response == "provider-response" + + +@pytest.mark.asyncio +async def test_aclose_runs_generator_finally_and_captures(): + source = RecordingAsyncStream([1, 2, 3]) + captured = [] + + async def gen(): + try: + async for item in source: + yield item + finally: + captured.append("done") + + wrapper = AsyncStreamWrapper(gen(), source) + await wrapper.__anext__() + await wrapper.aclose() + + assert captured == ["done"] + + +@pytest.mark.asyncio +async def test_getattr_does_not_proxy_private_names(): + source = RecordingAsyncStream([]) + + async def gen(): + if False: + yield + + wrapper = AsyncStreamWrapper(gen(), source) + assert not hasattr(wrapper, "_nonexistent_private") diff --git a/posthog/test/ai/utils.py b/posthog/test/ai/utils.py new file mode 100644 index 00000000..99fc3f02 --- /dev/null +++ b/posthog/test/ai/utils.py @@ -0,0 +1,27 @@ +"""Shared test helpers for the AI wrapper test suites.""" + + +class RecordingAsyncStream: + """Mock provider async stream that is iterable and records when closed. + + Mirrors the real ``openai.AsyncStream`` / ``anthropic.AsyncStream``: it + supports ``async for`` and exposes an async ``close()`` plus a ``response`` + attribute, so tests can assert both iteration and that the underlying + stream is closed on context exit. + """ + + def __init__(self, items): + self._items = list(items) + self.closed = False + self.response = "provider-response" + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + async def close(self): + self.closed = True