From 28fcf5d160006d342b779c7acbfaf8a45b4f954b Mon Sep 17 00:00:00 2001 From: Antawari Date: Wed, 27 May 2026 16:51:49 -0600 Subject: [PATCH] onboard: per-instance asyncio.Lock + typed lifecycle exceptions on ConversationEngine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two defects rooted at the same class: 1. handle_answer awaits emit() (which awaits broadcast() → asyncio.gather) before incrementing _turn. A second WS message arriving while the first handle_answer is suspended inside the first emit reads the same stale _turn — the analyzer fires twice for the same question, _turn double- advances, and the question-emission sequence can skip Q2 or Q3. With three answers required for completion, a double-click can leave the engine permanently broken for that session. 2. Lifecycle violations (handle_answer before start, handle_answer after completion) raised bare RuntimeError. The WS handler's `except Exception` swallowed them without producing a typed error frame the browser could surface to the user. ## Fix src/bonfire/onboard/conversation.py: - Added `_lock: asyncio.Lock = field(default_factory=asyncio.Lock)` to ConversationEngine. Each instance gets its own Lock — `default_factory` runs per instance, not per class definition, so unrelated conversations don't serialize against each other. - handle_answer now wraps its full body in `async with self._lock:` — the await points inside the body can't be raced by a sibling call. - Defined two typed exceptions, both subclassing RuntimeError for backward-compat with existing bare-`except RuntimeError` catchers upstream in the WS handler: * ConversationNotStarted (handle_answer called before start) * ConversationAlreadyComplete (handle_answer called after Q3 answered) - Updated `__all__` to export the new exception types. ## Tests tests/unit/test_onboard_conversation_concurrency.py (new · Knight RED): Seven tests across three classes: - TestLockPresence (2): asyncio.Lock attribute exists; per-instance not shared. - TestHandleAnswerAcquiresLock (2): * handle_answer blocks when the lock is held externally (asserts via asyncio.wait_for timeout). * two concurrent handle_answer calls emit questions in strict Q1→Q2→Q3 order (pre-fix this races and skips questions). - TestTypedLifecycleExceptions (3): * handle_answer before start raises ConversationNotStarted. * handle_answer after completion raises ConversationAlreadyComplete. * both typed exceptions subclass RuntimeError (back-compat guard). ## Out of scope (filed for follow-up PR) WS handler integration — catching ConversationNotStarted / ConversationAlreadyComplete specifically in src/bonfire/onboard/server.py and emitting a typed error frame to the browser. Deferred to avoid file overlap with the in-flight front-door hardening PR which also touches server.py. ## Verification pytest tests/unit/test_onboard_conversation_concurrency.py 7 passed (Knight RED → GREEN verified) pytest tests/unit/test_onboard_server.py (regression) 19 passed ruff check + format on changed files: clean Co-Authored-By: Claude Opus 4.7 (1M context) --- src/bonfire/onboard/conversation.py | 130 ++++++---- .../test_onboard_conversation_concurrency.py | 228 ++++++++++++++++++ 2 files changed, 314 insertions(+), 44 deletions(-) create mode 100644 tests/unit/test_onboard_conversation_concurrency.py diff --git a/src/bonfire/onboard/conversation.py b/src/bonfire/onboard/conversation.py index 45bf602..d5f374b 100644 --- a/src/bonfire/onboard/conversation.py +++ b/src/bonfire/onboard/conversation.py @@ -10,6 +10,7 @@ from __future__ import annotations +import asyncio import re from dataclasses import dataclass, field from typing import TYPE_CHECKING @@ -23,7 +24,26 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable -__all__ = ["ConversationEngine"] +__all__ = [ + "ConversationAlreadyComplete", + "ConversationEngine", + "ConversationNotStarted", +] + + +# --------------------------------------------------------------------------- +# Typed lifecycle exceptions (subclass RuntimeError for back-compat with +# the pre-fix bare-RuntimeError catch sites — existing `except RuntimeError` +# blocks continue to catch these unchanged). +# --------------------------------------------------------------------------- + + +class ConversationNotStarted(RuntimeError): + """Raised when ``handle_answer`` is called before ``start()``.""" + + +class ConversationAlreadyComplete(RuntimeError): + """Raised when ``handle_answer`` is called after all 3 questions are answered.""" # --------------------------------------------------------------------------- @@ -371,10 +391,20 @@ def _analyze_q3(text: str) -> tuple[str, dict[str, str]]: @dataclass class ConversationEngine: - """Scripted 3-question conversation for profiling.""" + """Scripted 3-question conversation for profiling. + + ``handle_answer`` acquires ``_lock`` (an ``asyncio.Lock``) for its full + body so back-to-back WS messages that arrive while a prior ``emit(...)`` + is suspended in ``broadcast(...)`` cannot interleave reads of ``_turn`` + and corrupt the question-emission sequence. + """ _turn: int = 0 # 0=not started, 1-3=waiting for answer to Q1-Q3 _profile: dict[str, str] = field(default_factory=dict) + # Per-instance asyncio.Lock — serializes handle_answer calls so the + # await on emit() inside the handler can't be raced by a second call + # that lands on the same WS connection while the first is suspended. + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) @property def is_complete(self) -> bool: @@ -400,54 +430,66 @@ async def handle_answer( text: str, emit: Callable[[FrontDoorMessage], Awaitable[None]], ) -> None: - """Process answer: analyze, reflect, ask next or finish.""" - if self._turn == 0: - msg = "Cannot handle answer before start() has been called." - raise RuntimeError(msg) - if self._turn > 3: - msg = "Conversation is already complete." - raise RuntimeError(msg) - - question_index = self._turn - 1 # 0-based - - # Short answer detection - stripped = text.strip() - word_count = len(stripped.split()) if stripped else 0 - - if word_count < _SHORT_THRESHOLD: - reflection_text = _BRIEF_REFLECTION - profile_update: dict[str, str] = {} - else: - analyzer = _ANALYZERS[question_index] - reflection_text, profile_update = analyzer(stripped) - - # Emit reflection - await emit( - FalcorMessage( - text=reflection_text, - subtype="reflection", + """Process answer: analyze, reflect, ask next or finish. + + Acquires ``self._lock`` for the full body — a second concurrent + call blocks until the first releases, so the ``await emit(...)`` + suspension points inside the body can't be raced by a sibling call + that reads stale ``_turn`` and double-advances past the same question. + + Lifecycle violations raise typed exceptions (``ConversationNotStarted``, + ``ConversationAlreadyComplete``) that subclass ``RuntimeError`` so + existing bare-``RuntimeError`` catchers in the WS handler continue + to work while typed handlers can catch the specific lifecycle case. + """ + async with self._lock: + if self._turn == 0: + msg = "Cannot handle answer before start() has been called." + raise ConversationNotStarted(msg) + if self._turn > 3: + msg = "Conversation is already complete." + raise ConversationAlreadyComplete(msg) + + question_index = self._turn - 1 # 0-based + + # Short answer detection + stripped = text.strip() + word_count = len(stripped.split()) if stripped else 0 + + if word_count < _SHORT_THRESHOLD: + reflection_text = _BRIEF_REFLECTION + profile_update: dict[str, str] = {} + else: + analyzer = _ANALYZERS[question_index] + reflection_text, profile_update = analyzer(stripped) + + # Emit reflection + await emit( + FalcorMessage( + text=reflection_text, + subtype="reflection", + ) ) - ) - # Accumulate profile - for k, v in profile_update.items(): - self._profile[k] = v + # Accumulate profile + for k, v in profile_update.items(): + self._profile[k] = v - # Advance turn - self._turn += 1 + # Advance turn + self._turn += 1 - # Ask next question if not done - if self._turn <= 3: - await emit( - FalcorMessage( - text=_QUESTIONS[self._turn - 1], - subtype="question", + # Ask next question if not done + if self._turn <= 3: + await emit( + FalcorMessage( + text=_QUESTIONS[self._turn - 1], + subtype="question", + ) ) - ) - # If complete, ensure all expected keys have defaults - if self._turn > 3: - self._ensure_complete_profile() + # If complete, ensure all expected keys have defaults + if self._turn > 3: + self._ensure_complete_profile() def _ensure_complete_profile(self) -> None: """Fill in any missing profile keys with sensible defaults.""" diff --git a/tests/unit/test_onboard_conversation_concurrency.py b/tests/unit/test_onboard_conversation_concurrency.py new file mode 100644 index 0000000..a6ca56c --- /dev/null +++ b/tests/unit/test_onboard_conversation_concurrency.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 BonfireAI + +"""Knight RED tests — ConversationEngine concurrency safety + typed lifecycle exceptions. + +The pre-fix ConversationEngine had no per-handler lock and raised bare +``RuntimeError`` for lifecycle violations. Two defects rooted at the same +class: + +1. **Back-to-back ``handle_answer`` calls corrupt ``_turn``.** ``handle_answer`` + awaits ``emit(...)`` (which awaits ``broadcast(...)`` on the server) BEFORE + incrementing ``_turn``. If a second call lands while the first is suspended + inside the first emit, both calls read the same stale ``_turn``, the analyzer + fires twice for the same question, and the question-emission sequence + skips a question. With three answers required for completion, a single + double-click can push the engine into ``is_complete=True`` after Q2, then + subsequent legitimate answers raise the bare ``RuntimeError`` and the WS + handler silently logs without recovery — the engine is permanently broken + for that session. + +2. **Lifecycle violations raise bare ``RuntimeError``.** ``handle_answer`` + before ``start()`` and ``handle_answer`` after completion both raise + ``RuntimeError("...")``. The WS handler's ``except Exception`` swallows + them without producing a typed error frame the browser can show. The + ticket calls out a typed exception that the WS handler can catch + specifically and respond to with a typed error frame. + +This Knight pins: + +- ``ConversationEngine`` exposes an ``_lock`` attribute (``asyncio.Lock``). +- ``handle_answer`` acquires the lock for its full body — a second call + blocks until the first releases. +- Calling ``handle_answer`` before ``start()`` raises ``ConversationNotStarted`` + (subclass of ``RuntimeError`` for backward-compat with existing catchers). +- Calling ``handle_answer`` after completion raises ``ConversationAlreadyComplete`` + (subclass of ``RuntimeError``). +- The legacy bare-``RuntimeError`` catch path still works (typed exceptions + subclass ``RuntimeError`` so existing ``except RuntimeError`` blocks + upstream in the WS handler still catch them). + +Out of scope (filed for follow-up PR to avoid file overlap with the +in-flight front-door hardening PR): + +- WS handler integration — catching ``ConversationNotStarted`` / + ``ConversationAlreadyComplete`` specifically and emitting a typed + error frame to the browser. Touches ``server.py`` + ``flow.py``. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from bonfire.onboard.conversation import ( + ConversationAlreadyComplete, + ConversationEngine, + ConversationNotStarted, +) +from bonfire.onboard.protocol import FrontDoorMessage + + +async def _noop_emit(_msg: FrontDoorMessage) -> None: + """Emit callback that does nothing — for tests not asserting emission shape.""" + + +# --------------------------------------------------------------------------- +# Lock presence + acquire semantics +# --------------------------------------------------------------------------- + + +class TestLockPresence: + """``ConversationEngine`` must have an ``asyncio.Lock`` instance attribute.""" + + def test_engine_has_lock_attribute(self) -> None: + engine = ConversationEngine() + assert isinstance(engine._lock, asyncio.Lock), ( + "ConversationEngine must expose an asyncio.Lock as _lock; " + "the per-handler lock is the concurrency-safety contract" + ) + + def test_lock_is_per_instance_not_shared(self) -> None: + """Two engines have independent locks (defaults aren't shared).""" + e1 = ConversationEngine() + e2 = ConversationEngine() + assert e1._lock is not e2._lock, ( + "Each ConversationEngine instance must have its own asyncio.Lock — " + "shared default-factory output between instances would serialize " + "unrelated conversations" + ) + + +class TestHandleAnswerAcquiresLock: + """``handle_answer`` must block while the lock is held externally.""" + + async def test_handle_answer_blocks_when_lock_held(self) -> None: + """If something else holds ``engine._lock``, ``handle_answer`` waits.""" + engine = ConversationEngine() + await engine.start(_noop_emit) + assert engine._turn == 1 + + # Hold the lock externally; handle_answer should block. Use a short + # timeout to assert blocking behavior without hanging the test. + async with engine._lock: + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for( + engine.handle_answer( + "a sufficiently long answer to trigger the analyzer path", + _noop_emit, + ), + timeout=0.3, + ) + + # After the external hold releases, the same call should succeed. + # (Construct a fresh call since the previous one was cancelled by + # the timeout.) + await engine.handle_answer( + "a sufficiently long answer to trigger the analyzer path", + _noop_emit, + ) + assert engine._turn == 2, ( + "handle_answer should resume after the lock releases and advance _turn" + ) + + async def test_back_to_back_handle_answer_calls_serialize(self) -> None: + """Two concurrent ``handle_answer`` calls fire questions in order, not interleaved. + + Pre-fix: both calls read the same stale ``_turn``, the question-emission + sequence races, and questions can be skipped. Post-fix: the lock + serializes the call bodies, so emissions land in their natural Q1→Q2→Q3 + order. + """ + engine = ConversationEngine() + emitted: list[FrontDoorMessage] = [] + + async def recording_emit(msg: FrontDoorMessage) -> None: + # Force a yield to the event loop on every emit, mirroring the + # broadcast()→asyncio.gather() suspension shape the ticket cites. + await asyncio.sleep(0) + emitted.append(msg) + + await engine.start(recording_emit) + + # Fire two answers concurrently — without the lock, the call bodies + # interleave at the first await emit() and the question-emission + # sequence races. + await asyncio.gather( + engine.handle_answer( + "first answer with enough words to trigger the analyzer", + recording_emit, + ), + engine.handle_answer( + "second answer with enough words to trigger the analyzer", + recording_emit, + ), + ) + + # _turn should have advanced exactly twice (Q1 → Q2 → Q3 waiting). + assert engine._turn == 3, ( + f"After two answers, _turn should be 3; got {engine._turn} " + "(race condition: both calls saw stale _turn or double-incremented)" + ) + + # Extract just the question-shaped emissions (start + each handle_answer + # emits a reflection + the next question; we only assert on questions + # to keep the test resilient to reflection-text variations). + question_texts = [ + m.text # type: ignore[attr-defined] + for m in emitted + if getattr(m, "subtype", None) == "question" + ] + + # Pre-fix race: questions can be emitted out of order or skipped. + # Post-fix: Q1 (from start), Q2 (from call A), Q3 (from call B) — in + # strict ascending order. + assert len(question_texts) == 3, ( + f"Expected 3 questions emitted; got {len(question_texts)}: {question_texts}" + ) + # The questions are unique by content; assert ascending-position order + # by checking that no question is emitted before a higher-indexed one. + from bonfire.onboard.conversation import _QUESTIONS + + expected_order = list(_QUESTIONS[:3]) + assert question_texts == expected_order, ( + f"Questions emitted out of order under concurrent handle_answer: " + f"got {question_texts!r}, expected {expected_order!r}" + ) + + +# --------------------------------------------------------------------------- +# Typed lifecycle exceptions +# --------------------------------------------------------------------------- + + +class TestTypedLifecycleExceptions: + """``handle_answer`` raises typed exceptions, not bare ``RuntimeError``.""" + + async def test_handle_answer_before_start_raises_typed_not_started(self) -> None: + """Calling ``handle_answer`` before ``start()`` raises ``ConversationNotStarted``.""" + engine = ConversationEngine() + with pytest.raises(ConversationNotStarted): + await engine.handle_answer("anything", _noop_emit) + + async def test_handle_answer_after_complete_raises_typed_already_complete( + self, + ) -> None: + """Calling ``handle_answer`` after all 3 answers raises ``ConversationAlreadyComplete``.""" + engine = ConversationEngine() + await engine.start(_noop_emit) + # Provide three answers to drive the engine to completion. + for i in range(3): + await engine.handle_answer( + f"answer {i} with enough words to satisfy the analyzer", + _noop_emit, + ) + assert engine.is_complete is True + with pytest.raises(ConversationAlreadyComplete): + await engine.handle_answer("one too many", _noop_emit) + + async def test_typed_exceptions_subclass_runtimeerror_for_backcompat(self) -> None: + """Typed exceptions must subclass ``RuntimeError`` so existing catchers still work. + + The WS handler's ``except Exception`` already catches these, but any + upstream code that specifically did ``except RuntimeError`` (the + pre-fix exception type) must continue to work without modification. + """ + assert issubclass(ConversationNotStarted, RuntimeError) + assert issubclass(ConversationAlreadyComplete, RuntimeError)