Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .sampo/changesets/async-stream-context-manager.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
pypi/posthog: patch
---

Fix async streaming responses from the AI wrappers (OpenAI, Anthropic, Gemini) so they support `async with` as well as `async for`. Previously, consuming a stream via `async with` (e.g. with pydantic-ai) raised `TypeError: 'async_generator' object does not support the asynchronous context manager protocol`.
3 changes: 2 additions & 1 deletion posthog/ai/anthropic/anthropic_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Dict, List, Optional

from posthog import setup
from posthog.ai.stream import AsyncStreamWrapper
from posthog.ai.types import StreamingContentBlock, TokenUsage, ToolInProgress
from posthog.ai.utils import (
call_llm_and_track_usage_async,
Expand Down Expand Up @@ -225,7 +226,7 @@ async def generator():
stop_reason=stop_reason,
)

return generator()
return AsyncStreamWrapper(generator(), stream=response)

async def _capture_streaming_event(
self,
Expand Down
3 changes: 2 additions & 1 deletion posthog/ai/gemini/gemini_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from typing import Any, Dict, Optional

from posthog.ai.stream import AsyncStreamWrapper
from posthog.ai.types import TokenUsage, StreamingEventData
from posthog.ai.utils import merge_system_prompt

Expand Down Expand Up @@ -354,7 +355,7 @@ async def async_generator():
stop_reason=stop_reason,
)

return async_generator()
return AsyncStreamWrapper(async_generator(), stream=response)

def _capture_streaming_event(
self,
Expand Down
5 changes: 3 additions & 2 deletions posthog/ai/openai/openai_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import uuid
from typing import Any, Dict, List, Optional

from posthog.ai.stream import AsyncStreamWrapper
from posthog.ai.types import TokenUsage

try:
Expand Down Expand Up @@ -221,7 +222,7 @@ async def async_generator():
stop_reason=stop_reason,
)

return async_generator()
return AsyncStreamWrapper(async_generator(), stream=response)

async def _capture_streaming_event(
self,
Expand Down Expand Up @@ -515,7 +516,7 @@ async def async_generator():
stop_reason=stop_reason,
)

return async_generator()
return AsyncStreamWrapper(async_generator(), stream=response)

async def _capture_streaming_event(
self,
Expand Down
62 changes: 62 additions & 0 deletions posthog/ai/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Shared async streaming utilities for PostHog AI wrappers."""

from typing import Any, AsyncGenerator, Generic, Optional, TypeVar

T = TypeVar("T")


class AsyncStreamWrapper(Generic[T]):
"""Adds the async context manager protocol to a PostHog streaming generator.

The OpenAI and Anthropic SDK streams support both ``async for`` and
``async with``. PostHog's wrappers returned a bare async generator, which
only supports ``async for``, so ``async with response:`` (used by
pydantic-ai) raised a TypeError. This wraps the tracking generator and,
when given the original provider stream, closes it and proxies attribute
access (e.g. ``.response``) to it.
"""

def __init__(
self,
generator: AsyncGenerator[T, None],
stream: Optional[Any] = None,
) -> None:
self._generator = generator
self._stream = stream

def __aiter__(self) -> "AsyncStreamWrapper[T]":
return self

async def __anext__(self) -> T:
return await self._generator.__anext__()

async def __aenter__(self) -> "AsyncStreamWrapper[T]":
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
# Close the generator first so its `finally` captures the event, even on
# early exit. try/finally still closes the provider stream if that raises.
try:
await self._generator.aclose()
finally:
if self._stream is not None:
close = getattr(self._stream, "aclose", None) or getattr(
self._stream, "close", None
)
if close is not None:
await close()

return False
Comment thread
turnipdabeets marked this conversation as resolved.

# aclose/asend/athrow belong to the generator; provider streams expose
# close(), not these. Forwarding aclose() keeps it firing the event.
_GENERATOR_METHODS = ("aclose", "asend", "athrow")

def __getattr__(self, name: str) -> Any:
# Proxy only public attributes (e.g. `.response`) to the provider stream.
if name.startswith("_"):
raise AttributeError(name)
if name in self._GENERATOR_METHODS:
return getattr(self._generator, name)
target = self._stream if self._stream is not None else self._generator
return getattr(target, name)
72 changes: 72 additions & 0 deletions posthog/test/ai/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anthropic.types import Message, Usage

from posthog.ai.anthropic import Anthropic, AsyncAnthropic
from posthog.test.ai.utils import RecordingAsyncStream

ANTHROPIC_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -1421,3 +1422,74 @@ def test_integration_stop_reason(mock_client):
assert props["$ai_stop_reason"] in ("end_turn", "max_tokens")
assert props["$ai_provider"] == "anthropic"
assert props["$ai_input_tokens"] > 0


def _anthropic_stream_events():
final = MockStreamEvent("message_delta")
final.usage = MockUsage(
input_tokens=10,
output_tokens=5,
cache_read_input_tokens=0,
cache_creation_input_tokens=0,
)
return [
MockStreamEvent("message_start"),
MockStreamEvent("content_block_delta", text="Hi"),
final,
]


@pytest.mark.asyncio
async def test_async_messages_create_streaming_supports_async_with(mock_client):
"""Regression test for #393: messages.create(stream=True) must support
`async with`."""

async def mock_async_create(**kwargs):
return RecordingAsyncStream(_anthropic_stream_events())

with patch(
"anthropic.resources.messages.AsyncMessages.create",
side_effect=mock_async_create,
):
client = AsyncAnthropic(posthog_client=mock_client)
response = await client.messages.create(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "Foo"}],
stream=True,
max_tokens=1,
)

async with response as stream:
events = [event async for event in stream]

assert len(events) == 3
assert mock_client.capture.call_count == 1


@pytest.mark.asyncio
async def test_async_messages_streaming_early_exit_closes_provider_stream(mock_client):
"""Breaking out early must close the underlying Anthropic stream and still
capture the event."""
source = RecordingAsyncStream(_anthropic_stream_events())

async def mock_async_create(**kwargs):
return source

with patch(
"anthropic.resources.messages.AsyncMessages.create",
side_effect=mock_async_create,
):
client = AsyncAnthropic(posthog_client=mock_client)
response = await client.messages.create(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "Foo"}],
stream=True,
max_tokens=1,
)

async with response as stream:
async for _ in stream:
break

assert source.closed is True
assert mock_client.capture.call_count == 1
40 changes: 40 additions & 0 deletions posthog/test/ai/gemini/test_gemini_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,3 +1110,43 @@ async def test_async_embed_content_integration_batch(mock_client):

assert response.embeddings is not None
assert len(response.embeddings) == len(inputs)


async def test_async_client_streaming_supports_async_with(
mock_client, mock_google_genai_client
):
"""Regression test for #393: generate_content_stream must support `async with`."""

async def mock_streaming_response():
chunk = MagicMock()
chunk.text = "Hi"
usage = MagicMock()
usage.prompt_token_count = 5
usage.candidates_token_count = 3
usage.cached_content_token_count = 0
usage.thoughts_token_count = 0
chunk.usage_metadata = usage
yield chunk

mock_google_genai_client.aio.models.generate_content_stream = AsyncMock(
return_value=mock_streaming_response()
)

client = AsyncClient(api_key="test-key", posthog_client=mock_client)

response = await client.models.generate_content_stream(
model="gemini-2.0-flash",
contents=["Hi"],
posthog_distinct_id="test-id",
)

chunks = []
async with response as stream:
async for chunk in stream:
chunks.append(chunk)

assert len(chunks) == 1
assert mock_client.capture.call_count == 1
call_args = mock_client.capture.call_args[1]
assert call_args["event"] == "$ai_generation"
assert call_args["properties"]["$ai_provider"] == "gemini"
98 changes: 98 additions & 0 deletions posthog/test/ai/openai/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from posthog.ai.openai import OpenAI
from posthog.ai.openai.openai_async import AsyncOpenAI
from posthog.ai.openai.wrapper_utils import reset_fallback_warnings
from posthog.test.ai.utils import RecordingAsyncStream

OPENAI_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -2352,3 +2353,100 @@ def test_integration_stop_reason(mock_client):
assert props["$ai_stop_reason"] in ("stop", "length")
assert props["$ai_provider"] == "openai"
assert props["$ai_input_tokens"] > 0


@pytest.mark.asyncio
async def test_async_chat_streaming_supports_async_with(
mock_client, streaming_tool_call_chunks
):
"""Regression test for #393: chat completions stream=True must support
`async with` (the protocol pydantic-ai relies on)."""

async def mock_create(self, **kwargs):
return RecordingAsyncStream(streaming_tool_call_chunks)

with patch(
"openai.resources.chat.completions.AsyncCompletions.create", new=mock_create
):
client = AsyncOpenAI(api_key="test-key", posthog_client=mock_client)

response = await client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hi"}],
stream=True,
posthog_distinct_id="test-id",
)

chunks = []
async with response as stream:
async for chunk in stream:
chunks.append(chunk)

assert chunks == streaming_tool_call_chunks
assert mock_client.capture.call_count == 1
call_args = mock_client.capture.call_args[1]
props = call_args["properties"]
assert call_args["event"] == "$ai_generation"
assert props["$ai_provider"] == "openai"
assert props["$ai_model"] == "gpt-4"


@pytest.mark.asyncio
async def test_async_responses_streaming_supports_async_with(mock_client):
"""Regression test for #393: responses stream=True must support
`async with`."""
from unittest.mock import MagicMock

chunk = MagicMock()
chunk.type = "response.text.delta"
chunk.text = "hello"

async def mock_create(self, **kwargs):
return RecordingAsyncStream([chunk])

with patch("openai.resources.responses.AsyncResponses.create", new=mock_create):
client = AsyncOpenAI(api_key="test-key", posthog_client=mock_client)

response = await client.responses.create(
model="gpt-4o-mini",
input=[{"role": "user", "content": "Hi"}],
stream=True,
posthog_distinct_id="test-id",
)

async with response as stream:
received = [c async for c in stream]

assert received == [chunk]
assert mock_client.capture.call_count == 1


@pytest.mark.asyncio
async def test_async_chat_streaming_early_exit_closes_provider_stream(
mock_client, streaming_tool_call_chunks
):
"""Breaking out of the stream early must close the underlying provider
stream (release the HTTP connection) and still capture the event."""
source = RecordingAsyncStream(streaming_tool_call_chunks)

async def mock_create(self, **kwargs):
return source

with patch(
"openai.resources.chat.completions.AsyncCompletions.create", new=mock_create
):
client = AsyncOpenAI(api_key="test-key", posthog_client=mock_client)

response = await client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hi"}],
stream=True,
posthog_distinct_id="test-id",
)

async with response as stream:
async for _ in stream:
break

assert source.closed is True
assert mock_client.capture.call_count == 1
Loading
Loading