From f2bdedbe3e5c4ab1a1c5896c2d85c874c9eb1f01 Mon Sep 17 00:00:00 2001 From: Arturo Bautista Date: Wed, 1 Apr 2026 13:40:24 -0600 Subject: [PATCH 1/5] feat: multi-directional agent orchestration with calls, messages, and mailboxes Add communication network infrastructure enabling bidirectional, topology-rich agent communication. External agents can now proactively send messages back via callback URLs (reply_url pattern), solving the invocability asymmetry where only orchestrators could initiate. Phase 1 - Network foundation: CommunicationNetwork, NetworkParticipant, NetworkMessage models with Redis-backed context accumulation. Phase 2 - Three communication channels: synchronous calls, near-real-time messages (webhook push), and async mailboxes (polling). Callback endpoint enables external agents to push messages into the network. Phase 3 - Complex topologies: loop steps with convergence detection (similarity, approval, max iterations), fan-in aggregation (merge, vote, LLM summarize), and topology validation (mesh, star, ring). Phase 4 - A2A protocol alignment: Agent Card generation, JSON-RPC protocol adapter, and A2A-compatible endpoints for interoperability with Google's Agent-to-Agent protocol. Co-Authored-By: Claude Opus 4.6 (1M context) --- ...6_04_01_0001-add_communication_networks.py | 178 +++++++ src/core/settings.py | 7 + src/main.py | 47 +- src/models/__init__.py | 11 + src/network/__init__.py | 0 src/network/a2a/__init__.py | 0 src/network/a2a/agent_card.py | 127 +++++ src/network/a2a/protocol.py | 147 ++++++ src/network/a2a/routes.py | 174 +++++++ src/network/models/__init__.py | 0 src/network/models/entities.py | 163 +++++++ src/network/models/schemas.py | 116 +++++ src/network/repositories/__init__.py | 0 src/network/repositories/networks.py | 182 +++++++ src/network/routes/__init__.py | 0 src/network/routes/callbacks.py | 62 +++ src/network/routes/channels.py | 152 ++++++ src/network/routes/networks.py | 174 +++++++ src/network/services/__init__.py | 0 src/network/services/channels.py | 446 ++++++++++++++++++ src/network/services/networks.py | 219 +++++++++ src/network/utils/__init__.py | 0 src/network/utils/aggregator.py | 126 +++++ src/network/utils/context_manager.py | 76 +++ src/network/utils/convergence.py | 146 ++++++ src/network/utils/delivery_worker.py | 157 ++++++ src/network/utils/topology.py | 104 ++++ src/workflow/models/dsl.py | 52 +- src/workflow/utils/dsl_parser.py | 19 + src/workflow/utils/orchestrator.py | 184 ++++++++ 30 files changed, 3036 insertions(+), 33 deletions(-) create mode 100644 alembic/versions/2026_04_01_0001-add_communication_networks.py create mode 100644 src/network/__init__.py create mode 100644 src/network/a2a/__init__.py create mode 100644 src/network/a2a/agent_card.py create mode 100644 src/network/a2a/protocol.py create mode 100644 src/network/a2a/routes.py create mode 100644 src/network/models/__init__.py create mode 100644 src/network/models/entities.py create mode 100644 src/network/models/schemas.py create mode 100644 src/network/repositories/__init__.py create mode 100644 src/network/repositories/networks.py create mode 100644 src/network/routes/__init__.py create mode 100644 src/network/routes/callbacks.py create mode 100644 src/network/routes/channels.py create mode 100644 src/network/routes/networks.py create mode 100644 src/network/services/__init__.py create mode 100644 src/network/services/channels.py create mode 100644 src/network/services/networks.py create mode 100644 src/network/utils/__init__.py create mode 100644 src/network/utils/aggregator.py create mode 100644 src/network/utils/context_manager.py create mode 100644 src/network/utils/convergence.py create mode 100644 src/network/utils/delivery_worker.py create mode 100644 src/network/utils/topology.py diff --git a/alembic/versions/2026_04_01_0001-add_communication_networks.py b/alembic/versions/2026_04_01_0001-add_communication_networks.py new file mode 100644 index 0000000..16a2220 --- /dev/null +++ b/alembic/versions/2026_04_01_0001-add_communication_networks.py @@ -0,0 +1,178 @@ +"""add communication networks, participants, and messages tables + +Revision ID: add_communication_networks +Revises: add_supports_streaming +Create Date: 2026-04-01 00:01:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# revision identifiers, used by Alembic. +revision = "add_communication_networks" +down_revision = "add_supports_streaming" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create enum types + op.execute( + "CREATE TYPE topologytype AS ENUM ('mesh', 'star', 'ring', 'custom')" + ) + op.execute( + "CREATE TYPE networkstatus AS ENUM ('active', 'paused', 'closed')" + ) + op.execute( + "CREATE TYPE participanttype AS ENUM ('agent', 'persona', 'orchestrator')" + ) + op.execute( + "CREATE TYPE participantstatus AS ENUM ('active', 'disconnected', 'removed')" + ) + op.execute( + "CREATE TYPE channeltype AS ENUM ('call', 'message', 'mailbox')" + ) + op.execute( + "CREATE TYPE messagestatus AS ENUM ('pending', 'delivered', 'read', 'failed')" + ) + + # Communication networks + op.create_table( + "communication_networks", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column("owner_id", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column( + "topology_type", + sa.Enum("mesh", "star", "ring", "custom", name="topologytype", create_type=False), + nullable=False, + server_default="mesh", + ), + sa.Column("metadata", JSONB, nullable=True), + sa.Column( + "status", + sa.Enum("active", "paused", "closed", name="networkstatus", create_type=False), + nullable=False, + server_default="active", + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + op.create_index("ix_communication_networks_owner_id", "communication_networks", ["owner_id"]) + + # Network participants + op.create_table( + "network_participants", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "network_id", + UUID(as_uuid=True), + sa.ForeignKey("communication_networks.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("agent_id", UUID(as_uuid=True), sa.ForeignKey("agents.id"), nullable=True), + sa.Column( + "participant_type", + sa.Enum("agent", "persona", "orchestrator", name="participanttype", create_type=False), + nullable=False, + server_default="agent", + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("callback_url", sa.Text, nullable=True), + sa.Column("polling_enabled", sa.Boolean, nullable=False, server_default="false"), + sa.Column("capabilities", JSONB, nullable=True), + sa.Column( + "status", + sa.Enum("active", "disconnected", "removed", name="participantstatus", create_type=False), + nullable=False, + server_default="active", + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + op.create_index("ix_network_participants_network_id", "network_participants", ["network_id"]) + op.create_index("ix_network_participants_agent_id", "network_participants", ["agent_id"]) + + # Network messages + op.create_table( + "network_messages", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "network_id", + UUID(as_uuid=True), + sa.ForeignKey("communication_networks.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "sender_participant_id", + UUID(as_uuid=True), + sa.ForeignKey("network_participants.id"), + nullable=False, + ), + sa.Column( + "recipient_participant_id", + UUID(as_uuid=True), + sa.ForeignKey("network_participants.id"), + nullable=True, + ), + sa.Column( + "channel_type", + sa.Enum("call", "message", "mailbox", name="channeltype", create_type=False), + nullable=False, + ), + sa.Column("content", sa.Text, nullable=False), + sa.Column("metadata", JSONB, nullable=True), + sa.Column( + "status", + sa.Enum("pending", "delivered", "read", "failed", name="messagestatus", create_type=False), + nullable=False, + server_default="pending", + ), + sa.Column( + "in_reply_to_id", + UUID(as_uuid=True), + sa.ForeignKey("network_messages.id"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + op.create_index("ix_network_messages_network_id", "network_messages", ["network_id"]) + op.create_index( + "ix_network_messages_sender", "network_messages", ["sender_participant_id"] + ) + op.create_index( + "ix_network_messages_recipient", "network_messages", ["recipient_participant_id"] + ) + op.create_index( + "ix_network_messages_created_at", "network_messages", ["network_id", "created_at"] + ) + + +def downgrade() -> None: + op.drop_table("network_messages") + op.drop_table("network_participants") + op.drop_table("communication_networks") + + op.execute("DROP TYPE IF EXISTS messagestatus") + op.execute("DROP TYPE IF EXISTS channeltype") + op.execute("DROP TYPE IF EXISTS participantstatus") + op.execute("DROP TYPE IF EXISTS participanttype") + op.execute("DROP TYPE IF EXISTS networkstatus") + op.execute("DROP TYPE IF EXISTS topologytype") diff --git a/src/core/settings.py b/src/core/settings.py index a5723d6..097e7cc 100644 --- a/src/core/settings.py +++ b/src/core/settings.py @@ -96,6 +96,13 @@ class Settings(BaseSettings): WORKFLOW_DEFAULT_MAX_CONCURRENT_PER_AGENT: int = 10 WORKFLOW_DEFAULT_MAX_CONCURRENT_EXECUTIONS: int = 5 + # ── Network settings (communication networks) ───────────────────── + NETWORK_CONTEXT_TTL_SECONDS: int = 86400 * 7 # 7 days + NETWORK_CONTEXT_MAX_ENTRIES: int = 500 # max messages in Redis context stream + NETWORK_MAX_PARTICIPANTS: int = 50 + NETWORK_CALLBACK_TIMEOUT_SECONDS: int = 30 + NETWORK_MESSAGE_DELIVERY_MAX_RETRIES: int = 3 + # ── Economy settings (from agent-economy) ────────────────────────── ECONOMY_WELCOME_BONUS_CREDITS: int = 500 ECONOMY_CREDIT_PACKAGES: list[dict] = [ diff --git a/src/main.py b/src/main.py index 5a6707c..b82e095 100644 --- a/src/main.py +++ b/src/main.py @@ -37,6 +37,12 @@ from src.workflow.routes.executions import router as execution_router from src.workflow.routes.webhooks import router as webhook_router +# Network routers (communication networks) +from src.network.routes.networks import router as network_router +from src.network.routes.channels import router as channel_router +from src.network.routes.callbacks import router as callback_router +from src.network.a2a.routes import router as a2a_router + # Economy routers (from agent-economy) from src.economy.routes.wallets import router as wallets_router from src.economy.routes.market import router as market_router @@ -231,6 +237,12 @@ async def handle_workflow_exception(_request: Request, exc: WorkflowAppException app.include_router(execution_router, tags=["Executions"]) app.include_router(webhook_router, tags=["Webhooks"]) +# ── Network routers (communication networks) ───────────────────────── +app.include_router(network_router, tags=["Networks"]) +app.include_router(channel_router, tags=["Channels"]) +app.include_router(callback_router, tags=["Callbacks"]) +app.include_router(a2a_router, tags=["A2A"]) + # ── Economy routers (from agent-economy) ───────────────────────────── app.include_router(wallets_router, prefix="/wallets", tags=["Wallets"]) app.include_router(market_router, prefix="/market", tags=["Market"]) @@ -244,38 +256,9 @@ async def handle_workflow_exception(_request: Request, exc: WorkflowAppException @app.get("/.well-known/agent.json") async def a2a_agent_card(): """A2A-compatible AgentCard for agent-to-agent discovery.""" - return JSONResponse( - { - "name": "Intuno Agent Network", - "description": "Registry, broker, and orchestrator for AI agents", - "url": settings.BASE_URL, - "version": settings.API_VERSION, - "capabilities": { - "streaming": True, - "pushNotifications": False, - }, - "skills": [ - { - "id": "discover", - "name": "Discover Agents", - "description": "Semantic search for AI agents by natural-language query", - }, - { - "id": "invoke", - "name": "Invoke Agent", - "description": "Execute an agent with input data through the broker", - }, - { - "id": "orchestrate", - "name": "Orchestrate Task", - "description": "Multi-step task orchestration across multiple agents", - }, - ], - "authentication": { - "schemes": ["apiKey", "bearer"], - }, - } - ) + from src.network.a2a.agent_card import build_platform_card + + return JSONResponse(build_platform_card()) @app.get("/.well-known/mcp/server-card.json") diff --git a/src/models/__init__.py b/src/models/__init__.py index 4e36376..78b58a1 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -21,6 +21,13 @@ WorkflowExecution, ) +# Network models (communication networks) +from src.network.models.entities import ( # noqa: F401 + CommunicationNetwork, + NetworkMessage, + NetworkParticipant, +) + # Economy models (from agent-economy) from src.economy.models.wallet import Transaction, Wallet # noqa: F401 from src.economy.models.order import Order, Trade # noqa: F401 @@ -45,6 +52,10 @@ "WorkflowExecution", "ProcessEntry", "ContextEntry", + # Network + "CommunicationNetwork", + "NetworkParticipant", + "NetworkMessage", # Economy "Wallet", "Transaction", diff --git a/src/network/__init__.py b/src/network/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/a2a/__init__.py b/src/network/a2a/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/a2a/agent_card.py b/src/network/a2a/agent_card.py new file mode 100644 index 0000000..c4ce7f3 --- /dev/null +++ b/src/network/a2a/agent_card.py @@ -0,0 +1,127 @@ +"""A2A Agent Card generation and serving. + +Generates JSON-LD Agent Cards from the Intuno agent registry for agents +that opt into A2A interoperability. Cards are served at +``GET /.well-known/agent.json`` (platform-level) and per-agent endpoints. +""" + +from typing import Any, Optional +from uuid import UUID + +from src.core.settings import settings + + +def build_agent_card( + agent: Any, + capabilities: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + """Build an A2A Agent Card from an Intuno agent registry entry. + + See: https://google.github.io/A2A/specification/ + """ + card: dict[str, Any] = { + "name": agent.name, + "description": agent.description, + "url": f"{settings.BASE_URL}/a2a/agents/{agent.agent_id}", + "version": getattr(agent, "version", "1.0.0"), + "capabilities": { + "streaming": getattr(agent, "supports_streaming", False), + "pushNotifications": True, # via network callback mechanism + **(capabilities or {}), + }, + "skills": _build_skills(agent), + "authentication": _build_auth(agent), + } + + # Add input schema if available + if agent.input_schema: + card["defaultInputModes"] = ["application/json"] + card["defaultOutputModes"] = ["application/json"] + + return card + + +def build_platform_card() -> dict[str, Any]: + """Build the platform-level A2A Agent Card for Intuno itself.""" + return { + "name": "Intuno Agent Network", + "description": ( + "Registry, broker, and orchestrator for AI agents. " + "Supports multi-directional agent communication with calls, " + "messages, and mailboxes." + ), + "url": settings.BASE_URL, + "version": settings.API_VERSION, + "capabilities": { + "streaming": True, + "pushNotifications": True, + "networks": True, + "topologies": ["mesh", "star", "ring", "custom"], + "channels": ["call", "message", "mailbox"], + }, + "skills": [ + { + "id": "discover", + "name": "Discover Agents", + "description": "Semantic search for AI agents by natural-language query", + }, + { + "id": "invoke", + "name": "Invoke Agent", + "description": "Execute an agent with input data through the broker", + }, + { + "id": "orchestrate", + "name": "Orchestrate Task", + "description": "Multi-step task orchestration across multiple agents", + }, + { + "id": "network", + "name": "Communication Network", + "description": ( + "Create multi-directional communication networks between agents " + "with calls, messages, and mailboxes" + ), + }, + ], + "authentication": { + "schemes": ["apiKey", "bearer"], + }, + } + + +def _build_skills(agent: Any) -> list[dict[str, str]]: + """Extract skills from agent metadata.""" + skills = [] + + # If agent has a2a_capabilities, use those + a2a_caps = getattr(agent, "a2a_capabilities", None) + if a2a_caps and isinstance(a2a_caps, list): + for cap in a2a_caps: + if isinstance(cap, dict): + skills.append(cap) + elif isinstance(cap, str): + skills.append({"id": cap, "name": cap, "description": cap}) + return skills + + # Default: generate a single skill from agent description + skills.append({ + "id": agent.agent_id, + "name": agent.name, + "description": agent.description, + }) + return skills + + +def _build_auth(agent: Any) -> dict[str, Any]: + """Build authentication section from agent auth_type.""" + auth_type = getattr(agent, "auth_type", "public") or "public" + + if auth_type == "public": + return {"schemes": []} + elif auth_type == "api_key": + return {"schemes": ["apiKey"]} + elif auth_type == "bearer_token": + return {"schemes": ["bearer"]} + + return {"schemes": [auth_type]} diff --git a/src/network/a2a/protocol.py b/src/network/a2a/protocol.py new file mode 100644 index 0000000..19206da --- /dev/null +++ b/src/network/a2a/protocol.py @@ -0,0 +1,147 @@ +"""A2A protocol adapter. + +Translates between Intuno's internal message format and the A2A +JSON-RPC wire format. + +A2A spec: https://google.github.io/A2A/specification/ +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any, Optional + + +# A2A task states +A2A_STATE_SUBMITTED = "submitted" +A2A_STATE_WORKING = "working" +A2A_STATE_INPUT_REQUIRED = "input-required" +A2A_STATE_COMPLETED = "completed" +A2A_STATE_FAILED = "failed" +A2A_STATE_CANCELED = "canceled" + +# Mapping from Intuno message status to A2A task state +_STATUS_MAP = { + "pending": A2A_STATE_SUBMITTED, + "delivered": A2A_STATE_WORKING, + "read": A2A_STATE_COMPLETED, + "failed": A2A_STATE_FAILED, +} + +# Mapping from Intuno channel type to A2A concepts +_CHANNEL_MAP = { + "call": "task", # synchronous call maps to A2A task + "message": "message", # async message maps to A2A message + "mailbox": "message", # mailbox also maps to A2A message (deferred) +} + + +def intuno_message_to_a2a_task( + message: dict[str, Any], + sender_name: str = "", + recipient_name: str = "", +) -> dict[str, Any]: + """Convert an Intuno NetworkMessage (as dict) to an A2A Task object.""" + task_id = message.get("id") or str(uuid.uuid4()) + status = message.get("status", "pending") + content = message.get("content", "") + metadata = message.get("metadata_") or message.get("metadata") or {} + + a2a_task = { + "id": str(task_id), + "status": { + "state": _STATUS_MAP.get(status, A2A_STATE_SUBMITTED), + "timestamp": ( + message.get("created_at", datetime.now(timezone.utc)).isoformat() + if isinstance(message.get("created_at"), datetime) + else datetime.now(timezone.utc).isoformat() + ), + }, + "history": [ + { + "role": "user", + "parts": [{"type": "text", "text": content}], + } + ], + "metadata": { + "intuno_network_id": str(message.get("network_id", "")), + "intuno_channel": message.get("channel_type", "message"), + "sender": sender_name, + "recipient": recipient_name, + **metadata, + }, + } + + return a2a_task + + +def a2a_task_to_intuno_message( + a2a_task: dict[str, Any], +) -> dict[str, Any]: + """Convert an A2A Task object to Intuno message format.""" + # Extract content from history + content = "" + history = a2a_task.get("history", []) + if history: + last_entry = history[-1] + parts = last_entry.get("parts", []) + text_parts = [p.get("text", "") for p in parts if p.get("type") == "text"] + content = "\n".join(text_parts) + + # Map A2A state back to Intuno status + state = a2a_task.get("status", {}).get("state", A2A_STATE_SUBMITTED) + reverse_status_map = { + A2A_STATE_SUBMITTED: "pending", + A2A_STATE_WORKING: "delivered", + A2A_STATE_COMPLETED: "read", + A2A_STATE_FAILED: "failed", + A2A_STATE_INPUT_REQUIRED: "pending", + A2A_STATE_CANCELED: "failed", + } + + a2a_metadata = a2a_task.get("metadata", {}) + + return { + "content": content, + "channel_type": a2a_metadata.get("intuno_channel", "message"), + "status": reverse_status_map.get(state, "pending"), + "metadata": { + "a2a_task_id": a2a_task.get("id"), + "a2a_state": state, + **{ + k: v + for k, v in a2a_metadata.items() + if not k.startswith("intuno_") + }, + }, + } + + +def build_a2a_json_rpc_response( + result: Any, + request_id: Optional[str | int] = None, +) -> dict[str, Any]: + """Wrap a result in a JSON-RPC 2.0 response envelope.""" + return { + "jsonrpc": "2.0", + "id": request_id, + "result": result, + } + + +def build_a2a_json_rpc_error( + code: int, + message: str, + request_id: Optional[str | int] = None, + data: Any = None, +) -> dict[str, Any]: + """Build a JSON-RPC 2.0 error response.""" + error: dict[str, Any] = {"code": code, "message": message} + if data is not None: + error["data"] = data + return { + "jsonrpc": "2.0", + "id": request_id, + "error": error, + } diff --git a/src/network/a2a/routes.py b/src/network/a2a/routes.py new file mode 100644 index 0000000..4809ae9 --- /dev/null +++ b/src/network/a2a/routes.py @@ -0,0 +1,174 @@ +"""A2A-compatible API endpoints. + +Provides endpoints that follow the A2A protocol specification, allowing +A2A-compatible agents to interact with Intuno networks. + +See: https://google.github.io/A2A/specification/ +""" + +from typing import Any, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +from src.core.auth import get_current_user +from src.models.auth import User +from src.network.a2a.agent_card import build_agent_card, build_platform_card +from src.network.a2a.protocol import ( + a2a_task_to_intuno_message, + build_a2a_json_rpc_error, + build_a2a_json_rpc_response, + intuno_message_to_a2a_task, +) +from src.repositories.registry import RegistryRepository + +router = APIRouter(prefix="/a2a", tags=["A2A"]) + + +# ── Agent Card endpoints ───────────────────────────────────────────── + + +@router.get("/agent-card") +async def get_platform_agent_card() -> JSONResponse: + """Serve the platform-level A2A Agent Card.""" + return JSONResponse(build_platform_card()) + + +@router.get("/agents/{agent_id}/agent-card") +async def get_agent_card( + agent_id: str, + registry: RegistryRepository = Depends(), +) -> JSONResponse: + """Serve an A2A Agent Card for a specific registered agent.""" + agent = await registry.get_agent_by_agent_id(agent_id) + if not agent: + return JSONResponse( + build_a2a_json_rpc_error(-32602, f"Agent '{agent_id}' not found"), + status_code=404, + ) + return JSONResponse(build_agent_card(agent)) + + +# ── A2A Task endpoints (JSON-RPC style) ────────────────────────────── + + +class A2ATaskSendRequest(BaseModel): + """A2A tasks/send request body.""" + + jsonrpc: str = "2.0" + id: Optional[str | int] = None + method: str = "tasks/send" + params: dict[str, Any] = Field(default_factory=dict) + + +@router.post("/tasks/send") +async def a2a_task_send( + data: A2ATaskSendRequest, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """A2A-compatible task send endpoint. + + Receives an A2A task, translates it to an Intuno network message, + processes it, and returns the result in A2A format. + """ + from src.network.services.channels import ChannelService + from src.network.repositories.networks import NetworkRepository + from src.network.utils.context_manager import NetworkContextManager + from src.database import get_redis + + params = data.params + task_data = params.get("task", {}) + network_id = params.get("network_id") + sender_participant_id = params.get("sender_participant_id") + recipient_participant_id = params.get("recipient_participant_id") + + if not network_id or not sender_participant_id: + return JSONResponse( + build_a2a_json_rpc_error( + -32602, + "Missing required params: network_id, sender_participant_id", + data.id, + ), + status_code=400, + ) + + # Convert A2A task to Intuno message format + intuno_msg = a2a_task_to_intuno_message(task_data) + + # Process through the channel service + try: + redis = request.app.state.redis + repo = NetworkRepository( + session=(await request.app.state.db_session_factory()).__aenter__() + ) + ctx_manager = NetworkContextManager(redis) + channel_service = ChannelService(repo=repo, context_manager=ctx_manager) + channel_service.set_http_client(request.app.state.http_client) + + channel_type = intuno_msg.get("channel_type", "message") + + if channel_type == "call" and recipient_participant_id: + result = await channel_service.call( + network_id=UUID(network_id), + sender_participant_id=UUID(sender_participant_id), + recipient_participant_id=UUID(recipient_participant_id), + content=intuno_msg["content"], + metadata=intuno_msg.get("metadata"), + ) + # Convert result back to A2A task format + a2a_result = { + "id": result.get("message_id"), + "status": {"state": "completed"}, + "artifacts": [ + { + "parts": [ + {"type": "text", "text": str(result.get("response", ""))} + ] + } + ], + } + else: + message = await channel_service.send_message( + network_id=UUID(network_id), + sender_participant_id=UUID(sender_participant_id), + recipient_participant_id=UUID(recipient_participant_id), + content=intuno_msg["content"], + metadata=intuno_msg.get("metadata"), + ) + a2a_result = intuno_message_to_a2a_task( + { + "id": message.id, + "status": message.status, + "content": message.content, + "network_id": message.network_id, + "channel_type": message.channel_type, + "created_at": message.created_at, + }, + ) + + return JSONResponse(build_a2a_json_rpc_response(a2a_result, data.id)) + + except Exception as exc: + return JSONResponse( + build_a2a_json_rpc_error(-32603, str(exc), data.id), + status_code=500, + ) + + +# ── A2A Agent Discovery ────────────────────────────────────────────── + + +@router.get("/agents") +async def list_a2a_agents( + registry: RegistryRepository = Depends(), +) -> JSONResponse: + """List all agents with A2A support enabled.""" + agents = await registry.list_agents(limit=100) + cards = [] + for agent in agents: + if getattr(agent, "is_active", False): + cards.append(build_agent_card(agent)) + return JSONResponse({"agents": cards}) diff --git a/src/network/models/__init__.py b/src/network/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/models/entities.py b/src/network/models/entities.py new file mode 100644 index 0000000..8d9286a --- /dev/null +++ b/src/network/models/entities.py @@ -0,0 +1,163 @@ +"""Communication network domain models. + +A CommunicationNetwork groups participants (agents, personas) that can +exchange messages through calls, messages, or mailboxes. +""" + +import enum +from typing import Optional +from uuid import UUID + +from sqlalchemy import Column, Enum, ForeignKey, String, Text, Boolean +from sqlalchemy.dialects.postgresql import JSONB, UUID as PostgresUUID +from sqlalchemy.orm import relationship + +from src.models.base import BaseModel + + +class TopologyType(str, enum.Enum): + mesh = "mesh" + star = "star" + ring = "ring" + custom = "custom" + + +class NetworkStatus(str, enum.Enum): + active = "active" + paused = "paused" + closed = "closed" + + +class ParticipantType(str, enum.Enum): + agent = "agent" + persona = "persona" + orchestrator = "orchestrator" + + +class ParticipantStatus(str, enum.Enum): + active = "active" + disconnected = "disconnected" + removed = "removed" + + +class ChannelType(str, enum.Enum): + call = "call" + message = "message" + mailbox = "mailbox" + + +class MessageStatus(str, enum.Enum): + pending = "pending" + delivered = "delivered" + read = "read" + failed = "failed" + + +class CommunicationNetwork(BaseModel): + """A group of participants that share a communication context.""" + + __tablename__: str = "communication_networks" + + owner_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("users.id"), nullable=False + ) + name: Column[str] = Column(String(255), nullable=False) + topology_type: Column[str] = Column( + Enum(TopologyType), nullable=False, default=TopologyType.mesh + ) + metadata_: Column[Optional[dict]] = Column("metadata", JSONB, nullable=True) + status: Column[str] = Column( + Enum(NetworkStatus), nullable=False, default=NetworkStatus.active + ) + + # Relationships + owner = relationship("User") + participants = relationship( + "NetworkParticipant", + back_populates="network", + cascade="all, delete-orphan", + ) + messages = relationship( + "NetworkMessage", + back_populates="network", + cascade="all, delete-orphan", + order_by="NetworkMessage.created_at", + ) + + +class NetworkParticipant(BaseModel): + """An entity registered in a communication network.""" + + __tablename__: str = "network_participants" + + network_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("communication_networks.id"), nullable=False + ) + agent_id: Column[Optional[UUID]] = Column( + PostgresUUID, ForeignKey("agents.id"), nullable=True + ) + participant_type: Column[str] = Column( + Enum(ParticipantType), nullable=False, default=ParticipantType.agent + ) + name: Column[str] = Column(String(255), nullable=False) + callback_url: Column[Optional[str]] = Column(Text, nullable=True) + polling_enabled: Column[bool] = Column(Boolean, nullable=False, default=False) + capabilities: Column[Optional[dict]] = Column(JSONB, nullable=True) + status: Column[str] = Column( + Enum(ParticipantStatus), nullable=False, default=ParticipantStatus.active + ) + + # Relationships + network = relationship("CommunicationNetwork", back_populates="participants") + agent = relationship("Agent") + sent_messages = relationship( + "NetworkMessage", + back_populates="sender", + foreign_keys="NetworkMessage.sender_participant_id", + ) + received_messages = relationship( + "NetworkMessage", + back_populates="recipient", + foreign_keys="NetworkMessage.recipient_participant_id", + ) + + +class NetworkMessage(BaseModel): + """A message exchanged within a communication network.""" + + __tablename__: str = "network_messages" + + network_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("communication_networks.id"), nullable=False + ) + sender_participant_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("network_participants.id"), nullable=False + ) + recipient_participant_id: Column[Optional[UUID]] = Column( + PostgresUUID, ForeignKey("network_participants.id"), nullable=True + ) + channel_type: Column[str] = Column( + Enum(ChannelType), nullable=False + ) + content: Column[str] = Column(Text, nullable=False) + metadata_: Column[Optional[dict]] = Column("metadata", JSONB, nullable=True) + status: Column[str] = Column( + Enum(MessageStatus), nullable=False, default=MessageStatus.pending + ) + in_reply_to_id: Column[Optional[UUID]] = Column( + PostgresUUID, ForeignKey("network_messages.id"), nullable=True + ) + + # Relationships + network = relationship("CommunicationNetwork", back_populates="messages") + sender = relationship( + "NetworkParticipant", + back_populates="sent_messages", + foreign_keys=[sender_participant_id], + ) + recipient = relationship( + "NetworkParticipant", + back_populates="received_messages", + foreign_keys=[recipient_participant_id], + ) + in_reply_to = relationship("NetworkMessage", remote_side="NetworkMessage.id") diff --git a/src/network/models/schemas.py b/src/network/models/schemas.py new file mode 100644 index 0000000..5e4cf7e --- /dev/null +++ b/src/network/models/schemas.py @@ -0,0 +1,116 @@ +"""Pydantic request/response schemas for communication networks.""" + +from datetime import datetime +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + + +# ── Network schemas ────────────────────────────────────────────────── + + +class NetworkCreate(BaseModel): + name: str = Field(..., max_length=255) + topology_type: str = Field(default="mesh", description="mesh | star | ring | custom") + metadata: Optional[dict[str, Any]] = None + + +class NetworkUpdate(BaseModel): + name: Optional[str] = Field(default=None, max_length=255) + topology_type: Optional[str] = None + status: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + + +class NetworkResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + owner_id: UUID + name: str + topology_type: str + metadata_: Optional[dict[str, Any]] = Field(default=None, alias="metadata_") + status: str + created_at: datetime + updated_at: datetime + + +# ── Participant schemas ────────────────────────────────────────────── + + +class ParticipantJoin(BaseModel): + agent_id: Optional[UUID] = None + participant_type: str = Field(default="agent", description="agent | persona | orchestrator") + name: str = Field(..., max_length=255) + callback_url: Optional[str] = None + polling_enabled: bool = False + capabilities: Optional[dict[str, Any]] = None + + +class ParticipantUpdate(BaseModel): + callback_url: Optional[str] = None + polling_enabled: Optional[bool] = None + capabilities: Optional[dict[str, Any]] = None + status: Optional[str] = None + + +class ParticipantResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + network_id: UUID + agent_id: Optional[UUID] = None + participant_type: str + name: str + callback_url: Optional[str] = None + polling_enabled: bool + capabilities: Optional[dict[str, Any]] = None + status: str + created_at: datetime + updated_at: datetime + + +# ── Message schemas ────────────────────────────────────────────────── + + +class NetworkMessageCreate(BaseModel): + recipient_participant_id: Optional[UUID] = None + channel_type: str = Field(..., description="call | message | mailbox") + content: str + metadata: Optional[dict[str, Any]] = None + in_reply_to_id: Optional[UUID] = None + + +class NetworkMessageResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + network_id: UUID + sender_participant_id: UUID + recipient_participant_id: Optional[UUID] = None + channel_type: str + content: str + metadata_: Optional[dict[str, Any]] = Field(default=None, alias="metadata_") + status: str + in_reply_to_id: Optional[UUID] = None + created_at: datetime + updated_at: datetime + + +# ── Context snapshot ───────────────────────────────────────────────── + + +class ContextEntry(BaseModel): + sender: str + recipient: Optional[str] = None + channel: str + content: str + timestamp: datetime + + +class NetworkContextSnapshot(BaseModel): + network_id: UUID + participant_count: int + message_count: int + entries: list[ContextEntry] diff --git a/src/network/repositories/__init__.py b/src/network/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/repositories/networks.py b/src/network/repositories/networks.py new file mode 100644 index 0000000..522dcdf --- /dev/null +++ b/src/network/repositories/networks.py @@ -0,0 +1,182 @@ +"""Repository for communication network domain operations.""" + +from typing import Optional +from uuid import UUID + +from fastapi import Depends +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from src.database import get_db +from src.network.models.entities import ( + CommunicationNetwork, + NetworkMessage, + NetworkParticipant, + ParticipantStatus, +) + + +class NetworkRepository: + """CRUD for communication networks, participants, and messages.""" + + def __init__(self, session: AsyncSession = Depends(get_db)): + self.session = session + + # ── Networks ───────────────────────────────────────────────────── + + async def create_network(self, network: CommunicationNetwork) -> CommunicationNetwork: + self.session.add(network) + await self.session.commit() + await self.session.refresh(network) + return network + + async def get_network(self, network_id: UUID) -> Optional[CommunicationNetwork]: + result = await self.session.execute( + select(CommunicationNetwork) + .where(CommunicationNetwork.id == network_id) + .options(selectinload(CommunicationNetwork.participants)) + ) + return result.scalar_one_or_none() + + async def list_networks( + self, + owner_id: UUID, + limit: int = 50, + offset: int = 0, + ) -> list[CommunicationNetwork]: + result = await self.session.execute( + select(CommunicationNetwork) + .where(CommunicationNetwork.owner_id == owner_id) + .order_by(CommunicationNetwork.created_at.desc()) + .limit(limit) + .offset(offset) + ) + return list(result.scalars().all()) + + async def update_network(self, network: CommunicationNetwork) -> CommunicationNetwork: + await self.session.commit() + await self.session.refresh(network) + return network + + async def delete_network(self, network_id: UUID) -> bool: + network = await self.get_network(network_id) + if network: + await self.session.delete(network) + await self.session.commit() + return True + return False + + # ── Participants ───────────────────────────────────────────────── + + async def add_participant(self, participant: NetworkParticipant) -> NetworkParticipant: + self.session.add(participant) + await self.session.commit() + await self.session.refresh(participant) + return participant + + async def get_participant(self, participant_id: UUID) -> Optional[NetworkParticipant]: + result = await self.session.execute( + select(NetworkParticipant).where(NetworkParticipant.id == participant_id) + ) + return result.scalar_one_or_none() + + async def get_participant_by_agent( + self, network_id: UUID, agent_id: UUID + ) -> Optional[NetworkParticipant]: + result = await self.session.execute( + select(NetworkParticipant).where( + NetworkParticipant.network_id == network_id, + NetworkParticipant.agent_id == agent_id, + NetworkParticipant.status == ParticipantStatus.active, + ) + ) + return result.scalar_one_or_none() + + async def list_participants( + self, + network_id: UUID, + active_only: bool = True, + ) -> list[NetworkParticipant]: + q = select(NetworkParticipant).where( + NetworkParticipant.network_id == network_id + ) + if active_only: + q = q.where(NetworkParticipant.status == ParticipantStatus.active) + q = q.order_by(NetworkParticipant.created_at) + result = await self.session.execute(q) + return list(result.scalars().all()) + + async def update_participant(self, participant: NetworkParticipant) -> NetworkParticipant: + await self.session.commit() + await self.session.refresh(participant) + return participant + + async def remove_participant(self, participant: NetworkParticipant) -> NetworkParticipant: + participant.status = ParticipantStatus.removed + await self.session.commit() + await self.session.refresh(participant) + return participant + + # ── Messages ───────────────────────────────────────────────────── + + async def create_message(self, message: NetworkMessage) -> NetworkMessage: + self.session.add(message) + await self.session.commit() + await self.session.refresh(message) + return message + + async def get_message(self, message_id: UUID) -> Optional[NetworkMessage]: + result = await self.session.execute( + select(NetworkMessage).where(NetworkMessage.id == message_id) + ) + return result.scalar_one_or_none() + + async def list_messages( + self, + network_id: UUID, + limit: int = 100, + offset: int = 0, + channel_type: Optional[str] = None, + participant_id: Optional[UUID] = None, + ) -> list[NetworkMessage]: + q = ( + select(NetworkMessage) + .where(NetworkMessage.network_id == network_id) + .order_by(NetworkMessage.created_at) + ) + if channel_type: + q = q.where(NetworkMessage.channel_type == channel_type) + if participant_id: + q = q.where( + (NetworkMessage.sender_participant_id == participant_id) + | (NetworkMessage.recipient_participant_id == participant_id) + ) + q = q.limit(limit).offset(offset) + result = await self.session.execute(q) + return list(result.scalars().all()) + + async def get_context( + self, + network_id: UUID, + limit: int = 50, + ) -> list[NetworkMessage]: + """Get recent messages for building network context.""" + result = await self.session.execute( + select(NetworkMessage) + .where(NetworkMessage.network_id == network_id) + .options( + selectinload(NetworkMessage.sender), + selectinload(NetworkMessage.recipient), + ) + .order_by(NetworkMessage.created_at.desc()) + .limit(limit) + ) + messages = list(result.scalars().all()) + messages.reverse() # chronological order + return messages + + async def update_message(self, message: NetworkMessage) -> NetworkMessage: + await self.session.commit() + await self.session.refresh(message) + return message diff --git a/src/network/routes/__init__.py b/src/network/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/routes/callbacks.py b/src/network/routes/callbacks.py new file mode 100644 index 0000000..80d1814 --- /dev/null +++ b/src/network/routes/callbacks.py @@ -0,0 +1,62 @@ +"""Callback routes: webhook receiver for external agents to push messages back. + +This is the key endpoint that enables bidirectional communication. +When Intuno delivers a message to an external agent, the payload includes +a ``reply_url`` pointing to this endpoint. The agent can POST back to +proactively send messages into the network. +""" + +from typing import Any, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Request +from pydantic import BaseModel, Field + +from src.network.models.schemas import NetworkMessageResponse +from src.network.services.channels import ChannelService + +router = APIRouter(prefix="/networks", tags=["Callbacks"]) + + +class CallbackPayload(BaseModel): + """Payload an external agent sends to its reply_url.""" + + content: str + recipient_participant_id: Optional[UUID] = None + channel_type: str = Field(default="message", description="message | call | mailbox") + metadata: Optional[dict[str, Any]] = None + in_reply_to_id: Optional[UUID] = None + + +@router.post( + "/{network_id}/participants/{participant_id}/callback", + response_model=NetworkMessageResponse, +) +async def receive_callback( + network_id: UUID, + participant_id: UUID, + data: CallbackPayload, + request: Request, + service: ChannelService = Depends(), +) -> NetworkMessageResponse: + """Receive a proactive message from an external agent. + + No authentication required — the reply_url itself acts as a capability + token. The participant_id in the URL identifies the sender. + + The external agent can: + - Reply to a specific message (in_reply_to_id) + - Target a specific recipient (recipient_participant_id) + - Broadcast to the network (omit recipient_participant_id) + - Choose a channel type (call/message/mailbox) + """ + service.set_http_client(request.app.state.http_client) + return await service.handle_callback( + network_id=network_id, + participant_id=participant_id, + content=data.content, + recipient_participant_id=data.recipient_participant_id, + channel_type=data.channel_type, + metadata=data.metadata, + in_reply_to_id=data.in_reply_to_id, + ) diff --git a/src/network/routes/channels.py b/src/network/routes/channels.py new file mode 100644 index 0000000..511dadd --- /dev/null +++ b/src/network/routes/channels.py @@ -0,0 +1,152 @@ +"""Channel routes: calls, messages, mailboxes, inbox, and acknowledgment.""" + +from typing import Any, List, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Query, Request, status +from pydantic import BaseModel, Field + +from src.core.auth import get_current_user +from src.models.auth import User +from src.network.models.schemas import NetworkMessageResponse +from src.network.services.channels import ChannelService + +router = APIRouter(prefix="/networks", tags=["Channels"]) + + +# ── Request schemas ────────────────────────────────────────────────── + + +class CallRequest(BaseModel): + sender_participant_id: UUID + recipient_participant_id: UUID + content: str + metadata: Optional[dict[str, Any]] = None + + +class MessageRequest(BaseModel): + sender_participant_id: UUID + recipient_participant_id: UUID + content: str + metadata: Optional[dict[str, Any]] = None + + +class MailboxRequest(BaseModel): + sender_participant_id: UUID + recipient_participant_id: UUID + content: str + metadata: Optional[dict[str, Any]] = None + + +class AckRequest(BaseModel): + message_ids: list[UUID] + + +# ── Call ───────────────────────────────────────────────────────────── + + +@router.post("/{network_id}/call") +async def make_call( + network_id: UUID, + data: CallRequest, + request: Request, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> dict: + """Synchronous call to another participant. Blocks until response.""" + service.set_http_client(request.app.state.http_client) + return await service.call( + network_id=network_id, + sender_participant_id=data.sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + content=data.content, + metadata=data.metadata, + ) + + +# ── Message ────────────────────────────────────────────────────────── + + +@router.post( + "/{network_id}/messages/send", + response_model=NetworkMessageResponse, + status_code=status.HTTP_201_CREATED, +) +async def send_message( + network_id: UUID, + data: MessageRequest, + request: Request, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> NetworkMessageResponse: + """Send a near-real-time message. Non-blocking.""" + service.set_http_client(request.app.state.http_client) + return await service.send_message( + network_id=network_id, + sender_participant_id=data.sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + content=data.content, + metadata=data.metadata, + ) + + +# ── Mailbox ────────────────────────────────────────────────────────── + + +@router.post( + "/{network_id}/mailbox", + response_model=NetworkMessageResponse, + status_code=status.HTTP_201_CREATED, +) +async def send_to_mailbox( + network_id: UUID, + data: MailboxRequest, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> NetworkMessageResponse: + """Send to mailbox. Fully async — no push delivery.""" + return await service.send_to_mailbox( + network_id=network_id, + sender_participant_id=data.sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + content=data.content, + metadata=data.metadata, + ) + + +# ── Inbox ──────────────────────────────────────────────────────────── + + +@router.get("/{network_id}/inbox/{participant_id}") +async def get_inbox( + network_id: UUID, + participant_id: UUID, + current_user: User = Depends(get_current_user), + channel_type: Optional[str] = Query(default=None), + limit: int = Query(default=50, ge=1, le=200), + service: ChannelService = Depends(), +) -> List[NetworkMessageResponse]: + """Poll inbox for a participant.""" + channel_types = [channel_type] if channel_type else None + messages = await service.get_inbox( + network_id=network_id, + participant_id=participant_id, + channel_types=channel_types, + limit=limit, + ) + return messages + + +# ── Acknowledge ────────────────────────────────────────────────────── + + +@router.post("/{network_id}/messages/ack") +async def acknowledge_messages( + network_id: UUID, + data: AckRequest, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> dict: + """Mark messages as read.""" + count = await service.acknowledge(network_id, data.message_ids) + return {"acknowledged": count} diff --git a/src/network/routes/networks.py b/src/network/routes/networks.py new file mode 100644 index 0000000..0b3d705 --- /dev/null +++ b/src/network/routes/networks.py @@ -0,0 +1,174 @@ +"""Network routes: CRUD for communication networks, participants, and context.""" + +from typing import List +from uuid import UUID + +from fastapi import APIRouter, Depends, Query, status + +from src.core.auth import get_current_user +from src.exceptions import NotFoundException +from src.models.auth import User +from src.network.models.schemas import ( + NetworkContextSnapshot, + NetworkCreate, + NetworkMessageResponse, + NetworkResponse, + NetworkUpdate, + ParticipantJoin, + ParticipantResponse, + ParticipantUpdate, +) +from src.network.services.networks import NetworkService + +router = APIRouter(prefix="/networks", tags=["Networks"]) + + +# ── Networks ───────────────────────────────────────────────────────── + + +@router.post("", response_model=NetworkResponse, status_code=status.HTTP_201_CREATED) +async def create_network( + data: NetworkCreate, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> NetworkResponse: + network = await service.create_network(current_user.id, data) + return network + + +@router.get("", response_model=List[NetworkResponse]) +async def list_networks( + current_user: User = Depends(get_current_user), + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), + service: NetworkService = Depends(), +) -> List[NetworkResponse]: + return await service.list_networks(current_user.id, limit, offset) + + +@router.get("/{network_id}", response_model=NetworkResponse) +async def get_network( + network_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> NetworkResponse: + return await service.get_network(network_id, current_user.id) + + +@router.patch("/{network_id}", response_model=NetworkResponse) +async def update_network( + network_id: UUID, + data: NetworkUpdate, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> NetworkResponse: + return await service.update_network(network_id, current_user.id, data) + + +@router.delete("/{network_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_network( + network_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> None: + success = await service.delete_network(network_id, current_user.id) + if not success: + raise NotFoundException("Network") + + +# ── Participants ───────────────────────────────────────────────────── + + +@router.post( + "/{network_id}/participants", + response_model=ParticipantResponse, + status_code=status.HTTP_201_CREATED, +) +async def join_network( + network_id: UUID, + data: ParticipantJoin, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> ParticipantResponse: + return await service.join_network(network_id, current_user.id, data) + + +@router.get( + "/{network_id}/participants", + response_model=List[ParticipantResponse], +) +async def list_participants( + network_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> List[ParticipantResponse]: + return await service.list_participants(network_id, current_user.id) + + +@router.patch( + "/{network_id}/participants/{participant_id}", + response_model=ParticipantResponse, +) +async def update_participant( + network_id: UUID, + participant_id: UUID, + data: ParticipantUpdate, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> ParticipantResponse: + return await service.update_participant( + network_id, participant_id, current_user.id, data + ) + + +@router.delete( + "/{network_id}/participants/{participant_id}", + status_code=status.HTTP_204_NO_CONTENT, +) +async def leave_network( + network_id: UUID, + participant_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> None: + success = await service.leave_network(network_id, participant_id, current_user.id) + if not success: + raise NotFoundException("Participant") + + +# ── Context ────────────────────────────────────────────────────────── + + +@router.get("/{network_id}/context") +async def get_network_context( + network_id: UUID, + current_user: User = Depends(get_current_user), + limit: int = Query(default=50, ge=1, le=200), + service: NetworkService = Depends(), +) -> dict: + entries = await service.get_context(network_id, current_user.id, limit) + return { + "network_id": str(network_id), + "entries": entries, + } + + +# ── Messages ───────────────────────────────────────────────────────── + + +@router.get( + "/{network_id}/messages", + response_model=List[NetworkMessageResponse], +) +async def list_messages( + network_id: UUID, + current_user: User = Depends(get_current_user), + limit: int = Query(default=100, ge=1, le=500), + offset: int = Query(default=0, ge=0), + channel_type: str | None = Query(default=None), + participant_id: UUID | None = Query(default=None), + service: NetworkService = Depends(), +) -> List[NetworkMessageResponse]: + return await service.list_messages( + network_id, current_user.id, limit, offset, channel_type, participant_id + ) diff --git a/src/network/services/__init__.py b/src/network/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/services/channels.py b/src/network/services/channels.py new file mode 100644 index 0000000..01f6591 --- /dev/null +++ b/src/network/services/channels.py @@ -0,0 +1,446 @@ +"""Channel service — calls, messages, and mailboxes. + +Implements the three communication primitives with different timing +semantics. Each interaction is recorded in the network context and +delivered to the recipient via the appropriate mechanism. +""" + +import json +import logging +import time +from typing import Any, Optional +from uuid import UUID + +import httpx +from fastapi import Depends + +from src.core.settings import settings +from src.exceptions import BadRequestException, NotFoundException +from src.network.models.entities import ( + ChannelType, + CommunicationNetwork, + MessageStatus, + NetworkMessage, + NetworkParticipant, + NetworkStatus, + ParticipantStatus, +) +from src.network.models.schemas import NetworkMessageCreate +from src.network.repositories.networks import NetworkRepository +from src.network.utils.context_manager import NetworkContextManager + +logger = logging.getLogger(__name__) + + +class ChannelService: + """Unified service for calls, messages, and mailboxes.""" + + def __init__( + self, + repo: NetworkRepository = Depends(), + context_manager: NetworkContextManager = Depends(), + ): + self.repo = repo + self.ctx = context_manager + self._http_client: Optional[httpx.AsyncClient] = None + + def set_http_client(self, client: httpx.AsyncClient) -> None: + self._http_client = client + + # ── Calls (synchronous, blocking) ──────────────────────────────── + + async def call( + self, + network_id: UUID, + sender_participant_id: UUID, + recipient_participant_id: UUID, + content: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """Synchronous call: send payload to recipient, wait for response. + + Returns a dict with the call result including the recipient's response. + """ + sender, recipient, network = await self._validate_communication( + network_id, sender_participant_id, recipient_participant_id + ) + + if not recipient.callback_url: + raise BadRequestException( + "Recipient has no callback_url; cannot make a synchronous call" + ) + + # Record outgoing message + outgoing = await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType.call, + content=content, + metadata=metadata, + ) + + # Build context window for the call + context_entries = await self.ctx.get_context_window(network_id, limit=30) + participants = await self.repo.list_participants(network_id) + + # Build the payload with reply_url + payload = self._build_delivery_payload( + network_id=network_id, + sender=sender, + recipient=recipient, + channel="call", + content=content, + context=context_entries, + participants=participants, + message_id=outgoing.id, + ) + + # Synchronous HTTP call + response_data = await self._deliver_http( + recipient.callback_url, + payload, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + + # Record the response as a message from recipient back to sender + response_content = ( + json.dumps(response_data) if isinstance(response_data, dict) else str(response_data) + ) + await self._record_message( + network_id=network_id, + sender=recipient, + recipient=sender, + channel_type=ChannelType.call, + content=response_content, + metadata={"in_reply_to": str(outgoing.id)}, + ) + + # Mark outgoing as delivered + outgoing.status = MessageStatus.delivered + await self.repo.update_message(outgoing) + + return { + "success": True, + "message_id": str(outgoing.id), + "response": response_data, + } + + # ── Messages (near-real-time, non-blocking) ────────────────────── + + async def send_message( + self, + network_id: UUID, + sender_participant_id: UUID, + recipient_participant_id: UUID, + content: str, + metadata: Optional[dict[str, Any]] = None, + ) -> NetworkMessage: + """Non-blocking message: record and push via webhook. + + The sender does not block on the recipient's processing. Delivery + is best-effort with retries handled by the delivery worker. + """ + sender, recipient, network = await self._validate_communication( + network_id, sender_participant_id, recipient_participant_id + ) + + message = await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType.message, + content=content, + metadata=metadata, + ) + + # Attempt immediate webhook delivery (fire-and-forget style) + if recipient.callback_url: + context_entries = await self.ctx.get_context_window(network_id, limit=20) + participants = await self.repo.list_participants(network_id) + payload = self._build_delivery_payload( + network_id=network_id, + sender=sender, + recipient=recipient, + channel="message", + content=content, + context=context_entries, + participants=participants, + message_id=message.id, + ) + try: + await self._deliver_http( + recipient.callback_url, + payload, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + message.status = MessageStatus.delivered + except Exception: + logger.warning( + "Message delivery failed for participant %s; will retry", + recipient_participant_id, + ) + message.status = MessageStatus.pending + await self.repo.update_message(message) + + return message + + # ── Mailbox (fully asynchronous) ───────────────────────────────── + + async def send_to_mailbox( + self, + network_id: UUID, + sender_participant_id: UUID, + recipient_participant_id: UUID, + content: str, + metadata: Optional[dict[str, Any]] = None, + ) -> NetworkMessage: + """Async mailbox: store message, no push delivery.""" + sender, recipient, network = await self._validate_communication( + network_id, sender_participant_id, recipient_participant_id + ) + + return await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType.mailbox, + content=content, + metadata=metadata, + ) + + # ── Inbox (polling) ────────────────────────────────────────────── + + async def get_inbox( + self, + network_id: UUID, + participant_id: UUID, + channel_types: Optional[list[str]] = None, + limit: int = 50, + ) -> list[NetworkMessage]: + """Get unread messages for a participant.""" + messages = await self.repo.list_messages( + network_id=network_id, + limit=limit, + participant_id=participant_id, + ) + if channel_types: + messages = [m for m in messages if m.channel_type in channel_types] + return messages + + async def acknowledge( + self, network_id: UUID, message_ids: list[UUID] + ) -> int: + """Mark messages as read.""" + count = 0 + for msg_id in message_ids: + message = await self.repo.get_message(msg_id) + if message and message.network_id == network_id: + message.status = MessageStatus.read + await self.repo.update_message(message) + count += 1 + return count + + # ── Callback (external agents pushing back) ────────────────────── + + async def handle_callback( + self, + network_id: UUID, + participant_id: UUID, + content: str, + recipient_participant_id: Optional[UUID] = None, + channel_type: str = "message", + metadata: Optional[dict[str, Any]] = None, + in_reply_to_id: Optional[UUID] = None, + ) -> NetworkMessage: + """Handle a proactive message from an external agent via callback URL. + + This is the key to bidirectionality: external agents POST to their + reply_url and this method records the message in the network. + """ + sender = await self.repo.get_participant(participant_id) + if not sender or sender.network_id != network_id: + raise NotFoundException("Participant") + if sender.status != ParticipantStatus.active: + raise BadRequestException("Participant is not active") + + network = await self.repo.get_network(network_id) + if not network or network.status != NetworkStatus.active: + raise BadRequestException("Network is not active") + + recipient = None + if recipient_participant_id: + recipient = await self.repo.get_participant(recipient_participant_id) + if not recipient or recipient.network_id != network_id: + raise NotFoundException("Recipient participant") + + message = await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType(channel_type), + content=content, + metadata=metadata, + in_reply_to_id=in_reply_to_id, + ) + + # If there's a specific recipient with a callback_url, forward the message + if recipient and recipient.callback_url and channel_type == "message": + context_entries = await self.ctx.get_context_window(network_id, limit=20) + participants = await self.repo.list_participants(network_id) + payload = self._build_delivery_payload( + network_id=network_id, + sender=sender, + recipient=recipient, + channel=channel_type, + content=content, + context=context_entries, + participants=participants, + message_id=message.id, + ) + try: + await self._deliver_http( + recipient.callback_url, + payload, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + message.status = MessageStatus.delivered + await self.repo.update_message(message) + except Exception: + logger.warning( + "Forwarding callback message failed for participant %s", + recipient_participant_id, + ) + + return message + + # ── Internal helpers ───────────────────────────────────────────── + + async def _validate_communication( + self, + network_id: UUID, + sender_id: UUID, + recipient_id: UUID, + ) -> tuple[NetworkParticipant, NetworkParticipant, CommunicationNetwork]: + network = await self.repo.get_network(network_id) + if not network: + raise NotFoundException("Network") + if network.status != NetworkStatus.active: + raise BadRequestException("Network is not active") + + sender = await self.repo.get_participant(sender_id) + if not sender or sender.network_id != network_id: + raise NotFoundException("Sender participant") + if sender.status != ParticipantStatus.active: + raise BadRequestException("Sender is not active") + + recipient = await self.repo.get_participant(recipient_id) + if not recipient or recipient.network_id != network_id: + raise NotFoundException("Recipient participant") + if recipient.status != ParticipantStatus.active: + raise BadRequestException("Recipient is not active") + + return sender, recipient, network + + async def _record_message( + self, + *, + network_id: UUID, + sender: NetworkParticipant, + recipient: Optional[NetworkParticipant], + channel_type: ChannelType, + content: str, + metadata: Optional[dict[str, Any]] = None, + in_reply_to_id: Optional[UUID] = None, + ) -> NetworkMessage: + message = NetworkMessage( + network_id=network_id, + sender_participant_id=sender.id, + recipient_participant_id=recipient.id if recipient else None, + channel_type=channel_type, + content=content, + metadata_=metadata, + in_reply_to_id=in_reply_to_id, + ) + message = await self.repo.create_message(message) + + await self.ctx.append( + network_id, + sender=sender.name, + recipient=recipient.name if recipient else None, + channel=channel_type.value, + content=content, + message_id=message.id, + ) + return message + + def _build_delivery_payload( + self, + *, + network_id: UUID, + sender: NetworkParticipant, + recipient: NetworkParticipant, + channel: str, + content: str, + context: list[dict], + participants: list[NetworkParticipant], + message_id: UUID, + ) -> dict[str, Any]: + """Build the standard payload delivered to external agents.""" + return { + "network_id": str(network_id), + "message_id": str(message_id), + "channel": channel, + "sender": { + "participant_id": str(sender.id), + "name": sender.name, + }, + "content": content, + "context": context, + "reply_url": ( + f"{settings.BASE_URL}/networks/{network_id}" + f"/participants/{recipient.id}/callback" + ), + "network_participants": [ + {"participant_id": str(p.id), "name": p.name} + for p in participants + if p.status == ParticipantStatus.active + ], + } + + async def _deliver_http( + self, + url: str, + payload: dict[str, Any], + timeout: int = 30, + ) -> dict[str, Any]: + """POST payload to an external agent's callback URL.""" + client = self._http_client + owns_client = client is None + if owns_client: + client = httpx.AsyncClient(timeout=timeout) + + try: + response = await client.post( + url, + json=payload, + headers={ + "Content-Type": "application/json", + "User-Agent": "Intuno-Network/1.0", + }, + timeout=timeout, + ) + if response.status_code == 200: + try: + return response.json() + except Exception: + return {"raw_response": response.text} + else: + raise httpx.HTTPStatusError( + f"Callback returned {response.status_code}", + request=response.request, + response=response, + ) + finally: + if owns_client: + await client.aclose() diff --git a/src/network/services/networks.py b/src/network/services/networks.py new file mode 100644 index 0000000..2cf4321 --- /dev/null +++ b/src/network/services/networks.py @@ -0,0 +1,219 @@ +"""Communication network service — business logic for networks and participants.""" + +from typing import Optional +from uuid import UUID + +from fastapi import Depends + +from src.exceptions import BadRequestException, NotFoundException +from src.network.models.entities import ( + ChannelType, + CommunicationNetwork, + NetworkMessage, + NetworkParticipant, + NetworkStatus, + ParticipantStatus, + ParticipantType, + TopologyType, +) +from src.network.models.schemas import ( + NetworkCreate, + NetworkMessageCreate, + NetworkUpdate, + ParticipantJoin, + ParticipantUpdate, +) +from src.network.repositories.networks import NetworkRepository +from src.network.utils.context_manager import NetworkContextManager + + +class NetworkService: + """Service for communication network operations.""" + + def __init__( + self, + repo: NetworkRepository = Depends(), + context_manager: NetworkContextManager = Depends(), + ): + self.repo = repo + self.ctx = context_manager + + # ── Networks ───────────────────────────────────────────────────── + + async def create_network( + self, owner_id: UUID, data: NetworkCreate + ) -> CommunicationNetwork: + network = CommunicationNetwork( + owner_id=owner_id, + name=data.name, + topology_type=TopologyType(data.topology_type), + metadata_=data.metadata, + status=NetworkStatus.active, + ) + return await self.repo.create_network(network) + + async def get_network(self, network_id: UUID, owner_id: UUID) -> CommunicationNetwork: + network = await self.repo.get_network(network_id) + if not network or network.owner_id != owner_id: + raise NotFoundException("Network") + return network + + async def list_networks( + self, owner_id: UUID, limit: int = 50, offset: int = 0 + ) -> list[CommunicationNetwork]: + return await self.repo.list_networks(owner_id, limit, offset) + + async def update_network( + self, network_id: UUID, owner_id: UUID, data: NetworkUpdate + ) -> CommunicationNetwork: + network = await self.get_network(network_id, owner_id) + if data.name is not None: + network.name = data.name + if data.topology_type is not None: + network.topology_type = TopologyType(data.topology_type) + if data.status is not None: + network.status = NetworkStatus(data.status) + if data.metadata is not None: + network.metadata_ = data.metadata + return await self.repo.update_network(network) + + async def delete_network(self, network_id: UUID, owner_id: UUID) -> bool: + network = await self.repo.get_network(network_id) + if not network or network.owner_id != owner_id: + return False + await self.ctx.clear(network_id) + return await self.repo.delete_network(network_id) + + # ── Participants ───────────────────────────────────────────────── + + async def join_network( + self, network_id: UUID, owner_id: UUID, data: ParticipantJoin + ) -> NetworkParticipant: + network = await self.get_network(network_id, owner_id) + if network.status != NetworkStatus.active: + raise BadRequestException("Network is not active") + if not data.callback_url and not data.polling_enabled: + raise BadRequestException( + "Participant must have a callback_url or polling_enabled" + ) + if data.agent_id: + existing = await self.repo.get_participant_by_agent(network_id, data.agent_id) + if existing: + raise BadRequestException("Agent already in network") + participant = NetworkParticipant( + network_id=network_id, + agent_id=data.agent_id, + participant_type=ParticipantType(data.participant_type), + name=data.name, + callback_url=data.callback_url, + polling_enabled=data.polling_enabled, + capabilities=data.capabilities, + status=ParticipantStatus.active, + ) + return await self.repo.add_participant(participant) + + async def list_participants( + self, network_id: UUID, owner_id: UUID + ) -> list[NetworkParticipant]: + await self.get_network(network_id, owner_id) + return await self.repo.list_participants(network_id) + + async def update_participant( + self, + network_id: UUID, + participant_id: UUID, + owner_id: UUID, + data: ParticipantUpdate, + ) -> NetworkParticipant: + await self.get_network(network_id, owner_id) + participant = await self.repo.get_participant(participant_id) + if not participant or participant.network_id != network_id: + raise NotFoundException("Participant") + if data.callback_url is not None: + participant.callback_url = data.callback_url + if data.polling_enabled is not None: + participant.polling_enabled = data.polling_enabled + if data.capabilities is not None: + participant.capabilities = data.capabilities + if data.status is not None: + participant.status = ParticipantStatus(data.status) + return await self.repo.update_participant(participant) + + async def leave_network( + self, network_id: UUID, participant_id: UUID, owner_id: UUID + ) -> bool: + await self.get_network(network_id, owner_id) + participant = await self.repo.get_participant(participant_id) + if not participant or participant.network_id != network_id: + return False + await self.repo.remove_participant(participant) + return True + + # ── Context ────────────────────────────────────────────────────── + + async def get_context( + self, network_id: UUID, owner_id: UUID, limit: int = 50 + ) -> list[dict]: + await self.get_network(network_id, owner_id) + return await self.ctx.get_context_window(network_id, limit) + + async def get_context_from_db( + self, network_id: UUID, owner_id: UUID, limit: int = 50 + ) -> list[NetworkMessage]: + """Authoritative context from Postgres (slower but complete).""" + await self.get_network(network_id, owner_id) + return await self.repo.get_context(network_id, limit) + + # ── Messages (internal recording) ──────────────────────────────── + + async def record_message( + self, + network_id: UUID, + sender_participant_id: UUID, + data: NetworkMessageCreate, + ) -> NetworkMessage: + """Record a message and update the Redis context cache.""" + sender = await self.repo.get_participant(sender_participant_id) + if not sender or sender.network_id != network_id: + raise BadRequestException("Sender is not a participant in this network") + + message = NetworkMessage( + network_id=network_id, + sender_participant_id=sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + channel_type=ChannelType(data.channel_type), + content=data.content, + metadata_=data.metadata, + in_reply_to_id=data.in_reply_to_id, + ) + message = await self.repo.create_message(message) + + # Update Redis context cache + recipient = None + if data.recipient_participant_id: + r = await self.repo.get_participant(data.recipient_participant_id) + recipient = r.name if r else None + + await self.ctx.append( + network_id, + sender=sender.name, + recipient=recipient, + channel=data.channel_type, + content=data.content, + message_id=message.id, + ) + return message + + async def list_messages( + self, + network_id: UUID, + owner_id: UUID, + limit: int = 100, + offset: int = 0, + channel_type: Optional[str] = None, + participant_id: Optional[UUID] = None, + ) -> list[NetworkMessage]: + await self.get_network(network_id, owner_id) + return await self.repo.list_messages( + network_id, limit, offset, channel_type, participant_id + ) diff --git a/src/network/utils/__init__.py b/src/network/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/utils/aggregator.py b/src/network/utils/aggregator.py new file mode 100644 index 0000000..26673c0 --- /dev/null +++ b/src/network/utils/aggregator.py @@ -0,0 +1,126 @@ +"""Fan-in aggregation strategies for combining outputs from multiple agents. + +Used by the ``aggregate`` step type in the workflow orchestrator. +""" + +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger(__name__) + + +class Aggregator(ABC): + """Base class for aggregation strategies.""" + + @abstractmethod + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + """Combine multiple agent outputs into a single result.""" + ... + + +class MergeAggregator(Aggregator): + """Concatenate all outputs into a single dict. + + Each input is keyed by its source step ID. + """ + + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + merged: dict[str, Any] = {} + for item in inputs: + source = item.get("source", f"input_{len(merged)}") + merged[source] = item.get("output") + return {"strategy": "merge", "result": merged} + + +class VoteAggregator(Aggregator): + """Pick the majority answer for classification tasks. + + Each input should have an ``output`` field with a string value. + The most common value wins. + """ + + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + votes: dict[str, int] = {} + for item in inputs: + output = item.get("output") + key = str(output) if output is not None else "null" + votes[key] = votes.get(key, 0) + 1 + + if not votes: + return {"strategy": "vote", "result": None, "votes": {}} + + winner = max(votes, key=votes.get) + return { + "strategy": "vote", + "result": winner, + "votes": votes, + "total": len(inputs), + } + + +class LLMSummarizeAggregator(Aggregator): + """Use an LLM to synthesize all inputs into a coherent output. + + Falls back to merge if LLM is unavailable. + """ + + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + try: + from openai import AsyncOpenAI + from src.core.settings import settings + + client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) + + formatted_inputs = "\n\n".join( + f"[{item.get('source', f'Agent {i+1}')}]:\n{json.dumps(item.get('output'), indent=2)}" + for i, item in enumerate(inputs) + ) + + response = await client.chat.completions.create( + model=settings.LLM_ENHANCEMENT_MODEL, + messages=[ + { + "role": "system", + "content": ( + "You are a synthesizer. Multiple AI agents have provided their outputs. " + "Combine them into a single coherent, comprehensive response. " + "Resolve contradictions, merge complementary information, " + "and produce a unified result." + ), + }, + { + "role": "user", + "content": f"Synthesize these agent outputs:\n\n{formatted_inputs}", + }, + ], + temperature=0.3, + ) + + synthesis = response.choices[0].message.content + return { + "strategy": "llm_summarize", + "result": synthesis, + "source_count": len(inputs), + } + except Exception as exc: + logger.warning("LLM summarize failed, falling back to merge: %s", exc) + fallback = MergeAggregator() + result = await fallback.aggregate(inputs) + result["strategy"] = "llm_summarize_fallback" + return result + + +def create_aggregator(strategy: str) -> Aggregator: + """Factory function for aggregation strategies.""" + if strategy == "merge": + return MergeAggregator() + elif strategy == "vote": + return VoteAggregator() + elif strategy == "llm_summarize": + return LLMSummarizeAggregator() + else: + raise ValueError(f"Unknown aggregation strategy: {strategy}") diff --git a/src/network/utils/context_manager.py b/src/network/utils/context_manager.py new file mode 100644 index 0000000..8b880f3 --- /dev/null +++ b/src/network/utils/context_manager.py @@ -0,0 +1,76 @@ +"""Network-scoped context manager. + +Maintains a fast Redis cache of recent messages per network alongside +the authoritative Postgres storage. When delivering a message to an +external agent, we build a context window from Redis for low-latency. +""" + +import json +import time +import uuid +from typing import Any + +import redis.asyncio as aioredis +from fastapi import Depends + +from src.core.settings import settings +from src.database import get_redis + + +class NetworkContextManager: + """Redis-backed context accumulator for communication networks.""" + + def __init__(self, redis: aioredis.Redis = Depends(get_redis)) -> None: + self._redis = redis + + def _stream_key(self, network_id: uuid.UUID) -> str: + return f"net:{network_id}:ctx" + + async def append( + self, + network_id: uuid.UUID, + *, + sender: str, + recipient: str | None, + channel: str, + content: str, + message_id: uuid.UUID | None = None, + ) -> None: + """Append a message to the network context stream.""" + entry = { + "sender": sender, + "recipient": recipient or "", + "channel": channel, + "content": content, + "message_id": str(message_id) if message_id else "", + "ts": str(time.time()), + } + key = self._stream_key(network_id) + await self._redis.xadd(key, entry, maxlen=settings.NETWORK_CONTEXT_MAX_ENTRIES) + await self._redis.expire(key, settings.NETWORK_CONTEXT_TTL_SECONDS) + + async def get_context_window( + self, + network_id: uuid.UUID, + limit: int = 50, + ) -> list[dict[str, Any]]: + """Retrieve recent context entries from Redis stream.""" + key = self._stream_key(network_id) + # Read from the end of the stream + entries = await self._redis.xrevrange(key, count=limit) + result = [] + for _stream_id, data in reversed(entries): + result.append( + { + "sender": data["sender"], + "recipient": data["recipient"] or None, + "channel": data["channel"], + "content": data["content"], + "timestamp": float(data["ts"]), + } + ) + return result + + async def clear(self, network_id: uuid.UUID) -> None: + """Delete the context stream for a network.""" + await self._redis.delete(self._stream_key(network_id)) diff --git a/src/network/utils/convergence.py b/src/network/utils/convergence.py new file mode 100644 index 0000000..e1ea647 --- /dev/null +++ b/src/network/utils/convergence.py @@ -0,0 +1,146 @@ +"""Convergence detectors for feedback loops. + +Determine when iterative agent interactions have converged and should +stop. Used by the loop step type in the workflow orchestrator. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger(__name__) + + +class ConvergenceDetector(ABC): + """Base class for convergence detection strategies.""" + + @abstractmethod + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + """Return True if the loop should stop.""" + ... + + +class MaxIterationsDetector(ConvergenceDetector): + """Hard cap on iterations — always enforced as a safety net.""" + + def __init__(self, max_iterations: int = 5): + self.max_iterations = max_iterations + + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + return iteration >= self.max_iterations + + +class ApprovalDetector(ConvergenceDetector): + """Check if the output contains an explicit approval signal. + + Looks for keywords like "approved", "accepted", "lgtm" in the output + or for a structured ``{"approved": true}`` field. + """ + + APPROVAL_KEYWORDS = {"approved", "accepted", "lgtm", "looks good", "ship it"} + + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + if isinstance(current_output, dict): + if current_output.get("approved") is True: + return True + text = str(current_output.get("output", "")) + str( + current_output.get("content", "") + ) + elif isinstance(current_output, str): + text = current_output + else: + return False + + text_lower = text.lower() + return any(kw in text_lower for kw in self.APPROVAL_KEYWORDS) + + +class SimilarityDetector(ConvergenceDetector): + """Compare consecutive outputs using text similarity. + + Uses a simple token overlap ratio (Jaccard similarity). For + production use, this could be upgraded to use embedding cosine + similarity via the EmbeddingService. + """ + + def __init__(self, threshold: float = 0.95): + self.threshold = threshold + + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + if previous_output is None: + return False + + current_text = self._to_text(current_output) + previous_text = self._to_text(previous_output) + + if not current_text or not previous_text: + return False + + similarity = self._jaccard_similarity(current_text, previous_text) + logger.debug( + "Similarity check: iteration=%d similarity=%.3f threshold=%.3f", + iteration, + similarity, + self.threshold, + ) + return similarity >= self.threshold + + def _to_text(self, output: Any) -> str: + if isinstance(output, str): + return output + if isinstance(output, dict): + return str(output.get("output", "")) or str(output.get("content", "")) + return str(output) + + def _jaccard_similarity(self, a: str, b: str) -> float: + tokens_a = set(a.lower().split()) + tokens_b = set(b.lower().split()) + if not tokens_a and not tokens_b: + return 1.0 + intersection = tokens_a & tokens_b + union = tokens_a | tokens_b + return len(intersection) / len(union) if union else 1.0 + + +def create_detector( + detector_type: str, + config: dict[str, Any] | None = None, +) -> ConvergenceDetector: + """Factory function for convergence detectors.""" + config = config or {} + if detector_type == "similarity": + return SimilarityDetector(threshold=config.get("threshold", 0.95)) + elif detector_type == "approval": + return ApprovalDetector() + elif detector_type == "max_iterations": + return MaxIterationsDetector( + max_iterations=config.get("max_iterations", 5) + ) + else: + raise ValueError(f"Unknown convergence detector type: {detector_type}") diff --git a/src/network/utils/delivery_worker.py b/src/network/utils/delivery_worker.py new file mode 100644 index 0000000..1e1c880 --- /dev/null +++ b/src/network/utils/delivery_worker.py @@ -0,0 +1,157 @@ +"""Background delivery worker for async network messages. + +Consumes from a Redis Stream to deliver pending messages to participants +via their callback URLs. Follows the same consumer-group pattern as +``src.workflow.utils.event_consumer.EventConsumer``. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any + +import httpx +import redis.asyncio as aioredis + +from src.core.settings import settings + +logger = logging.getLogger(__name__) + +STREAM_KEY = "intuno:network:delivery" +CONSUMER_GROUP = "network_delivery_workers" +CONSUMER_NAME = "delivery-worker-1" + + +class DeliveryWorker: + """Reads pending deliveries from Redis Stream and POSTs to callback URLs.""" + + def __init__(self, redis: aioredis.Redis) -> None: + self._redis = redis + self._task: asyncio.Task[None] | None = None + self._http_client: httpx.AsyncClient | None = None + + async def start(self, http_client: httpx.AsyncClient | None = None) -> None: + self._http_client = http_client + try: + await self._redis.xgroup_create( + STREAM_KEY, CONSUMER_GROUP, id="0", mkstream=True, + ) + except aioredis.ResponseError as exc: + if "BUSYGROUP" not in str(exc): + raise + self._task = asyncio.create_task(self._consume(), name="delivery-worker") + logger.info("Delivery worker started on stream '%s'", STREAM_KEY) + + async def stop(self) -> None: + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + logger.info("Delivery worker stopped") + + @staticmethod + async def enqueue( + redis: aioredis.Redis, + *, + callback_url: str, + payload: dict[str, Any], + message_id: str, + attempt: int = 0, + ) -> None: + """Enqueue a delivery task into the Redis Stream.""" + await redis.xadd( + STREAM_KEY, + { + "callback_url": callback_url, + "payload": json.dumps(payload, default=str), + "message_id": message_id, + "attempt": str(attempt), + }, + ) + + async def _consume(self) -> None: + while True: + try: + entries = await self._redis.xreadgroup( + CONSUMER_GROUP, + CONSUMER_NAME, + {STREAM_KEY: ">"}, + count=10, + block=5000, + ) + if not entries: + continue + for _stream, messages in entries: + for msg_id, data in messages: + await self._deliver(msg_id, data) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Delivery worker error; retrying in 2s") + await asyncio.sleep(2) + + async def _deliver(self, stream_id: str, data: dict[str, str]) -> None: + callback_url = data.get("callback_url", "") + payload_raw = data.get("payload", "{}") + message_id = data.get("message_id", "") + attempt = int(data.get("attempt", "0")) + + try: + payload = json.loads(payload_raw) + except json.JSONDecodeError: + payload = {"raw": payload_raw} + + client = self._http_client or httpx.AsyncClient( + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS + ) + owns_client = self._http_client is None + + try: + response = await client.post( + callback_url, + json=payload, + headers={ + "Content-Type": "application/json", + "User-Agent": "Intuno-Network/1.0", + }, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + if response.status_code == 200: + logger.debug("Delivered message %s to %s", message_id, callback_url) + else: + logger.warning( + "Delivery to %s returned %d for message %s", + callback_url, + response.status_code, + message_id, + ) + await self._maybe_retry(data, attempt) + except Exception: + logger.warning( + "Delivery to %s failed for message %s (attempt %d)", + callback_url, + message_id, + attempt, + ) + await self._maybe_retry(data, attempt) + finally: + if owns_client: + await client.aclose() + + await self._redis.xack(STREAM_KEY, CONSUMER_GROUP, stream_id) + + async def _maybe_retry(self, data: dict[str, str], attempt: int) -> None: + if attempt < settings.NETWORK_MESSAGE_DELIVERY_MAX_RETRIES: + backoff = 2 ** (attempt + 1) + await asyncio.sleep(backoff) + await self.enqueue( + self._redis, + callback_url=data.get("callback_url", ""), + payload=json.loads(data.get("payload", "{}")), + message_id=data.get("message_id", ""), + attempt=attempt + 1, + ) diff --git a/src/network/utils/topology.py b/src/network/utils/topology.py new file mode 100644 index 0000000..da295bd --- /dev/null +++ b/src/network/utils/topology.py @@ -0,0 +1,104 @@ +"""Topology validation and routing rules for communication networks. + +Enforces communication constraints based on the network's topology type: +- mesh: any participant can communicate with any other +- star: only the hub (first participant) can initiate +- ring: messages flow sequentially through participants +- custom: no enforcement (topology managed externally) +""" + +from uuid import UUID + +from src.exceptions import BadRequestException +from src.network.models.entities import ( + CommunicationNetwork, + NetworkParticipant, + TopologyType, +) + + +class TopologyValidator: + """Validates whether communication is allowed given the network topology.""" + + def validate( + self, + network: CommunicationNetwork, + sender: NetworkParticipant, + recipient: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> None: + """Raise BadRequestException if communication is not allowed.""" + topology = network.topology_type + if topology == TopologyType.mesh or topology == TopologyType.custom: + return # no restrictions + + if topology == TopologyType.star: + self._validate_star(sender, participants) + elif topology == TopologyType.ring: + self._validate_ring(sender, recipient, participants) + + def _validate_star( + self, + sender: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> None: + """In star topology, only the hub (first participant) can initiate.""" + if not participants: + return + hub = participants[0] + if sender.id != hub.id: + raise BadRequestException( + f"Star topology: only the hub participant '{hub.name}' can initiate communication" + ) + + def _validate_ring( + self, + sender: NetworkParticipant, + recipient: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> None: + """In ring topology, messages flow to the next participant in order.""" + if len(participants) < 2: + return + ids = [p.id for p in participants] + try: + sender_idx = ids.index(sender.id) + except ValueError: + raise BadRequestException("Sender is not in the participant list") + next_idx = (sender_idx + 1) % len(ids) + if recipient.id != ids[next_idx]: + expected_name = participants[next_idx].name + raise BadRequestException( + f"Ring topology: '{sender.name}' can only send to the next participant " + f"'{expected_name}', not '{recipient.name}'" + ) + + def get_reachable( + self, + network: CommunicationNetwork, + sender: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> list[NetworkParticipant]: + """Return participants that the sender can communicate with.""" + topology = network.topology_type + others = [p for p in participants if p.id != sender.id] + + if topology == TopologyType.mesh or topology == TopologyType.custom: + return others + + if topology == TopologyType.star: + hub = participants[0] if participants else None + if hub and sender.id == hub.id: + return others + return [hub] if hub else [] + + if topology == TopologyType.ring: + ids = [p.id for p in participants] + try: + sender_idx = ids.index(sender.id) + except ValueError: + return [] + next_idx = (sender_idx + 1) % len(ids) + return [participants[next_idx]] + + return others diff --git a/src/workflow/models/dsl.py b/src/workflow/models/dsl.py index 6bc01b2..e71f8d1 100644 --- a/src/workflow/models/dsl.py +++ b/src/workflow/models/dsl.py @@ -53,11 +53,49 @@ class StepConditionBranch(BaseModel): goto: str +class ConvergenceConfig(BaseModel): + """Configuration for loop convergence detection.""" + + type: str = Field( + default="max_iterations", + description="Convergence strategy: 'similarity', 'approval', or 'max_iterations'.", + ) + threshold: float | None = Field( + default=None, + description="Threshold for similarity-based convergence (0.0 to 1.0).", + ) + + +class LoopConfig(BaseModel): + """Configuration for loop (feedback cycle) steps.""" + + max_iterations: int = Field(default=5, ge=1, le=50) + convergence: ConvergenceConfig = Field(default_factory=ConvergenceConfig) + body: list["WorkflowStep"] = Field( + ..., description="Steps to execute in each iteration of the loop." + ) + + +class AggregateConfig(BaseModel): + """Configuration for fan-in aggregation steps.""" + + sources: list[str] = Field( + ..., description="Step IDs whose outputs to aggregate." + ) + strategy: str = Field( + default="merge", + description="Aggregation strategy: 'merge', 'vote', or 'llm_summarize'.", + ) + timeout_seconds: int | None = Field( + default=None, description="Max wait time for all sources to complete." + ) + + class WorkflowStep(BaseModel): id: str type: str | None = Field( default=None, - description="Explicit step type: 'condition', 'sub_workflow', or 'plan'. " + description="Explicit step type: 'condition', 'sub_workflow', 'plan', 'loop', or 'aggregate'. " "Inferred as 'agent' or 'skill' from agent/skill fields when omitted.", ) agent: str | None = None @@ -75,11 +113,23 @@ class WorkflowStep(BaseModel): parallel_with: str | None = None recovery: RecoveryConfig | None = None when: list[StepConditionBranch] | None = None + loop: LoopConfig | None = Field( + default=None, + description="Loop configuration for feedback cycle steps (type='loop').", + ) + aggregate: AggregateConfig | None = Field( + default=None, + description="Aggregation configuration for fan-in steps (type='aggregate').", + ) @property def resolved_type(self) -> str: if self.type: return self.type + if self.loop is not None: + return "loop" + if self.aggregate is not None: + return "aggregate" if self.when is not None: return "condition" if self.workflow is not None: diff --git a/src/workflow/utils/dsl_parser.py b/src/workflow/utils/dsl_parser.py index bc1d19c..286b4ef 100644 --- a/src/workflow/utils/dsl_parser.py +++ b/src/workflow/utils/dsl_parser.py @@ -106,6 +106,25 @@ def _validate_step_refs(wf: WorkflowDef) -> None: raise DSLParseError( f"Plan step '{step.id}' must have a 'goal' field" ) + if step.resolved_type == "loop" and not step.loop: + raise DSLParseError( + f"Loop step '{step.id}' must have a 'loop' configuration" + ) + if step.resolved_type == "loop" and step.loop: + if not step.loop.body: + raise DSLParseError( + f"Loop step '{step.id}' must have at least one step in its body" + ) + if step.resolved_type == "aggregate" and not step.aggregate: + raise DSLParseError( + f"Aggregate step '{step.id}' must have an 'aggregate' configuration" + ) + if step.resolved_type == "aggregate" and step.aggregate: + for source_id in step.aggregate.sources: + if source_id not in step_ids: + raise DSLParseError( + f"Aggregate step '{step.id}' references unknown source '{source_id}'" + ) def _detect_cycles(wf: WorkflowDef) -> None: diff --git a/src/workflow/utils/orchestrator.py b/src/workflow/utils/orchestrator.py index 6c6d426..298709b 100644 --- a/src/workflow/utils/orchestrator.py +++ b/src/workflow/utils/orchestrator.py @@ -209,6 +209,21 @@ async def _execute_step( ) return + if step.resolved_type == "loop": + await self._handle_loop( + step, entry_id, context_id, trigger_data, + engine, step_outputs, default_recovery, step_map, + skipped, execution_id, workflow_def, t0, + ) + return + + if step.resolved_type == "aggregate": + await self._handle_aggregate( + step, entry_id, context_id, + step_outputs, execution_id, t0, + ) + return + if step.resolved_type == "plan": await self._handle_plan( step, entry_id, context_id, trigger_data, @@ -468,3 +483,172 @@ async def _handle_plan( workflow_def, ) ) + + # -- Loop (feedback cycle) handling ---------------------------------------- + + async def _handle_loop( + self, + step: WorkflowStep, + entry_id: uuid.UUID, + context_id: uuid.UUID, + trigger_data: dict[str, Any], + engine: TemplateEngine, + step_outputs: dict[str, Any], + default_recovery: RecoveryConfig, + step_map: dict[str, WorkflowStep], + skipped: set[str], + execution_id: uuid.UUID, + workflow_def: WorkflowDef, + t0: float, + ) -> None: + """Execute a feedback loop until convergence or max iterations.""" + from src.network.utils.convergence import ( + MaxIterationsDetector, + create_detector, + ) + + if not step.loop: + raise StepExecutionError(f"Loop step '{step.id}' has no loop config") + + loop_cfg = step.loop + max_iter = loop_cfg.max_iterations + + # Create convergence detector + mandatory max iterations safety net + convergence = create_detector( + loop_cfg.convergence.type, + { + "threshold": loop_cfg.convergence.threshold, + "max_iterations": max_iter, + }, + ) + safety = MaxIterationsDetector(max_iterations=max_iter) + + previous_output: Any = None + iteration_outputs: list[Any] = [] + + for iteration in range(1, max_iter + 1): + logger.info( + "Loop '%s' iteration %d/%d", step.id, iteration, max_iter, + ) + + # Run all body steps sequentially for this iteration + iter_outputs: dict[str, Any] = {} + for body_step in loop_cfg.body: + # Create a process entry for this iteration's step + iter_step_id = f"{step.id}::iter{iteration}::{body_step.id}" + entry = await self._exec_repo.create_process_entry( + execution_id=execution_id, + step_id=iter_step_id, + step_type=body_step.resolved_type, + target_name=body_step.target_ref or body_step.id, + ) + + # Inject loop iteration context + loop_context = { + "iteration": iteration, + "previous_output": previous_output, + "iteration_outputs": iteration_outputs, + } + step_outputs[f"{step.id}::loop"] = loop_context + + await self._execute_step( + body_step, + entry.id, + context_id, + trigger_data, + step_outputs, + default_recovery, + {bs.id: bs for bs in loop_cfg.body}, + set(), + execution_id, + workflow_def, + ) + + iter_outputs[body_step.id] = step_outputs.get(body_step.id) + + # The last body step's output is considered the iteration output + last_body_step = loop_cfg.body[-1] + current_output = step_outputs.get(last_body_step.id, {}).get("output") + iteration_outputs.append(current_output) + + # Check convergence + converged = await convergence.has_converged( + iteration, current_output, previous_output, + {"step_outputs": step_outputs}, + ) + hit_max = await safety.has_converged( + iteration, current_output, previous_output, {}, + ) + + previous_output = current_output + + if converged or hit_max: + reason = "converged" if converged else "max_iterations" + logger.info( + "Loop '%s' stopped after %d iterations: %s", + step.id, iteration, reason, + ) + break + + duration_ms = int((time.monotonic() - t0) * 1000) + output_data = { + "iterations": len(iteration_outputs), + "converged": converged if 'converged' in dir() else False, + "final_output": previous_output, + } + + step_outputs[step.id] = {"output": output_data} + await self._ctx.write(context_id, step.id, output_data) + await self._exec_repo.mark_process_completed( + entry_id, output=output_data, duration_ms=duration_ms, + ) + + # -- Aggregate (fan-in) handling ------------------------------------------- + + async def _handle_aggregate( + self, + step: WorkflowStep, + entry_id: uuid.UUID, + context_id: uuid.UUID, + step_outputs: dict[str, Any], + execution_id: uuid.UUID, + t0: float, + ) -> None: + """Aggregate outputs from multiple source steps.""" + from src.network.utils.aggregator import create_aggregator + + if not step.aggregate: + raise StepExecutionError( + f"Aggregate step '{step.id}' has no aggregate config" + ) + + agg_cfg = step.aggregate + + # Collect outputs from source steps + inputs: list[dict[str, Any]] = [] + missing_sources: list[str] = [] + for source_id in agg_cfg.sources: + source_output = step_outputs.get(source_id) + if source_output is None: + missing_sources.append(source_id) + else: + inputs.append({ + "source": source_id, + "output": source_output.get("output"), + }) + + if missing_sources: + logger.warning( + "Aggregate step '%s' is missing outputs from: %s", + step.id, missing_sources, + ) + + aggregator = create_aggregator(agg_cfg.strategy) + result = await aggregator.aggregate(inputs) + + duration_ms = int((time.monotonic() - t0) * 1000) + step_outputs[step.id] = {"output": result} + await self._ctx.write(context_id, step.id, result) + await self._exec_repo.mark_process_completed( + entry_id, output=result, duration_ms=duration_ms, + ) From 68d00ef0a63eea3b4b95f4e7301317dc2563b0aa Mon Sep 17 00:00:00 2001 From: Arturo Bautista Date: Wed, 1 Apr 2026 13:52:44 -0600 Subject: [PATCH 2/5] feat: A2A agent discovery and import as first-class agents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add discovery service that fetches remote A2A Agent Cards, registers them in the Intuno registry, generates embeddings, and indexes them in Qdrant. Imported A2A agents become fully discoverable, invocable, and can join communication networks — identical to natively registered agents. New endpoints: - POST /a2a/agents/import — import single agent by URL - POST /a2a/agents/import/batch — import multiple agents - POST /a2a/agents/{id}/refresh — re-fetch card and update - GET /a2a/agents/fetch-card?url= — preview card without importing Co-Authored-By: Claude Opus 4.6 (1M context) --- src/network/a2a/discovery.py | 353 +++++++++++++++++++++++++++++++++++ src/network/a2a/routes.py | 128 ++++++++++++- src/repositories/registry.py | 7 + 3 files changed, 487 insertions(+), 1 deletion(-) create mode 100644 src/network/a2a/discovery.py diff --git a/src/network/a2a/discovery.py b/src/network/a2a/discovery.py new file mode 100644 index 0000000..e5b6e95 --- /dev/null +++ b/src/network/a2a/discovery.py @@ -0,0 +1,353 @@ +"""A2A agent discovery and import. + +Fetches remote Agent Cards from external A2A-compatible services and +registers them as first-class agents in the Intuno registry + Qdrant. +Once imported, A2A agents are discoverable, invocable, and can join +communication networks — exactly like any other agent. +""" + +import logging +from typing import Any, Optional +from urllib.parse import urljoin +from uuid import UUID + +import httpx +from fastapi import Depends + +from src.core.settings import settings +from src.models.registry import Agent +from src.repositories.registry import RegistryRepository +from src.utilities.embedding import EmbeddingService +from src.utilities.qdrant_service import QdrantService + +logger = logging.getLogger(__name__) + +# Default paths to try when fetching an Agent Card +AGENT_CARD_PATHS = [ + "/.well-known/agent.json", + "/agent.json", + "/a2a/agent-card", +] + + +class A2ADiscoveryService: + """Discover and import external A2A agents into the Intuno registry.""" + + def __init__( + self, + registry_repository: RegistryRepository = Depends(), + embedding_service: EmbeddingService = Depends(), + ): + self.registry_repository = registry_repository + self.embedding_service = embedding_service + self.qdrant_service = QdrantService() + self._http_client: Optional[httpx.AsyncClient] = None + + def set_http_client(self, client: httpx.AsyncClient) -> None: + self._http_client = client + + async def fetch_agent_card(self, base_url: str) -> Optional[dict[str, Any]]: + """Fetch an A2A Agent Card from a remote URL. + + Tries well-known paths if the URL doesn't point directly to a card. + """ + client = self._http_client + owns_client = client is None + if owns_client: + client = httpx.AsyncClient(timeout=15) + + try: + # If the URL looks like it already points to a card, try it directly + if base_url.endswith(".json") or base_url.endswith("/agent-card"): + card = await self._try_fetch(client, base_url) + if card: + return card + + # Try well-known paths + normalized = base_url.rstrip("/") + for path in AGENT_CARD_PATHS: + url = f"{normalized}{path}" + card = await self._try_fetch(client, url) + if card: + return card + + return None + finally: + if owns_client: + await client.aclose() + + async def import_agent( + self, + base_url: str, + user_id: UUID, + card: Optional[dict[str, Any]] = None, + ) -> Agent: + """Import an A2A agent as a first-class Intuno agent. + + Fetches the Agent Card (if not provided), extracts metadata, + creates a registry entry, generates embeddings, and indexes + in Qdrant. The resulting agent is fully discoverable and invocable. + """ + if card is None: + card = await self.fetch_agent_card(base_url) + if card is None: + raise ValueError( + f"Could not fetch A2A Agent Card from {base_url}" + ) + + name = card.get("name", "Unknown A2A Agent") + description = card.get("description", "") + version = card.get("version", "1.0.0") + agent_url = card.get("url", base_url) + + # Build a rich description from skills for better embedding + skills = card.get("skills", []) + skills_text = "" + if skills: + skill_descriptions = [ + s.get("description", s.get("name", "")) + for s in skills + if isinstance(s, dict) + ] + skills_text = " | Skills: " + ", ".join(skill_descriptions) + + full_description = f"{description}{skills_text}" + + # Determine invoke endpoint — A2A tasks are sent via JSON-RPC + invoke_endpoint = self._resolve_invoke_endpoint(agent_url, card) + + # Determine auth type from the card + auth_info = card.get("authentication", {}) + schemes = auth_info.get("schemes", []) + if "bearer" in schemes: + auth_type = "bearer_token" + elif "apiKey" in schemes: + auth_type = "api_key" + else: + auth_type = "public" + + # Build tags from skills and capabilities + tags = ["a2a", "external"] + for skill in skills: + if isinstance(skill, dict) and skill.get("name"): + tags.append(skill["name"].lower().replace(" ", "-")) + + capabilities = card.get("capabilities", {}) + if capabilities.get("streaming"): + tags.append("streaming") + + # Check for existing import (by invoke_endpoint) + existing = await self.registry_repository.find_agent_by_endpoint( + invoke_endpoint + ) + if existing: + # Update existing agent with latest card data + existing.name = name + existing.description = full_description + existing.version = version + existing.tags = tags + existing.auth_type = auth_type + + # Re-embed and re-index + agent_embedding = await self._generate_embedding( + name, full_description, tags + ) + await self._upsert_qdrant(existing, agent_embedding) + return await self.registry_repository.update_agent(existing) + + # Generate agent_id + import re + from uuid import uuid4 + + slug = re.sub(r"[^a-z0-9]+", "-", name.lower().strip()).strip("-") + short_id = str(uuid4()).replace("-", "")[:8] + agent_id = f"a2a-{slug}-{short_id}" + + # Build input schema from A2A card if available + input_schema = self._extract_input_schema(card) + + # Create agent + agent = Agent( + agent_id=agent_id, + user_id=user_id, + name=name, + description=full_description, + version=version, + invoke_endpoint=invoke_endpoint, + auth_type=auth_type, + input_schema=input_schema, + tags=tags, + category="a2a", + trust_verification="a2a-card", + supports_streaming=capabilities.get("streaming", False), + ) + + created_agent = await self.registry_repository.create_agent(agent) + + # Embed and index in Qdrant + agent_embedding = await self._generate_embedding( + name, full_description, tags + ) + await self._upsert_qdrant(created_agent, agent_embedding) + + return await self.registry_repository.get_agent_by_id(created_agent.id) + + async def import_multiple( + self, + urls: list[str], + user_id: UUID, + ) -> list[dict[str, Any]]: + """Import multiple A2A agents. Returns results per URL.""" + results = [] + for url in urls: + try: + agent = await self.import_agent(url, user_id) + results.append({ + "url": url, + "success": True, + "agent_id": agent.agent_id, + "name": agent.name, + }) + except Exception as exc: + logger.warning("Failed to import A2A agent from %s: %s", url, exc) + results.append({ + "url": url, + "success": False, + "error": str(exc), + }) + return results + + async def refresh_agent(self, agent_uuid: UUID, user_id: UUID) -> Agent: + """Re-fetch the Agent Card and update the registry entry.""" + agent = await self.registry_repository.get_agent_by_id(agent_uuid) + if not agent: + raise ValueError("Agent not found") + if agent.user_id != user_id: + raise ValueError("Not authorized to refresh this agent") + if "a2a" not in (agent.tags or []): + raise ValueError("Agent is not an A2A import") + + card = await self.fetch_agent_card(agent.invoke_endpoint) + if card is None: + raise ValueError( + f"Could not fetch updated Agent Card from {agent.invoke_endpoint}" + ) + + return await self.import_agent( + agent.invoke_endpoint, user_id, card=card + ) + + # ── Internal helpers ───────────────────────────────────────────── + + async def _try_fetch( + self, client: httpx.AsyncClient, url: str + ) -> Optional[dict[str, Any]]: + try: + response = await client.get( + url, + headers={"Accept": "application/json"}, + timeout=10, + ) + if response.status_code == 200: + data = response.json() + # Validate it looks like an Agent Card + if isinstance(data, dict) and ("name" in data or "skills" in data): + return data + except Exception: + pass + return None + + def _resolve_invoke_endpoint( + self, agent_url: str, card: dict[str, Any] + ) -> str: + """Determine the invoke endpoint from the Agent Card. + + A2A agents typically accept tasks at a /tasks/send endpoint + relative to their base URL. + """ + # Check if the card specifies an explicit endpoint + explicit = card.get("invoke_endpoint") or card.get("endpoint") + if explicit: + return explicit + + # Default: A2A JSON-RPC endpoint + base = agent_url.rstrip("/") + return base + + def _extract_input_schema( + self, card: dict[str, Any] + ) -> Optional[dict[str, Any]]: + """Extract or generate an input schema from the Agent Card.""" + skills = card.get("skills", []) + if not skills: + return { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Input message for the agent", + } + }, + "required": ["message"], + } + + # Build schema from skills + properties = { + "message": { + "type": "string", + "description": "Input message for the agent", + }, + "skill": { + "type": "string", + "description": "Target skill ID", + "enum": [ + s.get("id", s.get("name", "")) + for s in skills + if isinstance(s, dict) + ], + }, + } + return { + "type": "object", + "properties": properties, + "required": ["message"], + } + + async def _generate_embedding( + self, + name: str, + description: str, + tags: list[str], + ) -> list[float]: + """Generate embedding for the agent.""" + try: + return await self.embedding_service.generate_enhanced_embedding( + agent_name=name, + description=description, + tags=tags, + ) + except Exception: + # Fall back to basic embedding + text = self.embedding_service.prepare_agent_text_for_embedding( + name, description, tags + ) + return await self.embedding_service.generate_embedding(text) + + async def _upsert_qdrant( + self, agent: Agent, embedding: list[float] + ) -> None: + """Upsert agent embedding into Qdrant.""" + qdrant_point_id = agent.qdrant_point_id or agent.id + await self.qdrant_service.upsert_vector( + point_id=qdrant_point_id, + vector=embedding, + payload={ + "agent_id": agent.agent_id, + "is_active": True, + "name": agent.name, + "embedding_version": settings.EMBEDDING_VERSION, + }, + ) + agent.qdrant_point_id = qdrant_point_id + agent.embedding_version = settings.EMBEDDING_VERSION + await self.registry_repository.update_agent(agent) diff --git a/src/network/a2a/routes.py b/src/network/a2a/routes.py index 4809ae9..e4330de 100644 --- a/src/network/a2a/routes.py +++ b/src/network/a2a/routes.py @@ -158,7 +158,7 @@ async def a2a_task_send( ) -# ── A2A Agent Discovery ────────────────────────────────────────────── +# ── A2A Agent Discovery & Import ───────────────────────────────────── @router.get("/agents") @@ -172,3 +172,129 @@ async def list_a2a_agents( if getattr(agent, "is_active", False): cards.append(build_agent_card(agent)) return JSONResponse({"agents": cards}) + + +class A2AImportRequest(BaseModel): + """Import an external A2A agent by its base URL.""" + + url: str = Field(..., description="Base URL of the A2A-compatible agent") + + +class A2ABatchImportRequest(BaseModel): + """Import multiple external A2A agents.""" + + urls: list[str] = Field(..., description="List of A2A agent base URLs") + + +@router.post("/agents/import") +async def import_a2a_agent( + data: A2AImportRequest, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Import an external A2A agent as a first-class Intuno agent. + + Fetches the Agent Card from the given URL, creates a registry entry, + generates embeddings, and indexes in Qdrant. The agent becomes fully + discoverable and invocable — just like any natively registered agent. + """ + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + try: + agent = await discovery_service.import_agent(data.url, current_user.id) + return JSONResponse( + { + "success": True, + "agent_id": agent.agent_id, + "name": agent.name, + "description": agent.description, + "invoke_endpoint": agent.invoke_endpoint, + "tags": agent.tags, + }, + status_code=201, + ) + except ValueError as exc: + return JSONResponse( + {"success": False, "error": str(exc)}, + status_code=400, + ) + + +@router.post("/agents/import/batch") +async def import_a2a_agents_batch( + data: A2ABatchImportRequest, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Import multiple external A2A agents in one request.""" + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + results = await discovery_service.import_multiple(data.urls, current_user.id) + return JSONResponse({"results": results}) + + +@router.post("/agents/{agent_id}/refresh") +async def refresh_a2a_agent( + agent_id: str, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Re-fetch the Agent Card and update the registry entry.""" + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + agent = await discovery_service.registry_repository.get_agent_by_agent_id(agent_id) + if not agent: + return JSONResponse( + build_a2a_json_rpc_error(-32602, f"Agent '{agent_id}' not found"), + status_code=404, + ) + + try: + updated = await discovery_service.refresh_agent(agent.id, current_user.id) + return JSONResponse( + { + "success": True, + "agent_id": updated.agent_id, + "name": updated.name, + } + ) + except ValueError as exc: + return JSONResponse( + {"success": False, "error": str(exc)}, + status_code=400, + ) + + +@router.get("/agents/fetch-card") +async def fetch_agent_card_preview( + url: str, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Preview an A2A Agent Card without importing it.""" + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + card = await discovery_service.fetch_agent_card(url) + if card is None: + return JSONResponse( + {"success": False, "error": f"Could not fetch Agent Card from {url}"}, + status_code=404, + ) + return JSONResponse({"success": True, "card": card}) + + +async def _get_discovery_service(request: Request): + """Helper to build a discovery service from request context.""" + from src.network.a2a.discovery import A2ADiscoveryService + from src.database import AsyncSessionLocal + from src.utilities.embedding import EmbeddingService + + session = AsyncSessionLocal() + return A2ADiscoveryService( + registry_repository=RegistryRepository(session=session), + embedding_service=EmbeddingService(), + ) diff --git a/src/repositories/registry.py b/src/repositories/registry.py index 03c2c31..6ae097e 100644 --- a/src/repositories/registry.py +++ b/src/repositories/registry.py @@ -39,6 +39,13 @@ async def get_agent_by_agent_id(self, agent_id: str) -> Optional[Agent]: ) return result.scalar_one_or_none() + async def find_agent_by_endpoint(self, invoke_endpoint: str) -> Optional[Agent]: + """Find an agent by its invoke endpoint URL.""" + result = await self.session.execute( + select(Agent).where(Agent.invoke_endpoint == invoke_endpoint) + ) + return result.scalar_one_or_none() + async def get_agents_by_user_id(self, user_id: UUID) -> List[Agent]: """Get all agents for a user.""" result = await self.session.execute( From a49f20a5b639315f0c325d49159df078cade2a6c Mon Sep 17 00:00:00 2001 From: Arturo Bautista Date: Wed, 1 Apr 2026 14:00:00 -0600 Subject: [PATCH 3/5] docs: add communication networks and A2A documentation Add NETWORKS.md covering communication channels, topologies, workflow loops/aggregation, and the reply_url bidirectional pattern. Add A2A.md covering agent import, discovery, protocol mapping, and examples. Update PROJECT.md with new concepts and doc index. Update API_ENDPOINTS.md with all network, channel, callback, and A2A endpoints. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/A2A.md | 186 ++++++++++++++++++++++++++++++++ docs/API_ENDPOINTS.md | 147 +++++++++++++++++++++++++ docs/NETWORKS.md | 245 ++++++++++++++++++++++++++++++++++++++++++ docs/PROJECT.md | 4 + 4 files changed, 582 insertions(+) create mode 100644 docs/A2A.md create mode 100644 docs/NETWORKS.md diff --git a/docs/A2A.md b/docs/A2A.md new file mode 100644 index 0000000..f6a0263 --- /dev/null +++ b/docs/A2A.md @@ -0,0 +1,186 @@ +# A2A Protocol Integration + +Agent-to-Agent protocol interoperability for Intuno. + +--- + +## Overview + +Intuno supports [Google's A2A protocol](https://google.github.io/A2A/) as an interoperability layer. External A2A-compatible agents can be **imported** into the Intuno registry and become first-class citizens — discoverable via semantic search, invocable through the broker, and able to join communication networks. + +A2A is **not required**. Agents can be registered the simple way (name + description + endpoint) and work identically. A2A is an optional bridge for agents that already speak the protocol. + +--- + +## How It Works + +### Import Flow + +``` +1. User provides a URL + POST /a2a/agents/import { "url": "https://example.com" } + +2. Intuno fetches the Agent Card + GET https://example.com/.well-known/agent.json + +3. Extract metadata + Name, description, skills, capabilities, auth → Agent record + +4. Generate embeddings + Description + skills text → embedding via EmbeddingService + +5. Index in Qdrant + Same collection as all other agents + +6. Result: first-class agent + Discoverable, invocable, can join networks +``` + +Once imported, an A2A agent is indistinguishable from a natively registered agent in discovery results. The only differences: + +- Tagged with `a2a` and `external` +- `trust_verification` is set to `a2a-card` +- `category` is set to `a2a` + +### Agent Card Resolution + +When fetching a card, Intuno tries these paths in order: + +1. The URL itself (if it ends in `.json` or `/agent-card`) +2. `{url}/.well-known/agent.json` +3. `{url}/agent.json` +4. `{url}/a2a/agent-card` + +### Refresh + +A2A agents can be refreshed to sync with their remote card: + +``` +POST /a2a/agents/{agent_id}/refresh +``` + +This re-fetches the card, updates the registry entry, re-generates embeddings, and re-indexes in Qdrant. + +--- + +## API Endpoints + +### Discovery & Import + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/a2a/agents/import` | Import agent from URL | +| POST | `/a2a/agents/import/batch` | Import multiple agents | +| POST | `/a2a/agents/{id}/refresh` | Re-fetch card and update | +| GET | `/a2a/agents/fetch-card?url=` | Preview card without importing | +| GET | `/a2a/agents` | List all agents (with cards) | + +### Agent Cards + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/.well-known/agent.json` | Intuno platform Agent Card | +| GET | `/a2a/agent-card` | Same (alternate path) | +| GET | `/a2a/agents/{id}/agent-card` | Card for a specific agent | + +### A2A Task Endpoint + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/a2a/tasks/send` | A2A JSON-RPC task send | + +--- + +## Protocol Mapping + +| A2A Concept | Intuno Equivalent | +|-------------|-------------------| +| Agent Card | Agent registry entry (name, description, skills, auth) | +| Task | Call channel (synchronous network communication) | +| Message | Message channel (async network communication) | +| Artifact | Response data / metadata | +| Push Notifications | Callback/webhook delivery via `reply_url` | +| Streaming | SSE via existing `invoke_agent_stream` | + +--- + +## Import Request Examples + +### Single Import + +```bash +curl -X POST https://api.intuno.net/a2a/agents/import \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"url": "https://example.com"}' +``` + +Response: + +```json +{ + "success": true, + "agent_id": "a2a-example-agent-a1b2c3d4", + "name": "Example Agent", + "description": "An A2A-compatible agent | Skills: search, summarize", + "invoke_endpoint": "https://example.com", + "tags": ["a2a", "external", "search", "summarize"] +} +``` + +### Batch Import + +```bash +curl -X POST https://api.intuno.net/a2a/agents/import/batch \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"urls": ["https://agent1.com", "https://agent2.com"]}' +``` + +### Preview Card + +```bash +curl "https://api.intuno.net/a2a/agents/fetch-card?url=https://example.com" \ + -H "Authorization: Bearer $TOKEN" +``` + +--- + +## File Layout + +``` +src/network/a2a/ +├── __init__.py +├── agent_card.py # Build Agent Cards from registry entries +├── protocol.py # Translate between Intuno ↔ A2A JSON-RPC format +├── discovery.py # Fetch remote cards, import as first-class agents +└── routes.py # All A2A API endpoints +``` + +--- + +## Platform Agent Card + +Intuno serves its own Agent Card at `GET /.well-known/agent.json`: + +```json +{ + "name": "Intuno Agent Network", + "description": "Registry, broker, and orchestrator for AI agents...", + "url": "https://api.intuno.net", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "networks": true, + "topologies": ["mesh", "star", "ring", "custom"], + "channels": ["call", "message", "mailbox"] + }, + "skills": [ + {"id": "discover", "name": "Discover Agents", "description": "..."}, + {"id": "invoke", "name": "Invoke Agent", "description": "..."}, + {"id": "orchestrate", "name": "Orchestrate Task", "description": "..."}, + {"id": "network", "name": "Communication Network", "description": "..."} + ], + "authentication": {"schemes": ["apiKey", "bearer"]} +} +``` diff --git a/docs/API_ENDPOINTS.md b/docs/API_ENDPOINTS.md index 5780d24..ac11266 100644 --- a/docs/API_ENDPOINTS.md +++ b/docs/API_ENDPOINTS.md @@ -297,6 +297,153 @@ The API supports two authentication methods: Agents must provide a manifest following the Intuno specification: +--- + +## Communication Networks + +### POST /networks +Create a communication network +- **Auth**: Bearer token +- **Body**: `NetworkCreate` (name, topology_type?, metadata?) +- **Response**: `NetworkResponse` + +### GET /networks +List networks (owner-scoped) +- **Auth**: Bearer token +- **Query**: limit?, offset? +- **Response**: `NetworkResponse[]` + +### GET /networks/{id} +Get network details +- **Auth**: Bearer token +- **Response**: `NetworkResponse` + +### PATCH /networks/{id} +Update network +- **Auth**: Bearer token +- **Body**: `NetworkUpdate` (name?, topology_type?, status?, metadata?) +- **Response**: `NetworkResponse` + +### DELETE /networks/{id} +Delete network +- **Auth**: Bearer token + +### POST /networks/{id}/participants +Join a network +- **Auth**: Bearer token +- **Body**: `ParticipantJoin` (name, agent_id?, participant_type?, callback_url?, polling_enabled?, capabilities?) +- **Response**: `ParticipantResponse` + +### GET /networks/{id}/participants +List participants +- **Auth**: Bearer token +- **Response**: `ParticipantResponse[]` + +### PATCH /networks/{id}/participants/{pid} +Update participant +- **Auth**: Bearer token +- **Body**: `ParticipantUpdate` (callback_url?, polling_enabled?, capabilities?, status?) +- **Response**: `ParticipantResponse` + +### DELETE /networks/{id}/participants/{pid} +Remove participant from network +- **Auth**: Bearer token + +### POST /networks/{id}/call +Synchronous call between participants (blocks until response) +- **Auth**: Bearer token +- **Body**: `CallRequest` (sender_participant_id, recipient_participant_id, content, metadata?) +- **Response**: `{ success, message_id, response }` + +### POST /networks/{id}/messages/send +Send near-real-time message (non-blocking, webhook push) +- **Auth**: Bearer token +- **Body**: `MessageRequest` (sender_participant_id, recipient_participant_id, content, metadata?) +- **Response**: `NetworkMessageResponse` + +### POST /networks/{id}/mailbox +Send to mailbox (store only, no push) +- **Auth**: Bearer token +- **Body**: `MailboxRequest` (sender_participant_id, recipient_participant_id, content, metadata?) +- **Response**: `NetworkMessageResponse` + +### GET /networks/{id}/inbox/{pid} +Poll inbox for a participant +- **Auth**: Bearer token +- **Query**: channel_type?, limit? +- **Response**: `NetworkMessageResponse[]` + +### POST /networks/{id}/messages/ack +Acknowledge messages as read +- **Auth**: Bearer token +- **Body**: `AckRequest` (message_ids) +- **Response**: `{ acknowledged: int }` + +### POST /networks/{id}/participants/{pid}/callback +Receive proactive message from external agent (no auth — URL is capability token) +- **Body**: `CallbackPayload` (content, recipient_participant_id?, channel_type?, metadata?, in_reply_to_id?) +- **Response**: `NetworkMessageResponse` + +### GET /networks/{id}/context +Get shared network context (Redis-cached) +- **Auth**: Bearer token +- **Query**: limit? +- **Response**: `{ network_id, entries[] }` + +### GET /networks/{id}/messages +List all messages (Postgres) +- **Auth**: Bearer token +- **Query**: limit?, offset?, channel_type?, participant_id? +- **Response**: `NetworkMessageResponse[]` + +--- + +## A2A (Agent-to-Agent Protocol) + +### GET /.well-known/agent.json +Intuno platform A2A Agent Card + +### GET /a2a/agent-card +Same as above (alternate path) + +### GET /a2a/agents/{agent_id}/agent-card +A2A Agent Card for a specific registered agent + +### GET /a2a/agents +List all active agents with A2A cards + +### POST /a2a/agents/import +Import an external A2A agent by URL (fetches card, registers, embeds, indexes in Qdrant) +- **Auth**: Bearer token +- **Body**: `{ url }` +- **Response**: `{ success, agent_id, name, description, invoke_endpoint, tags }` + +### POST /a2a/agents/import/batch +Import multiple A2A agents +- **Auth**: Bearer token +- **Body**: `{ urls[] }` +- **Response**: `{ results[] }` + +### POST /a2a/agents/{agent_id}/refresh +Re-fetch Agent Card and update registry + embeddings +- **Auth**: Bearer token + +### GET /a2a/agents/fetch-card +Preview an Agent Card without importing +- **Auth**: Bearer token +- **Query**: url +- **Response**: `{ success, card }` + +### POST /a2a/tasks/send +A2A JSON-RPC task send endpoint +- **Auth**: Bearer token +- **Body**: A2A JSON-RPC envelope (jsonrpc, id, method, params) +- **Response**: A2A JSON-RPC response + +--- + +## Agent Manifest Format + ```json { "agent_id": "agent:namespace:name:version", diff --git a/docs/NETWORKS.md b/docs/NETWORKS.md new file mode 100644 index 0000000..868ba08 --- /dev/null +++ b/docs/NETWORKS.md @@ -0,0 +1,245 @@ +# Communication Networks + +Multi-directional agent communication with shared context. + +--- + +## Overview + +A **Communication Network** groups agents (participants) that can exchange messages bidirectionally. Unlike the broker (one-way request-response), networks enable agents to proactively initiate communication with each other. + +The system solves two fundamental limitations: + +1. **Directionality** — the broker only supports caller → callee. Networks allow any participant to message any other. +2. **Invocability asymmetry** — invoking agents weren't registered as invocable endpoints. In a network, every participant must register how they can be reached. + +--- + +## Core Concepts + +### Participants + +Any agent can join a network by registering either: + +- A **callback URL** — Intuno POSTs messages to this endpoint +- **Polling** — the participant checks their inbox via the API + +This means any agent registered in Intuno (simple registration or A2A import) can participate. + +### Three Communication Channels + +| Channel | Timing | Delivery | Use Case | +|---------|--------|----------|----------| +| **Call** | Synchronous, blocking | HTTP request-response | Direct questions, immediate responses | +| **Message** | Near-real-time, non-blocking | Webhook push | Conversational, like chat | +| **Mailbox** | Fully async | Polling only | Batch processing, non-urgent coordination | + +### Shared Context + +Every message exchanged within a network is recorded and accumulated. When Intuno delivers a message to a participant, the payload includes recent conversation history so the agent has context — even if it wasn't part of earlier exchanges. + +### The `reply_url` Pattern + +When Intuno delivers any communication to an external agent, the payload includes a `reply_url`: + +```json +{ + "network_id": "uuid", + "message_id": "uuid", + "channel": "call", + "sender": { + "participant_id": "uuid", + "name": "Writer Agent" + }, + "content": "Please review this draft...", + "context": [ + {"sender": "User", "recipient": "Writer Agent", "channel": "message", "content": "...", "timestamp": 1711900000} + ], + "reply_url": "https://api.intuno.net/networks/{id}/participants/{pid}/callback", + "network_participants": [ + {"participant_id": "uuid", "name": "Writer Agent"}, + {"participant_id": "uuid", "name": "Reviewer Agent"} + ] +} +``` + +The external agent can POST back to the `reply_url` to proactively send messages into the network. No authentication is required on the callback — the URL itself acts as a capability token. + +--- + +## Topology Types + +Networks support four topology types that constrain communication patterns: + +| Topology | Rule | +|----------|------| +| **mesh** (default) | Any participant can communicate with any other | +| **star** | Only the hub (first participant) can initiate | +| **ring** | Messages flow sequentially to the next participant | +| **custom** | No enforcement; topology managed externally | + +--- + +## API Endpoints + +### Networks + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks` | Create a network | +| GET | `/networks` | List networks (owner-scoped) | +| GET | `/networks/{id}` | Get network details | +| PATCH | `/networks/{id}` | Update network | +| DELETE | `/networks/{id}` | Delete network | + +### Participants + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks/{id}/participants` | Join network | +| GET | `/networks/{id}/participants` | List participants | +| PATCH | `/networks/{id}/participants/{pid}` | Update participant | +| DELETE | `/networks/{id}/participants/{pid}` | Leave network | + +### Channels + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks/{id}/call` | Synchronous call | +| POST | `/networks/{id}/messages/send` | Send async message | +| POST | `/networks/{id}/mailbox` | Send to mailbox | +| GET | `/networks/{id}/inbox/{pid}` | Poll inbox | +| POST | `/networks/{id}/messages/ack` | Acknowledge messages | + +### Callbacks + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks/{id}/participants/{pid}/callback` | Receive message from external agent | + +### Context + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/networks/{id}/context` | Get shared context (Redis-cached) | +| GET | `/networks/{id}/messages` | List all messages (Postgres) | + +--- + +## Data Flow + +``` +1. Create network + POST /networks → CommunicationNetwork (mesh, active) + +2. Add participants + POST /networks/{id}/participants → callback_url or polling_enabled + +3. Agent A sends message to Agent B + POST /networks/{id}/messages/send + → Record in DB + Redis context + → POST to Agent B's callback_url (with context + reply_url) + +4. Agent B responds proactively + POST /networks/{id}/participants/{B}/callback + → Record in DB + Redis context + → Forward to Agent A if targeted (with updated context) +``` + +--- + +## Workflow Integration: Loops and Aggregation + +### Loop Steps (Feedback Cycles) + +The workflow DSL supports `loop` steps for iterative agent interactions: + +```yaml +- id: review_loop + type: loop + loop: + max_iterations: 5 + convergence: + type: similarity + threshold: 0.95 + body: + - id: write + agent: "writer-agent" + - id: review + agent: "reviewer-agent" +``` + +**Convergence detectors:** + +| Type | Behavior | +|------|----------| +| `similarity` | Compares consecutive outputs (Jaccard token overlap). Stops when similarity ≥ threshold. | +| `approval` | Checks for approval keywords ("approved", "lgtm") or `{"approved": true}` in output. | +| `max_iterations` | Hard cap. Always enforced as a safety net. | + +### Aggregate Steps (Fan-In) + +Combine outputs from multiple parallel agents: + +```yaml +- id: collect + type: aggregate + aggregate: + sources: [agent_a, agent_b, agent_c] + strategy: llm_summarize + timeout_seconds: 30 +``` + +**Strategies:** + +| Strategy | Behavior | +|----------|----------| +| `merge` | Concatenate all outputs into a single dict, keyed by source step ID | +| `vote` | Pick the majority answer (for classification tasks) | +| `llm_summarize` | Use LLM to synthesize all inputs into a coherent output | + +--- + +## File Layout + +``` +src/network/ +├── __init__.py +├── models/ +│ ├── entities.py # CommunicationNetwork, NetworkParticipant, NetworkMessage +│ └── schemas.py # Pydantic request/response schemas +├── repositories/ +│ └── networks.py # CRUD + context retrieval +├── services/ +│ ├── networks.py # Network management + message recording +│ └── channels.py # Calls, messages, mailboxes, callbacks +├── routes/ +│ ├── networks.py # Network + participant + context endpoints +│ ├── channels.py # Call/message/mailbox/inbox endpoints +│ └── callbacks.py # External agent callback receiver +├── utils/ +│ ├── context_manager.py # Redis Streams context accumulator +│ ├── delivery_worker.py # Background message delivery (Redis Streams consumer) +│ ├── topology.py # Topology validation and routing +│ ├── convergence.py # Convergence detectors for loops +│ └── aggregator.py # Fan-in aggregation strategies +└── a2a/ + ├── agent_card.py # A2A Agent Card generation + ├── protocol.py # A2A ↔ Intuno format translation + ├── discovery.py # Fetch + import external A2A agents + └── routes.py # A2A-compatible API endpoints +``` + +--- + +## Configuration + +Settings in `src/core/settings.py`: + +| Setting | Default | Description | +|---------|---------|-------------| +| `NETWORK_CONTEXT_TTL_SECONDS` | 604800 (7 days) | Redis context stream TTL | +| `NETWORK_CONTEXT_MAX_ENTRIES` | 500 | Max messages in Redis context stream | +| `NETWORK_MAX_PARTICIPANTS` | 50 | Max participants per network | +| `NETWORK_CALLBACK_TIMEOUT_SECONDS` | 30 | HTTP timeout for callback delivery | +| `NETWORK_MESSAGE_DELIVERY_MAX_RETRIES` | 3 | Retry count for failed deliveries | diff --git a/docs/PROJECT.md b/docs/PROJECT.md index f6bc939..71f0303 100644 --- a/docs/PROJECT.md +++ b/docs/PROJECT.md @@ -29,6 +29,8 @@ Additional artifacts: - **Broker** — Proxy for agent-to-agent calls: invoke a capability on a registered agent; handles auth and invocation logging. - **Task (Orchestrator)** — High-level “goal + input” API: the server plans steps, discovers agents, invokes via the broker, and returns a result (sync or async with polling). - **Conversation / Message** — Conversation threads and messages; creation is tied to broker/orchestrator usage; API is read/update/delete and logs. +- **Communication Network** — A group of agents (participants) that can exchange messages bidirectionally through three channels: **calls** (synchronous), **messages** (near-real-time), and **mailboxes** (async). Each participant registers a callback URL; the network accumulates shared conversational context. +- **A2A (Agent-to-Agent)** — Protocol interoperability layer. External A2A-compatible agents can be imported and indexed as first-class Intuno agents, discoverable alongside natively registered ones. --- @@ -174,3 +176,5 @@ Example manifests: `demo/manifests/*.json`. | [TOOL_CALL_GUIDE.md](./TOOL_CALL_GUIDE.md) | Tool-call integration guide. | | [AGENT_REGISTRATION_SUMMARY.md](./AGENT_REGISTRATION_SUMMARY.md) | Agent registration summary. | | [DEMO_README.md](./DEMO_README.md) | Demo usage (duplicate/summary of `demo/README.md`). | +| [NETWORKS.md](./NETWORKS.md) | Communication networks: channels, topologies, and bidirectional agent communication. | +| [A2A.md](./A2A.md) | A2A protocol integration: agent import, discovery, and interoperability. | From 477f98681d4b1cd0286ecaddc770acd7f26ad717 Mon Sep 17 00:00:00 2001 From: Arturo Bautista Date: Wed, 1 Apr 2026 14:12:39 -0600 Subject: [PATCH 4/5] test: add integration tests for communication networks Tests cover the full network lifecycle: create network, add participants, exchange messages and mailbox items, verify shared context, bidirectional callbacks, multi-participant context sharing, A2A platform card, agent card generation, and agent-linked participants. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/conftest.py | 1 + tests/test_networks.py | 528 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 529 insertions(+) create mode 100644 tests/test_networks.py diff --git a/tests/conftest.py b/tests/conftest.py index 8cbac7a..c699b35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ def pytest_collection_modifyitems(config, items): "test_economy.py", "test_workflow.py", "test_new_orchestrator.py", + "test_networks.py", } for item in items: if item.path and item.path.name in integration_files: diff --git a/tests/test_networks.py b/tests/test_networks.py new file mode 100644 index 0000000..b4c0380 --- /dev/null +++ b/tests/test_networks.py @@ -0,0 +1,528 @@ +"""Integration tests for communication networks and multi-directional orchestration. + +These tests exercise the full network lifecycle: creating networks, +adding participants, exchanging messages and mailbox items, verifying +shared context, and importing A2A agents. + +Requires a running backend at BASE_URL (default: http://localhost:8000). +Mark: integration (auto-applied by conftest.py). + +Run: + pytest tests/test_networks.py -v + pytest tests/test_networks.py -v -k "test_network_lifecycle" +""" + +import asyncio +import os +import uuid + +import httpx +import pytest + +BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000") +TIMEOUT = 15 + + +# ── Helpers ────────────────────────────────────────────────────────── + + +async def register_and_login(client: httpx.AsyncClient) -> str: + """Register a test user and return a JWT token.""" + email = f"test-net-{uuid.uuid4().hex[:8]}@test.local" + password = "TestPass123!" + + await client.post( + f"{BASE_URL}/auth/register", + json={ + "email": email, + "password": password, + "first_name": "Net", + "last_name": "Test", + }, + ) + resp = await client.post( + f"{BASE_URL}/auth/login", + json={"email": email, "password": password}, + ) + assert resp.status_code == 200, f"Login failed: {resp.text}" + return resp.json()["access_token"] + + +def auth(token: str) -> dict: + return {"Authorization": f"Bearer {token}"} + + +async def register_test_agent( + client: httpx.AsyncClient, token: str, name: str +) -> dict: + """Register a simple test agent and return its data.""" + resp = await client.post( + f"{BASE_URL}/registry/agents", + headers=auth(token), + json={ + "name": name, + "description": f"Test agent for network integration: {name}", + "endpoint": f"https://httpbin.org/post", + "tags": ["test", "network"], + }, + ) + assert resp.status_code in (200, 201), f"Agent registration failed: {resp.text}" + return resp.json() + + +# ── Fixtures ───────────────────────────────────────────────────────── + + +@pytest.fixture +async def client(): + async with httpx.AsyncClient(timeout=TIMEOUT) as c: + yield c + + +@pytest.fixture +async def authed(client): + """Return (client, token) tuple.""" + token = await register_and_login(client) + return client, token + + +# ── Network Lifecycle ──────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_network_lifecycle(authed): + """Create a network, add participants, exchange messages, verify context.""" + client, token = authed + headers = auth(token) + + # 1. Create network + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Test Mesh Network", "topology_type": "mesh"}, + ) + assert resp.status_code == 201, f"Create network failed: {resp.text}" + network = resp.json() + network_id = network["id"] + assert network["name"] == "Test Mesh Network" + assert network["topology_type"] == "mesh" + assert network["status"] == "active" + + # 2. List networks + resp = await client.get(f"{BASE_URL}/networks", headers=headers) + assert resp.status_code == 200 + networks = resp.json() + assert any(n["id"] == network_id for n in networks) + + # 3. Add participants (with polling since we don't have real callback URLs) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Alice", + "participant_type": "persona", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201, f"Add participant failed: {resp.text}" + alice = resp.json() + alice_id = alice["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Bob", + "participant_type": "persona", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201 + bob = resp.json() + bob_id = bob["id"] + + # 4. List participants + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/participants", headers=headers + ) + assert resp.status_code == 200 + participants = resp.json() + assert len(participants) == 2 + names = {p["name"] for p in participants} + assert names == {"Alice", "Bob"} + + # 5. Send a message (Alice → Bob) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": alice_id, + "recipient_participant_id": bob_id, + "content": "Hey Bob, what do you think about the proposal?", + }, + ) + assert resp.status_code == 201, f"Send message failed: {resp.text}" + msg1 = resp.json() + assert msg1["channel_type"] == "message" + assert msg1["sender_participant_id"] == alice_id + + # 6. Send to mailbox (Bob → Alice) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/mailbox", + headers=headers, + json={ + "sender_participant_id": bob_id, + "recipient_participant_id": alice_id, + "content": "I'll review it tonight and get back to you.", + }, + ) + assert resp.status_code == 201 + msg2 = resp.json() + assert msg2["channel_type"] == "mailbox" + + # 7. Check inbox (Alice should see Bob's mailbox message) + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/inbox/{alice_id}", + headers=headers, + ) + assert resp.status_code == 200 + inbox = resp.json() + assert len(inbox) >= 1 + + # 8. Check shared context + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/context", + headers=headers, + ) + assert resp.status_code == 200 + context = resp.json() + assert len(context["entries"]) >= 2 + senders = {e["sender"] for e in context["entries"]} + assert "Alice" in senders + assert "Bob" in senders + + # 9. List messages + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/messages", + headers=headers, + ) + assert resp.status_code == 200 + messages = resp.json() + assert len(messages) >= 2 + + # 10. Acknowledge messages + message_ids = [m["id"] for m in messages[:1]] + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/ack", + headers=headers, + json={"message_ids": message_ids}, + ) + assert resp.status_code == 200 + assert resp.json()["acknowledged"] == 1 + + # 11. Update network + resp = await client.patch( + f"{BASE_URL}/networks/{network_id}", + headers=headers, + json={"name": "Updated Mesh Network"}, + ) + assert resp.status_code == 200 + assert resp.json()["name"] == "Updated Mesh Network" + + # 12. Remove participant + resp = await client.delete( + f"{BASE_URL}/networks/{network_id}/participants/{bob_id}", + headers=headers, + ) + assert resp.status_code == 204 + + # 13. Verify Bob is removed + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/participants", headers=headers + ) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + # 14. Delete network + resp = await client.delete( + f"{BASE_URL}/networks/{network_id}", headers=headers + ) + assert resp.status_code == 204 + + +# ── Callback (Bidirectional) ───────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_callback_bidirectional(authed): + """External agent pushes a message back via the callback endpoint.""" + client, token = authed + headers = auth(token) + + # Create network + participants + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Callback Test Network"}, + ) + network_id = resp.json()["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={"name": "Agent A", "polling_enabled": True}, + ) + agent_a_id = resp.json()["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={"name": "Agent B", "polling_enabled": True}, + ) + agent_b_id = resp.json()["id"] + + # Agent A sends a message + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": agent_a_id, + "recipient_participant_id": agent_b_id, + "content": "Can you analyze this data?", + }, + ) + assert resp.status_code == 201 + original_msg_id = resp.json()["id"] + + # Agent B responds proactively via callback (no auth required) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants/{agent_b_id}/callback", + json={ + "content": "Analysis complete. Found 3 anomalies.", + "recipient_participant_id": agent_a_id, + "channel_type": "message", + "in_reply_to_id": original_msg_id, + }, + ) + assert resp.status_code == 200, f"Callback failed: {resp.text}" + callback_msg = resp.json() + assert callback_msg["sender_participant_id"] == agent_b_id + + # Verify context has both messages + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/context", headers=headers + ) + context = resp.json() + assert len(context["entries"]) >= 2 + contents = [e["content"] for e in context["entries"]] + assert any("analyze" in c for c in contents) + assert any("anomalies" in c for c in contents) + + # Cleanup + await client.delete(f"{BASE_URL}/networks/{network_id}", headers=headers) + + +# ── Multi-Participant Context Sharing ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_multi_participant_context(authed): + """Three participants exchange messages; all see the full context.""" + client, token = authed + headers = auth(token) + + # Create network with 3 participants + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Multi-Party Context Test"}, + ) + network_id = resp.json()["id"] + + participant_ids = {} + for name in ["Persona A", "Persona B", "Persona C"]: + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={"name": name, "participant_type": "persona", "polling_enabled": True}, + ) + participant_ids[name] = resp.json()["id"] + + # A → B + await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": participant_ids["Persona A"], + "recipient_participant_id": participant_ids["Persona B"], + "content": "Hey B, should we include C in this discussion?", + }, + ) + + # B → C + await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": participant_ids["Persona B"], + "recipient_participant_id": participant_ids["Persona C"], + "content": "C, A wants to loop you in. Thoughts on the project?", + }, + ) + + # C → A (proactive via callback) + await client.post( + f"{BASE_URL}/networks/{network_id}/participants/{participant_ids['Persona C']}/callback", + json={ + "content": "Thanks for including me! I have some ideas to share.", + "recipient_participant_id": participant_ids["Persona A"], + }, + ) + + # Verify shared context has all 3 messages from all 3 senders + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/context", headers=headers + ) + context = resp.json() + senders = {e["sender"] for e in context["entries"]} + assert senders == {"Persona A", "Persona B", "Persona C"} + assert len(context["entries"]) == 3 + + # Cleanup + await client.delete(f"{BASE_URL}/networks/{network_id}", headers=headers) + + +# ── A2A Agent Card ─────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_a2a_platform_card(authed): + """Verify the platform A2A Agent Card is served correctly.""" + client, token = authed + + resp = await client.get(f"{BASE_URL}/.well-known/agent.json") + assert resp.status_code == 200 + card = resp.json() + assert card["name"] == "Intuno Agent Network" + assert "capabilities" in card + assert card["capabilities"]["networks"] is True + assert "call" in card["capabilities"]["channels"] + assert "message" in card["capabilities"]["channels"] + assert "mailbox" in card["capabilities"]["channels"] + assert len(card["skills"]) >= 4 + + +@pytest.mark.asyncio +async def test_a2a_agent_card_for_registered_agent(authed): + """Register an agent and verify its A2A card is generated.""" + client, token = authed + headers = auth(token) + + agent = await register_test_agent(client, token, "A2A Card Test Agent") + agent_id = agent["agent_id"] + + resp = await client.get( + f"{BASE_URL}/a2a/agents/{agent_id}/agent-card", headers=headers + ) + assert resp.status_code == 200 + card = resp.json() + assert card["name"] == "A2A Card Test Agent" + assert "skills" in card + assert len(card["skills"]) >= 1 + + +# ── A2A Fetch Card Preview ────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_a2a_fetch_card_from_self(authed): + """Fetch Intuno's own Agent Card via the preview endpoint.""" + client, token = authed + headers = auth(token) + + resp = await client.get( + f"{BASE_URL}/a2a/agents/fetch-card", + headers=headers, + params={"url": BASE_URL}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["card"]["name"] == "Intuno Agent Network" + + +# ── Network with Agent Participants ────────────────────────────────── + + +@pytest.mark.asyncio +async def test_network_with_agent_participants(authed): + """Create a network where participants are linked to registered agents.""" + client, token = authed + headers = auth(token) + + # Register two agents + agent1 = await register_test_agent(client, token, "Network Agent Alpha") + agent2 = await register_test_agent(client, token, "Network Agent Beta") + + # Create network + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Agent Network Test"}, + ) + network_id = resp.json()["id"] + + # Add agents as participants + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Alpha", + "agent_id": agent1["id"], + "participant_type": "agent", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201 + alpha_id = resp.json()["id"] + assert resp.json()["agent_id"] == agent1["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Beta", + "agent_id": agent2["id"], + "participant_type": "agent", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201 + beta_id = resp.json()["id"] + + # Verify duplicate agent is rejected + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Alpha Duplicate", + "agent_id": agent1["id"], + "polling_enabled": True, + }, + ) + assert resp.status_code == 400 + + # Exchange messages + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": alpha_id, + "recipient_participant_id": beta_id, + "content": "Beta, can you process dataset X?", + }, + ) + assert resp.status_code == 201 + + # Cleanup + await client.delete(f"{BASE_URL}/networks/{network_id}", headers=headers) From c2cbbb083f98f2e2fe1365ad934240d2a45975d3 Mon Sep 17 00:00:00 2001 From: Arturo Bautista Date: Thu, 9 Apr 2026 14:46:40 -0600 Subject: [PATCH 5/5] =?UTF-8?q?feat:=20safety=20governance=20layer=20?= =?UTF-8?q?=E2=80=94=20kill=20switch=20&=20admin=20controls=20(#29)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add safety governance layer with kill switch and admin controls Add platform-wide emergency halt, per-agent kill switch, and admin API to enable ethical control over AI agent operations. This ensures all communication paths (broker, streaming, networks, A2A, workflows) can be shut down instantly when needed. - Add is_admin to User model with migration - Add SafetyService with Redis-backed platform halt and agent status cache - Add admin auth dependency (get_admin_user) - Add admin API: kill/reactivate agents, halt/resume platform, status - Expose is_active in AgentUpdate schema - Enforce platform halt check at all communication chokepoints - Fix streaming broker path missing is_active check - Add agent-level active checks in network channel validation - Add PlatformHaltedException (503) and AgentDisabledException (403) - Filter inactive agents in workflow resolver discovery Co-Authored-By: Claude Opus 4.6 (1M context) * feat: add distributed halt codes and public safety endpoints Trustees can halt the platform with a code — no JWT needed. Codes are bcrypt-hashed, shown once on creation, and managed by admins. Asymmetric by design: easy to stop (code), hard to restart (admin auth). - Add HaltCode model and migration - Add public POST /safety/halt (code-authenticated) - Add public GET /safety/status - Add admin endpoints: create/list/revoke halt codes - Register safety router in main.py Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .../2026_04_09_0001-add_is_admin_to_users.py | 26 +++ .../2026_04_09_0002-add_halt_codes_table.py | 36 ++++ src/core/admin_auth.py | 20 ++ src/core/settings.py | 4 + src/exceptions.py | 17 ++ src/main.py | 6 + src/models/__init__.py | 2 + src/models/auth.py | 1 + src/models/halt_code.py | 25 +++ src/network/a2a/routes.py | 4 + src/network/services/channels.py | 14 ++ src/routes/admin.py | 182 ++++++++++++++++ src/routes/safety.py | 204 ++++++++++++++++++ src/schemas/registry.py | 1 + src/services/broker.py | 22 +- src/services/registry.py | 2 + src/services/safety.py | 145 +++++++++++++ src/workflow/utils/resolver.py | 8 + 18 files changed, 716 insertions(+), 3 deletions(-) create mode 100644 alembic/versions/2026_04_09_0001-add_is_admin_to_users.py create mode 100644 alembic/versions/2026_04_09_0002-add_halt_codes_table.py create mode 100644 src/core/admin_auth.py create mode 100644 src/models/halt_code.py create mode 100644 src/routes/admin.py create mode 100644 src/routes/safety.py create mode 100644 src/services/safety.py diff --git a/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py b/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py new file mode 100644 index 0000000..a0b4ac4 --- /dev/null +++ b/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py @@ -0,0 +1,26 @@ +"""add is_admin column to users table + +Revision ID: add_is_admin_to_users +Revises: add_communication_networks +Create Date: 2026-04-09 00:01:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "add_is_admin_to_users" +down_revision = "add_communication_networks" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("is_admin", sa.Boolean(), nullable=False, server_default="false"), + ) + + +def downgrade() -> None: + op.drop_column("users", "is_admin") diff --git a/alembic/versions/2026_04_09_0002-add_halt_codes_table.py b/alembic/versions/2026_04_09_0002-add_halt_codes_table.py new file mode 100644 index 0000000..daeb287 --- /dev/null +++ b/alembic/versions/2026_04_09_0002-add_halt_codes_table.py @@ -0,0 +1,36 @@ +"""add halt_codes table for distributed kill switch + +Revision ID: add_halt_codes +Revises: add_is_admin_to_users +Create Date: 2026-04-09 00:02:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +# revision identifiers, used by Alembic. +revision = "add_halt_codes" +down_revision = "add_is_admin_to_users" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "halt_codes", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column("code_hash", sa.String(), nullable=False), + sa.Column("label", sa.String(), nullable=False), + sa.Column("trustee_name", sa.String(), nullable=False), + sa.Column("trustee_email", sa.String(), nullable=True), + sa.Column("is_master", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table("halt_codes") diff --git a/src/core/admin_auth.py b/src/core/admin_auth.py new file mode 100644 index 0000000..9ec5123 --- /dev/null +++ b/src/core/admin_auth.py @@ -0,0 +1,20 @@ +"""Admin authentication dependency.""" + +from fastapi import Depends + +from src.core.auth import get_current_user +from src.exceptions import ForbiddenException +from src.models.auth import User + + +async def get_admin_user( + current_user: User = Depends(get_current_user), +) -> User: + """Require the current user to be an admin. + + Wraps get_current_user and raises 403 if the user does not have + the is_admin flag set. + """ + if not getattr(current_user, "is_admin", False): + raise ForbiddenException("Admin access required") + return current_user diff --git a/src/core/settings.py b/src/core/settings.py index 097e7cc..3495d30 100644 --- a/src/core/settings.py +++ b/src/core/settings.py @@ -103,6 +103,10 @@ class Settings(BaseSettings): NETWORK_CALLBACK_TIMEOUT_SECONDS: int = 30 NETWORK_MESSAGE_DELIVERY_MAX_RETRIES: int = 3 + # ── Safety & Governance ───────────────────────────────────────────── + SAFETY_CHECK_ENABLED: bool = True + AGENT_STATUS_CACHE_TTL: int = 300 # seconds to cache agent active status in Redis + # ── Economy settings (from agent-economy) ────────────────────────── ECONOMY_WELCOME_BONUS_CREDITS: int = 500 ECONOMY_CREDIT_PACKAGES: list[dict] = [ diff --git a/src/exceptions.py b/src/exceptions.py index 70a6836..940bab7 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -12,6 +12,8 @@ "ValidationException", "DatabaseException", "RateLimitException", + "PlatformHaltedException", + "AgentDisabledException", ] @@ -91,3 +93,18 @@ class RateLimitException(BaseCustomException): def __init__(self, detail: str = "Rate limit exceeded. Please try again later"): super().__init__(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail) + + +# Safety & Governance Exceptions +class PlatformHaltedException(BaseCustomException): + """Exception raised when the platform is in emergency halt mode""" + + def __init__(self, detail: str = "Platform is in emergency halt mode. All agent operations are suspended."): + super().__init__(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail) + + +class AgentDisabledException(BaseCustomException): + """Exception raised when a disabled agent is invoked""" + + def __init__(self, detail: str = "Agent has been disabled by an administrator"): + super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) diff --git a/src/main.py b/src/main.py index b82e095..e67ef58 100644 --- a/src/main.py +++ b/src/main.py @@ -30,6 +30,8 @@ from src.routes.message import router as message_router from src.routes.registry import router as registry_router from src.routes.task import router as task_router +from src.routes.admin import router as admin_router +from src.routes.safety import router as safety_router from src.mcp_app import create_mcp_app # Workflow routers (from agent-os) @@ -232,6 +234,10 @@ async def handle_workflow_exception(_request: Request, exc: WorkflowAppException app.include_router(invocation_log_router) app.include_router(task_router) +# ── Admin / Safety routers ─────────────────────────────────────────── +app.include_router(admin_router, tags=["Admin"]) +app.include_router(safety_router, tags=["Safety"]) + # ── Workflow routers (from agent-os) ───────────────────────────────── app.include_router(workflow_router, prefix="/workflows", tags=["Workflows"]) app.include_router(execution_router, tags=["Executions"]) diff --git a/src/models/__init__.py b/src/models/__init__.py index 78b58a1..fe6882e 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -12,6 +12,7 @@ from src.models.message import Message from src.models.registry import Agent, AgentCredential, AgentRating from src.models.task import Task +from src.models.halt_code import HaltCode # Workflow models (from agent-os) from src.workflow.models.entities import ( # noqa: F401 @@ -47,6 +48,7 @@ "AgentCredential", "InvocationLog", "Task", + "HaltCode", # Workflow "WorkflowDefinition", "WorkflowExecution", diff --git a/src/models/auth.py b/src/models/auth.py index 816be30..cb577da 100644 --- a/src/models/auth.py +++ b/src/models/auth.py @@ -22,6 +22,7 @@ class User(BaseModel): last_name: Column[str] = Column(String, nullable=True) phone_number: Column[str] = Column(String, nullable=True, unique=True) is_active: Column[bool] = Column(Boolean, default=True, nullable=False) + is_admin: Column[bool] = Column(Boolean, default=False, nullable=False) # Relationships api_keys = relationship("ApiKey", back_populates="user", cascade="all, delete-orphan") diff --git a/src/models/halt_code.py b/src/models/halt_code.py new file mode 100644 index 0000000..f9260a9 --- /dev/null +++ b/src/models/halt_code.py @@ -0,0 +1,25 @@ +"""Halt code model — distributed kill switch codes for trustees.""" + +from typing import Optional +from uuid import UUID + +from sqlalchemy import Boolean, Column, ForeignKey, String +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID + +from .base import BaseModel + + +class HaltCode(BaseModel): + """A halt code held by a trustee who can stop the platform.""" + + __tablename__: str = "halt_codes" + + code_hash: Column[str] = Column(String, nullable=False) + label: Column[str] = Column(String, nullable=False) + trustee_name: Column[str] = Column(String, nullable=False) + trustee_email: Column[Optional[str]] = Column(String, nullable=True) + is_master: Column[bool] = Column(Boolean, default=False, nullable=False) + is_active: Column[bool] = Column(Boolean, default=True, nullable=False) + created_by: Column[UUID] = Column( + PostgresUUID, ForeignKey("users.id"), nullable=False + ) diff --git a/src/network/a2a/routes.py b/src/network/a2a/routes.py index e4330de..d16cb60 100644 --- a/src/network/a2a/routes.py +++ b/src/network/a2a/routes.py @@ -79,6 +79,10 @@ async def a2a_task_send( from src.network.utils.context_manager import NetworkContextManager from src.database import get_redis + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + params = data.params task_data = params.get("task", {}) network_id = params.get("network_id") diff --git a/src/network/services/channels.py b/src/network/services/channels.py index 01f6591..09a4025 100644 --- a/src/network/services/channels.py +++ b/src/network/services/channels.py @@ -258,6 +258,10 @@ async def handle_callback( This is the key to bidirectionality: external agents POST to their reply_url and this method records the message in the network. """ + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + sender = await self.repo.get_participant(participant_id) if not sender or sender.network_id != network_id: raise NotFoundException("Participant") @@ -322,6 +326,10 @@ async def _validate_communication( sender_id: UUID, recipient_id: UUID, ) -> tuple[NetworkParticipant, NetworkParticipant, CommunicationNetwork]: + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_agent_active, check_platform_halt + await check_platform_halt() + network = await self.repo.get_network(network_id) if not network: raise NotFoundException("Network") @@ -340,6 +348,12 @@ async def _validate_communication( if recipient.status != ParticipantStatus.active: raise BadRequestException("Recipient is not active") + # Safety check: verify linked agents are still active + if sender.agent_id: + await check_agent_active(sender.agent_id) + if recipient.agent_id: + await check_agent_active(recipient.agent_id) + return sender, recipient, network async def _record_message( diff --git a/src/routes/admin.py b/src/routes/admin.py new file mode 100644 index 0000000..d73e1be --- /dev/null +++ b/src/routes/admin.py @@ -0,0 +1,182 @@ +"""Admin routes — platform governance and agent kill switch. + +All endpoints require admin privileges via the get_admin_user dependency. +""" + +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.admin_auth import get_admin_user +from src.database import get_db +from src.exceptions import NotFoundException +from src.models.auth import User +from src.models.registry import Agent +from src.services import safety + +router = APIRouter(prefix="/admin", tags=["Admin"]) + + +# ── Request/Response schemas ──────────────────────────────────────── + + +class KillAgentRequest(BaseModel): + reason: str = Field(..., min_length=1, description="Why is this agent being disabled?") + + +class HaltPlatformRequest(BaseModel): + reason: str = Field(..., min_length=1, description="Why is the platform being halted?") + + +class AgentStatusResponse(BaseModel): + agent_id: str + agent_uuid: UUID + name: str + is_active: bool + owner_id: UUID + + +class PlatformStatusResponse(BaseModel): + halted: bool + reason: Optional[str] = None + halted_by: Optional[str] = None + redis_available: bool + disabled_agent_count: int = 0 + + +# ── Agent kill switch ─────────────────────────────────────────────── + + +@router.post("/agents/{agent_uuid}/kill") +async def kill_agent( + agent_uuid: UUID, + body: KillAgentRequest, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Force-disable an agent. Sets is_active=False and caches in Redis.""" + result = await session.execute(select(Agent).where(Agent.id == agent_uuid)) + agent = result.scalar_one_or_none() + if not agent: + raise NotFoundException("Agent") + + agent.is_active = False + await session.commit() + + # Cache kill in Redis for fast rejection + await safety.kill_agent(agent_uuid) + + return { + "success": True, + "agent_id": agent.agent_id, + "agent_uuid": str(agent_uuid), + "reason": body.reason, + "killed_by": str(admin.id), + } + + +@router.post("/agents/{agent_uuid}/reactivate") +async def reactivate_agent( + agent_uuid: UUID, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Re-enable a previously disabled agent.""" + result = await session.execute(select(Agent).where(Agent.id == agent_uuid)) + agent = result.scalar_one_or_none() + if not agent: + raise NotFoundException("Agent") + + agent.is_active = True + await session.commit() + + # Clear Redis kill cache + await safety.reactivate_agent(agent_uuid) + + return { + "success": True, + "agent_id": agent.agent_id, + "agent_uuid": str(agent_uuid), + "reactivated_by": str(admin.id), + } + + +@router.get("/agents/disabled", response_model=list[AgentStatusResponse]) +async def list_disabled_agents( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """List all currently disabled agents.""" + result = await session.execute( + select(Agent).where(Agent.is_active == False).order_by(Agent.updated_at.desc()) # noqa: E712 + ) + agents = result.scalars().all() + + return [ + AgentStatusResponse( + agent_id=a.agent_id, + agent_uuid=a.id, + name=a.name, + is_active=a.is_active, + owner_id=a.user_id, + ) + for a in agents + ] + + +# ── Platform halt ─────────────────────────────────────────────────── + + +@router.post("/platform/halt") +async def halt_platform( + body: HaltPlatformRequest, + admin: User = Depends(get_admin_user), +): + """Emergency halt — suspend all agent operations platform-wide.""" + await safety.halt_platform(body.reason, admin.id) + return { + "success": True, + "halted": True, + "reason": body.reason, + "halted_by": str(admin.id), + } + + +@router.post("/platform/resume") +async def resume_platform( + admin: User = Depends(get_admin_user), +): + """Resume platform operations after an emergency halt.""" + await safety.resume_platform(admin.id) + return { + "success": True, + "halted": False, + "resumed_by": str(admin.id), + } + + +@router.get("/platform/status", response_model=PlatformStatusResponse) +async def get_platform_status( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Get current platform safety status.""" + status = await safety.get_platform_status() + + # Count disabled agents + result = await session.execute( + select(func.count()).select_from(Agent).where(Agent.is_active == False) # noqa: E712 + ) + disabled_count = result.scalar() or 0 + + return PlatformStatusResponse( + halted=status.get("halted", False), + reason=status.get("reason"), + halted_by=status.get("halted_by"), + redis_available=status.get("redis_available", False), + disabled_agent_count=disabled_count, + ) diff --git a/src/routes/safety.py b/src/routes/safety.py new file mode 100644 index 0000000..726ae08 --- /dev/null +++ b/src/routes/safety.py @@ -0,0 +1,204 @@ +"""Safety routes — public halt endpoint and admin halt code management. + +The halt endpoint is intentionally PUBLIC (no JWT required). The code +itself is the authentication. This is by design: it should be easy to +stop the platform, harder to restart it. +""" + +import secrets +from typing import Optional +from uuid import UUID + +import bcrypt +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.admin_auth import get_admin_user +from src.database import get_db +from src.models.auth import User +from src.models.halt_code import HaltCode +from src.services import safety + +router = APIRouter(prefix="/safety", tags=["Safety"]) + + +# ── Schemas ───────────────────────────────────────────────────────── + + +class HaltRequest(BaseModel): + code: str = Field(..., min_length=1, description="Halt code issued to a trustee") + reason: Optional[str] = Field(default=None, description="Optional reason for halting") + + +class HaltResponse(BaseModel): + halted: bool + trustee: str + reason: Optional[str] = None + message: str + + +class CreateHaltCodeRequest(BaseModel): + trustee_name: str = Field(..., min_length=1) + trustee_email: Optional[str] = None + label: str = Field(..., min_length=1, description="Human-readable label, e.g. 'Guardian - Europe'") + is_master: bool = Field(default=False) + + +class CreateHaltCodeResponse(BaseModel): + id: UUID + label: str + trustee_name: str + is_master: bool + code: str = Field(description="The plaintext code — shown ONCE, never stored") + + +class HaltCodeListItem(BaseModel): + id: UUID + label: str + trustee_name: str + trustee_email: Optional[str] + is_master: bool + is_active: bool + + +class PlatformStatusPublic(BaseModel): + halted: bool + message: str + + +# ── Public endpoints (no auth) ────────────────────────────────────── + + +@router.get("/status", response_model=PlatformStatusPublic) +async def get_public_status(): + """Public platform status — anyone can check if the platform is halted.""" + status = await safety.get_platform_status() + halted = status.get("halted", False) + return PlatformStatusPublic( + halted=halted, + message="Platform is halted. All agent operations are suspended." if halted + else "Platform is operational.", + ) + + +@router.post("/halt", response_model=HaltResponse) +async def halt_with_code( + body: HaltRequest, + session: AsyncSession = Depends(get_db), +): + """Halt the platform using a trustee code. + + This endpoint is PUBLIC — no JWT required. The halt code is the + authentication. By design, stopping the platform should be easy. + Restarting requires admin authentication. + """ + # Find all active halt codes and check against each + result = await session.execute( + select(HaltCode).where(HaltCode.is_active == True) # noqa: E712 + ) + halt_codes = result.scalars().all() + + matched_code = None + for hc in halt_codes: + if bcrypt.checkpw(body.code.encode("utf-8"), hc.code_hash.encode("utf-8")): + matched_code = hc + break + + if not matched_code: + from src.exceptions import ForbiddenException + raise ForbiddenException("Invalid halt code") + + reason = body.reason or f"Halted by trustee: {matched_code.trustee_name}" + await safety.halt_platform(reason, matched_code.created_by) + + return HaltResponse( + halted=True, + trustee=matched_code.trustee_name, + reason=reason, + message="Platform has been halted. All agent operations are suspended.", + ) + + +# ── Admin endpoints (manage halt codes) ───────────────────────────── + + +@router.post("/codes", response_model=CreateHaltCodeResponse) +async def create_halt_code( + body: CreateHaltCodeRequest, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Create a new halt code for a trustee. The plaintext code is returned + ONCE and never stored — only its bcrypt hash is persisted.""" + # Generate a secure random code: 8 groups of 4 chars + raw_code = "-".join( + secrets.token_hex(2).upper() for _ in range(4) + ) + + code_hash = bcrypt.hashpw( + raw_code.encode("utf-8"), bcrypt.gensalt() + ).decode("utf-8") + + halt_code = HaltCode( + code_hash=code_hash, + label=body.label, + trustee_name=body.trustee_name, + trustee_email=body.trustee_email, + is_master=body.is_master, + created_by=admin.id, + ) + session.add(halt_code) + await session.commit() + await session.refresh(halt_code) + + return CreateHaltCodeResponse( + id=halt_code.id, + label=body.label, + trustee_name=body.trustee_name, + is_master=body.is_master, + code=raw_code, + ) + + +@router.get("/codes", response_model=list[HaltCodeListItem]) +async def list_halt_codes( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """List all halt codes (without the actual codes — those are never stored).""" + result = await session.execute( + select(HaltCode).order_by(HaltCode.created_at.desc()) + ) + codes = result.scalars().all() + return [ + HaltCodeListItem( + id=c.id, + label=c.label, + trustee_name=c.trustee_name, + trustee_email=c.trustee_email, + is_master=c.is_master, + is_active=c.is_active, + ) + for c in codes + ] + + +@router.delete("/codes/{code_id}") +async def revoke_halt_code( + code_id: UUID, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Revoke a halt code — it can no longer be used to halt the platform.""" + result = await session.execute(select(HaltCode).where(HaltCode.id == code_id)) + halt_code = result.scalar_one_or_none() + if not halt_code: + from src.exceptions import NotFoundException + raise NotFoundException("Halt code") + + halt_code.is_active = False + await session.commit() + + return {"success": True, "revoked": str(code_id)} diff --git a/src/schemas/registry.py b/src/schemas/registry.py index d566e24..9b5d1c6 100644 --- a/src/schemas/registry.py +++ b/src/schemas/registry.py @@ -115,6 +115,7 @@ class AgentUpdate(BaseModel): base_price: Optional[float] = None pricing_enabled: Optional[bool] = None supports_streaming: Optional[bool] = None + is_active: Optional[bool] = None @field_validator("auth_type") @classmethod diff --git a/src/services/broker.py b/src/services/broker.py index 2f2e634..de8b81c 100644 --- a/src/services/broker.py +++ b/src/services/broker.py @@ -88,6 +88,10 @@ async def invoke_agent( """ start_time = time.time() + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + # Resolve conversation_id and message_id from request if not passed conv_id = conversation_id or invoke_request.conversation_id msg_id = message_id or invoke_request.message_id @@ -532,13 +536,25 @@ async def invoke_agent_stream( async generator of dicts. Otherwise, falls back to a normal invocation and returns an InvokeResponse. """ + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + # Resolve agent to check streaming support agent = await self.registry_repository.get_agent_by_agent_id( invoke_request.agent_id ) - supports_streaming = ( - getattr(agent, "supports_streaming", False) if agent else False - ) + + # Check agent exists and is active (fixes gap where streaming path skipped this) + if not agent or not agent.is_active: + return InvokeResponse( + success=False, + error="Agent not found or inactive", + latency_ms=0, + status_code=404, + ) + + supports_streaming = getattr(agent, "supports_streaming", False) if not supports_streaming: # Fall back to normal invocation diff --git a/src/services/registry.py b/src/services/registry.py index ace1c91..4686a27 100644 --- a/src/services/registry.py +++ b/src/services/registry.py @@ -271,6 +271,8 @@ async def update_agent( agent.base_price = update.base_price if update.pricing_enabled is not None: agent.pricing_enabled = update.pricing_enabled + if update.is_active is not None: + agent.is_active = update.is_active # Regenerate embedding if enhance: diff --git a/src/services/safety.py b/src/services/safety.py new file mode 100644 index 0000000..4f87576 --- /dev/null +++ b/src/services/safety.py @@ -0,0 +1,145 @@ +"""Safety service — platform halt, agent kill switch, and safety checks. + +Provides the central "off switch" for the Intuno platform. All communication +chokepoints (broker, channels, A2A) call into this service before processing. +""" + +import logging +from typing import Optional +from uuid import UUID + +from src.core.redis_client import get_redis +from src.core.settings import settings +from src.exceptions import AgentDisabledException, PlatformHaltedException + +logger = logging.getLogger(__name__) + +# Redis key constants +EMERGENCY_HALT_KEY = "platform:emergency_halt" +EMERGENCY_HALT_REASON_KEY = "platform:emergency_halt:reason" +EMERGENCY_HALT_ACTOR_KEY = "platform:emergency_halt:actor" +AGENT_STATUS_PREFIX = "agent:status:" + + +async def check_platform_halt() -> None: + """Raise PlatformHaltedException if the platform is in emergency halt. + + This is designed to be called at every communication chokepoint. + Fast O(1) Redis GET — adds ~0.1ms overhead per call. + Fails open if Redis is unavailable (consistent with rate limiter pattern). + """ + if not settings.SAFETY_CHECK_ENABLED: + return + + redis = await get_redis() + if not redis: + return # Fail open: if Redis is down, allow requests + + try: + halted = await redis.get(EMERGENCY_HALT_KEY) + if halted == "1": + reason = await redis.get(EMERGENCY_HALT_REASON_KEY) + detail = "Platform is in emergency halt mode." + if reason: + detail += f" Reason: {reason}" + raise PlatformHaltedException(detail) + except PlatformHaltedException: + raise + except Exception as e: + logger.warning("Safety check (platform halt) failed: %s", e) + + +async def check_agent_active(agent_id: UUID) -> None: + """Raise AgentDisabledException if the agent has been killed/deactivated. + + Checks Redis cache first, falls back to no-op if unavailable. + The authoritative is_active check remains in the broker/service layer + via the DB — this adds a fast-path rejection for killed agents. + """ + if not settings.SAFETY_CHECK_ENABLED: + return + + redis = await get_redis() + if not redis: + return # Fail open + + try: + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + cached = await redis.get(key) + if cached == "0": + raise AgentDisabledException() + except AgentDisabledException: + raise + except Exception as e: + logger.warning("Safety check (agent status) failed: %s", e) + + +async def halt_platform(reason: str, actor_id: UUID) -> None: + """Activate emergency halt — all agent operations will be rejected.""" + redis = await get_redis() + if not redis: + raise RuntimeError("Redis is required for platform halt") + + await redis.set(EMERGENCY_HALT_KEY, "1") + await redis.set(EMERGENCY_HALT_REASON_KEY, reason) + await redis.set(EMERGENCY_HALT_ACTOR_KEY, str(actor_id)) + logger.critical( + "PLATFORM HALT activated by user %s. Reason: %s", + actor_id, + reason, + ) + + +async def resume_platform(actor_id: UUID) -> None: + """Deactivate emergency halt — resume normal operations.""" + redis = await get_redis() + if not redis: + raise RuntimeError("Redis is required for platform resume") + + await redis.delete(EMERGENCY_HALT_KEY) + await redis.delete(EMERGENCY_HALT_REASON_KEY) + await redis.delete(EMERGENCY_HALT_ACTOR_KEY) + logger.critical("PLATFORM HALT deactivated by user %s", actor_id) + + +async def kill_agent(agent_id: UUID) -> None: + """Cache agent as killed in Redis for fast rejection at chokepoints.""" + redis = await get_redis() + if not redis: + return + + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + await redis.set(key, "0", ex=settings.AGENT_STATUS_CACHE_TTL) + logger.warning("Agent %s killed (cached in Redis)", agent_id) + + +async def reactivate_agent(agent_id: UUID) -> None: + """Remove killed status from Redis cache.""" + redis = await get_redis() + if not redis: + return + + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + await redis.delete(key) + logger.info("Agent %s reactivated (Redis cache cleared)", agent_id) + + +async def get_platform_status() -> dict: + """Get current platform safety status.""" + redis = await get_redis() + if not redis: + return {"halted": False, "redis_available": False} + + try: + halted = await redis.get(EMERGENCY_HALT_KEY) + reason = await redis.get(EMERGENCY_HALT_REASON_KEY) + actor = await redis.get(EMERGENCY_HALT_ACTOR_KEY) + return { + "halted": halted == "1", + "reason": reason, + "halted_by": actor, + "redis_available": True, + } + except Exception as e: + logger.warning("Failed to get platform status: %s", e) + return {"halted": False, "redis_available": False, "error": str(e)} diff --git a/src/workflow/utils/resolver.py b/src/workflow/utils/resolver.py index 7876a98..790d482 100644 --- a/src/workflow/utils/resolver.py +++ b/src/workflow/utils/resolver.py @@ -50,6 +50,10 @@ async def resolve( if cache_key in self._cache: return self._cache[cache_key] + # Safety check: reject if platform is halted + from src.services.safety import check_platform_halt + await check_platform_halt() + if ref.startswith(SEARCH_PREFIX): query = ref[len(SEARCH_PREFIX):].strip() target = await self._discover(query, exclude_ids or []) @@ -133,6 +137,10 @@ async def _discover( for agent, _distance in results: if agent.agent_id in cb_excluded: continue + if not agent.is_active: + logger.info("Skipping agent '%s' — inactive", agent.agent_id) + cb_excluded.append(agent.agent_id) + continue available = await self._circuit_breaker.is_available(agent.agent_id) if not available: logger.info(