diff --git a/alembic/versions/2026_04_01_0001-add_communication_networks.py b/alembic/versions/2026_04_01_0001-add_communication_networks.py new file mode 100644 index 0000000..16a2220 --- /dev/null +++ b/alembic/versions/2026_04_01_0001-add_communication_networks.py @@ -0,0 +1,178 @@ +"""add communication networks, participants, and messages tables + +Revision ID: add_communication_networks +Revises: add_supports_streaming +Create Date: 2026-04-01 00:01:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# revision identifiers, used by Alembic. +revision = "add_communication_networks" +down_revision = "add_supports_streaming" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create enum types + op.execute( + "CREATE TYPE topologytype AS ENUM ('mesh', 'star', 'ring', 'custom')" + ) + op.execute( + "CREATE TYPE networkstatus AS ENUM ('active', 'paused', 'closed')" + ) + op.execute( + "CREATE TYPE participanttype AS ENUM ('agent', 'persona', 'orchestrator')" + ) + op.execute( + "CREATE TYPE participantstatus AS ENUM ('active', 'disconnected', 'removed')" + ) + op.execute( + "CREATE TYPE channeltype AS ENUM ('call', 'message', 'mailbox')" + ) + op.execute( + "CREATE TYPE messagestatus AS ENUM ('pending', 'delivered', 'read', 'failed')" + ) + + # Communication networks + op.create_table( + "communication_networks", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column("owner_id", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column( + "topology_type", + sa.Enum("mesh", "star", "ring", "custom", name="topologytype", create_type=False), + nullable=False, + server_default="mesh", + ), + sa.Column("metadata", JSONB, nullable=True), + sa.Column( + "status", + sa.Enum("active", "paused", "closed", name="networkstatus", create_type=False), + nullable=False, + server_default="active", + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + op.create_index("ix_communication_networks_owner_id", "communication_networks", ["owner_id"]) + + # Network participants + op.create_table( + "network_participants", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "network_id", + UUID(as_uuid=True), + sa.ForeignKey("communication_networks.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("agent_id", UUID(as_uuid=True), sa.ForeignKey("agents.id"), nullable=True), + sa.Column( + "participant_type", + sa.Enum("agent", "persona", "orchestrator", name="participanttype", create_type=False), + nullable=False, + server_default="agent", + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("callback_url", sa.Text, nullable=True), + sa.Column("polling_enabled", sa.Boolean, nullable=False, server_default="false"), + sa.Column("capabilities", JSONB, nullable=True), + sa.Column( + "status", + sa.Enum("active", "disconnected", "removed", name="participantstatus", create_type=False), + nullable=False, + server_default="active", + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + op.create_index("ix_network_participants_network_id", "network_participants", ["network_id"]) + op.create_index("ix_network_participants_agent_id", "network_participants", ["agent_id"]) + + # Network messages + op.create_table( + "network_messages", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "network_id", + UUID(as_uuid=True), + sa.ForeignKey("communication_networks.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "sender_participant_id", + UUID(as_uuid=True), + sa.ForeignKey("network_participants.id"), + nullable=False, + ), + sa.Column( + "recipient_participant_id", + UUID(as_uuid=True), + sa.ForeignKey("network_participants.id"), + nullable=True, + ), + sa.Column( + "channel_type", + sa.Enum("call", "message", "mailbox", name="channeltype", create_type=False), + nullable=False, + ), + sa.Column("content", sa.Text, nullable=False), + sa.Column("metadata", JSONB, nullable=True), + sa.Column( + "status", + sa.Enum("pending", "delivered", "read", "failed", name="messagestatus", create_type=False), + nullable=False, + server_default="pending", + ), + sa.Column( + "in_reply_to_id", + UUID(as_uuid=True), + sa.ForeignKey("network_messages.id"), + nullable=True, + ), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + op.create_index("ix_network_messages_network_id", "network_messages", ["network_id"]) + op.create_index( + "ix_network_messages_sender", "network_messages", ["sender_participant_id"] + ) + op.create_index( + "ix_network_messages_recipient", "network_messages", ["recipient_participant_id"] + ) + op.create_index( + "ix_network_messages_created_at", "network_messages", ["network_id", "created_at"] + ) + + +def downgrade() -> None: + op.drop_table("network_messages") + op.drop_table("network_participants") + op.drop_table("communication_networks") + + op.execute("DROP TYPE IF EXISTS messagestatus") + op.execute("DROP TYPE IF EXISTS channeltype") + op.execute("DROP TYPE IF EXISTS participantstatus") + op.execute("DROP TYPE IF EXISTS participanttype") + op.execute("DROP TYPE IF EXISTS networkstatus") + op.execute("DROP TYPE IF EXISTS topologytype") diff --git a/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py b/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py new file mode 100644 index 0000000..a0b4ac4 --- /dev/null +++ b/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py @@ -0,0 +1,26 @@ +"""add is_admin column to users table + +Revision ID: add_is_admin_to_users +Revises: add_communication_networks +Create Date: 2026-04-09 00:01:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "add_is_admin_to_users" +down_revision = "add_communication_networks" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("is_admin", sa.Boolean(), nullable=False, server_default="false"), + ) + + +def downgrade() -> None: + op.drop_column("users", "is_admin") diff --git a/alembic/versions/2026_04_09_0002-add_halt_codes_table.py b/alembic/versions/2026_04_09_0002-add_halt_codes_table.py new file mode 100644 index 0000000..daeb287 --- /dev/null +++ b/alembic/versions/2026_04_09_0002-add_halt_codes_table.py @@ -0,0 +1,36 @@ +"""add halt_codes table for distributed kill switch + +Revision ID: add_halt_codes +Revises: add_is_admin_to_users +Create Date: 2026-04-09 00:02:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +# revision identifiers, used by Alembic. +revision = "add_halt_codes" +down_revision = "add_is_admin_to_users" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "halt_codes", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column("code_hash", sa.String(), nullable=False), + sa.Column("label", sa.String(), nullable=False), + sa.Column("trustee_name", sa.String(), nullable=False), + sa.Column("trustee_email", sa.String(), nullable=True), + sa.Column("is_master", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table("halt_codes") diff --git a/docs/A2A.md b/docs/A2A.md new file mode 100644 index 0000000..f6a0263 --- /dev/null +++ b/docs/A2A.md @@ -0,0 +1,186 @@ +# A2A Protocol Integration + +Agent-to-Agent protocol interoperability for Intuno. + +--- + +## Overview + +Intuno supports [Google's A2A protocol](https://google.github.io/A2A/) as an interoperability layer. External A2A-compatible agents can be **imported** into the Intuno registry and become first-class citizens — discoverable via semantic search, invocable through the broker, and able to join communication networks. + +A2A is **not required**. Agents can be registered the simple way (name + description + endpoint) and work identically. A2A is an optional bridge for agents that already speak the protocol. + +--- + +## How It Works + +### Import Flow + +``` +1. User provides a URL + POST /a2a/agents/import { "url": "https://example.com" } + +2. Intuno fetches the Agent Card + GET https://example.com/.well-known/agent.json + +3. Extract metadata + Name, description, skills, capabilities, auth → Agent record + +4. Generate embeddings + Description + skills text → embedding via EmbeddingService + +5. Index in Qdrant + Same collection as all other agents + +6. Result: first-class agent + Discoverable, invocable, can join networks +``` + +Once imported, an A2A agent is indistinguishable from a natively registered agent in discovery results. The only differences: + +- Tagged with `a2a` and `external` +- `trust_verification` is set to `a2a-card` +- `category` is set to `a2a` + +### Agent Card Resolution + +When fetching a card, Intuno tries these paths in order: + +1. The URL itself (if it ends in `.json` or `/agent-card`) +2. `{url}/.well-known/agent.json` +3. `{url}/agent.json` +4. `{url}/a2a/agent-card` + +### Refresh + +A2A agents can be refreshed to sync with their remote card: + +``` +POST /a2a/agents/{agent_id}/refresh +``` + +This re-fetches the card, updates the registry entry, re-generates embeddings, and re-indexes in Qdrant. + +--- + +## API Endpoints + +### Discovery & Import + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/a2a/agents/import` | Import agent from URL | +| POST | `/a2a/agents/import/batch` | Import multiple agents | +| POST | `/a2a/agents/{id}/refresh` | Re-fetch card and update | +| GET | `/a2a/agents/fetch-card?url=` | Preview card without importing | +| GET | `/a2a/agents` | List all agents (with cards) | + +### Agent Cards + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/.well-known/agent.json` | Intuno platform Agent Card | +| GET | `/a2a/agent-card` | Same (alternate path) | +| GET | `/a2a/agents/{id}/agent-card` | Card for a specific agent | + +### A2A Task Endpoint + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/a2a/tasks/send` | A2A JSON-RPC task send | + +--- + +## Protocol Mapping + +| A2A Concept | Intuno Equivalent | +|-------------|-------------------| +| Agent Card | Agent registry entry (name, description, skills, auth) | +| Task | Call channel (synchronous network communication) | +| Message | Message channel (async network communication) | +| Artifact | Response data / metadata | +| Push Notifications | Callback/webhook delivery via `reply_url` | +| Streaming | SSE via existing `invoke_agent_stream` | + +--- + +## Import Request Examples + +### Single Import + +```bash +curl -X POST https://api.intuno.net/a2a/agents/import \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"url": "https://example.com"}' +``` + +Response: + +```json +{ + "success": true, + "agent_id": "a2a-example-agent-a1b2c3d4", + "name": "Example Agent", + "description": "An A2A-compatible agent | Skills: search, summarize", + "invoke_endpoint": "https://example.com", + "tags": ["a2a", "external", "search", "summarize"] +} +``` + +### Batch Import + +```bash +curl -X POST https://api.intuno.net/a2a/agents/import/batch \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"urls": ["https://agent1.com", "https://agent2.com"]}' +``` + +### Preview Card + +```bash +curl "https://api.intuno.net/a2a/agents/fetch-card?url=https://example.com" \ + -H "Authorization: Bearer $TOKEN" +``` + +--- + +## File Layout + +``` +src/network/a2a/ +├── __init__.py +├── agent_card.py # Build Agent Cards from registry entries +├── protocol.py # Translate between Intuno ↔ A2A JSON-RPC format +├── discovery.py # Fetch remote cards, import as first-class agents +└── routes.py # All A2A API endpoints +``` + +--- + +## Platform Agent Card + +Intuno serves its own Agent Card at `GET /.well-known/agent.json`: + +```json +{ + "name": "Intuno Agent Network", + "description": "Registry, broker, and orchestrator for AI agents...", + "url": "https://api.intuno.net", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "networks": true, + "topologies": ["mesh", "star", "ring", "custom"], + "channels": ["call", "message", "mailbox"] + }, + "skills": [ + {"id": "discover", "name": "Discover Agents", "description": "..."}, + {"id": "invoke", "name": "Invoke Agent", "description": "..."}, + {"id": "orchestrate", "name": "Orchestrate Task", "description": "..."}, + {"id": "network", "name": "Communication Network", "description": "..."} + ], + "authentication": {"schemes": ["apiKey", "bearer"]} +} +``` diff --git a/docs/API_ENDPOINTS.md b/docs/API_ENDPOINTS.md index 5780d24..ac11266 100644 --- a/docs/API_ENDPOINTS.md +++ b/docs/API_ENDPOINTS.md @@ -297,6 +297,153 @@ The API supports two authentication methods: Agents must provide a manifest following the Intuno specification: +--- + +## Communication Networks + +### POST /networks +Create a communication network +- **Auth**: Bearer token +- **Body**: `NetworkCreate` (name, topology_type?, metadata?) +- **Response**: `NetworkResponse` + +### GET /networks +List networks (owner-scoped) +- **Auth**: Bearer token +- **Query**: limit?, offset? +- **Response**: `NetworkResponse[]` + +### GET /networks/{id} +Get network details +- **Auth**: Bearer token +- **Response**: `NetworkResponse` + +### PATCH /networks/{id} +Update network +- **Auth**: Bearer token +- **Body**: `NetworkUpdate` (name?, topology_type?, status?, metadata?) +- **Response**: `NetworkResponse` + +### DELETE /networks/{id} +Delete network +- **Auth**: Bearer token + +### POST /networks/{id}/participants +Join a network +- **Auth**: Bearer token +- **Body**: `ParticipantJoin` (name, agent_id?, participant_type?, callback_url?, polling_enabled?, capabilities?) +- **Response**: `ParticipantResponse` + +### GET /networks/{id}/participants +List participants +- **Auth**: Bearer token +- **Response**: `ParticipantResponse[]` + +### PATCH /networks/{id}/participants/{pid} +Update participant +- **Auth**: Bearer token +- **Body**: `ParticipantUpdate` (callback_url?, polling_enabled?, capabilities?, status?) +- **Response**: `ParticipantResponse` + +### DELETE /networks/{id}/participants/{pid} +Remove participant from network +- **Auth**: Bearer token + +### POST /networks/{id}/call +Synchronous call between participants (blocks until response) +- **Auth**: Bearer token +- **Body**: `CallRequest` (sender_participant_id, recipient_participant_id, content, metadata?) +- **Response**: `{ success, message_id, response }` + +### POST /networks/{id}/messages/send +Send near-real-time message (non-blocking, webhook push) +- **Auth**: Bearer token +- **Body**: `MessageRequest` (sender_participant_id, recipient_participant_id, content, metadata?) +- **Response**: `NetworkMessageResponse` + +### POST /networks/{id}/mailbox +Send to mailbox (store only, no push) +- **Auth**: Bearer token +- **Body**: `MailboxRequest` (sender_participant_id, recipient_participant_id, content, metadata?) +- **Response**: `NetworkMessageResponse` + +### GET /networks/{id}/inbox/{pid} +Poll inbox for a participant +- **Auth**: Bearer token +- **Query**: channel_type?, limit? +- **Response**: `NetworkMessageResponse[]` + +### POST /networks/{id}/messages/ack +Acknowledge messages as read +- **Auth**: Bearer token +- **Body**: `AckRequest` (message_ids) +- **Response**: `{ acknowledged: int }` + +### POST /networks/{id}/participants/{pid}/callback +Receive proactive message from external agent (no auth — URL is capability token) +- **Body**: `CallbackPayload` (content, recipient_participant_id?, channel_type?, metadata?, in_reply_to_id?) +- **Response**: `NetworkMessageResponse` + +### GET /networks/{id}/context +Get shared network context (Redis-cached) +- **Auth**: Bearer token +- **Query**: limit? +- **Response**: `{ network_id, entries[] }` + +### GET /networks/{id}/messages +List all messages (Postgres) +- **Auth**: Bearer token +- **Query**: limit?, offset?, channel_type?, participant_id? +- **Response**: `NetworkMessageResponse[]` + +--- + +## A2A (Agent-to-Agent Protocol) + +### GET /.well-known/agent.json +Intuno platform A2A Agent Card + +### GET /a2a/agent-card +Same as above (alternate path) + +### GET /a2a/agents/{agent_id}/agent-card +A2A Agent Card for a specific registered agent + +### GET /a2a/agents +List all active agents with A2A cards + +### POST /a2a/agents/import +Import an external A2A agent by URL (fetches card, registers, embeds, indexes in Qdrant) +- **Auth**: Bearer token +- **Body**: `{ url }` +- **Response**: `{ success, agent_id, name, description, invoke_endpoint, tags }` + +### POST /a2a/agents/import/batch +Import multiple A2A agents +- **Auth**: Bearer token +- **Body**: `{ urls[] }` +- **Response**: `{ results[] }` + +### POST /a2a/agents/{agent_id}/refresh +Re-fetch Agent Card and update registry + embeddings +- **Auth**: Bearer token + +### GET /a2a/agents/fetch-card +Preview an Agent Card without importing +- **Auth**: Bearer token +- **Query**: url +- **Response**: `{ success, card }` + +### POST /a2a/tasks/send +A2A JSON-RPC task send endpoint +- **Auth**: Bearer token +- **Body**: A2A JSON-RPC envelope (jsonrpc, id, method, params) +- **Response**: A2A JSON-RPC response + +--- + +## Agent Manifest Format + ```json { "agent_id": "agent:namespace:name:version", diff --git a/docs/NETWORKS.md b/docs/NETWORKS.md new file mode 100644 index 0000000..868ba08 --- /dev/null +++ b/docs/NETWORKS.md @@ -0,0 +1,245 @@ +# Communication Networks + +Multi-directional agent communication with shared context. + +--- + +## Overview + +A **Communication Network** groups agents (participants) that can exchange messages bidirectionally. Unlike the broker (one-way request-response), networks enable agents to proactively initiate communication with each other. + +The system solves two fundamental limitations: + +1. **Directionality** — the broker only supports caller → callee. Networks allow any participant to message any other. +2. **Invocability asymmetry** — invoking agents weren't registered as invocable endpoints. In a network, every participant must register how they can be reached. + +--- + +## Core Concepts + +### Participants + +Any agent can join a network by registering either: + +- A **callback URL** — Intuno POSTs messages to this endpoint +- **Polling** — the participant checks their inbox via the API + +This means any agent registered in Intuno (simple registration or A2A import) can participate. + +### Three Communication Channels + +| Channel | Timing | Delivery | Use Case | +|---------|--------|----------|----------| +| **Call** | Synchronous, blocking | HTTP request-response | Direct questions, immediate responses | +| **Message** | Near-real-time, non-blocking | Webhook push | Conversational, like chat | +| **Mailbox** | Fully async | Polling only | Batch processing, non-urgent coordination | + +### Shared Context + +Every message exchanged within a network is recorded and accumulated. When Intuno delivers a message to a participant, the payload includes recent conversation history so the agent has context — even if it wasn't part of earlier exchanges. + +### The `reply_url` Pattern + +When Intuno delivers any communication to an external agent, the payload includes a `reply_url`: + +```json +{ + "network_id": "uuid", + "message_id": "uuid", + "channel": "call", + "sender": { + "participant_id": "uuid", + "name": "Writer Agent" + }, + "content": "Please review this draft...", + "context": [ + {"sender": "User", "recipient": "Writer Agent", "channel": "message", "content": "...", "timestamp": 1711900000} + ], + "reply_url": "https://api.intuno.net/networks/{id}/participants/{pid}/callback", + "network_participants": [ + {"participant_id": "uuid", "name": "Writer Agent"}, + {"participant_id": "uuid", "name": "Reviewer Agent"} + ] +} +``` + +The external agent can POST back to the `reply_url` to proactively send messages into the network. No authentication is required on the callback — the URL itself acts as a capability token. + +--- + +## Topology Types + +Networks support four topology types that constrain communication patterns: + +| Topology | Rule | +|----------|------| +| **mesh** (default) | Any participant can communicate with any other | +| **star** | Only the hub (first participant) can initiate | +| **ring** | Messages flow sequentially to the next participant | +| **custom** | No enforcement; topology managed externally | + +--- + +## API Endpoints + +### Networks + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks` | Create a network | +| GET | `/networks` | List networks (owner-scoped) | +| GET | `/networks/{id}` | Get network details | +| PATCH | `/networks/{id}` | Update network | +| DELETE | `/networks/{id}` | Delete network | + +### Participants + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks/{id}/participants` | Join network | +| GET | `/networks/{id}/participants` | List participants | +| PATCH | `/networks/{id}/participants/{pid}` | Update participant | +| DELETE | `/networks/{id}/participants/{pid}` | Leave network | + +### Channels + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks/{id}/call` | Synchronous call | +| POST | `/networks/{id}/messages/send` | Send async message | +| POST | `/networks/{id}/mailbox` | Send to mailbox | +| GET | `/networks/{id}/inbox/{pid}` | Poll inbox | +| POST | `/networks/{id}/messages/ack` | Acknowledge messages | + +### Callbacks + +| Method | Path | Description | +|--------|------|-------------| +| POST | `/networks/{id}/participants/{pid}/callback` | Receive message from external agent | + +### Context + +| Method | Path | Description | +|--------|------|-------------| +| GET | `/networks/{id}/context` | Get shared context (Redis-cached) | +| GET | `/networks/{id}/messages` | List all messages (Postgres) | + +--- + +## Data Flow + +``` +1. Create network + POST /networks → CommunicationNetwork (mesh, active) + +2. Add participants + POST /networks/{id}/participants → callback_url or polling_enabled + +3. Agent A sends message to Agent B + POST /networks/{id}/messages/send + → Record in DB + Redis context + → POST to Agent B's callback_url (with context + reply_url) + +4. Agent B responds proactively + POST /networks/{id}/participants/{B}/callback + → Record in DB + Redis context + → Forward to Agent A if targeted (with updated context) +``` + +--- + +## Workflow Integration: Loops and Aggregation + +### Loop Steps (Feedback Cycles) + +The workflow DSL supports `loop` steps for iterative agent interactions: + +```yaml +- id: review_loop + type: loop + loop: + max_iterations: 5 + convergence: + type: similarity + threshold: 0.95 + body: + - id: write + agent: "writer-agent" + - id: review + agent: "reviewer-agent" +``` + +**Convergence detectors:** + +| Type | Behavior | +|------|----------| +| `similarity` | Compares consecutive outputs (Jaccard token overlap). Stops when similarity ≥ threshold. | +| `approval` | Checks for approval keywords ("approved", "lgtm") or `{"approved": true}` in output. | +| `max_iterations` | Hard cap. Always enforced as a safety net. | + +### Aggregate Steps (Fan-In) + +Combine outputs from multiple parallel agents: + +```yaml +- id: collect + type: aggregate + aggregate: + sources: [agent_a, agent_b, agent_c] + strategy: llm_summarize + timeout_seconds: 30 +``` + +**Strategies:** + +| Strategy | Behavior | +|----------|----------| +| `merge` | Concatenate all outputs into a single dict, keyed by source step ID | +| `vote` | Pick the majority answer (for classification tasks) | +| `llm_summarize` | Use LLM to synthesize all inputs into a coherent output | + +--- + +## File Layout + +``` +src/network/ +├── __init__.py +├── models/ +│ ├── entities.py # CommunicationNetwork, NetworkParticipant, NetworkMessage +│ └── schemas.py # Pydantic request/response schemas +├── repositories/ +│ └── networks.py # CRUD + context retrieval +├── services/ +│ ├── networks.py # Network management + message recording +│ └── channels.py # Calls, messages, mailboxes, callbacks +├── routes/ +│ ├── networks.py # Network + participant + context endpoints +│ ├── channels.py # Call/message/mailbox/inbox endpoints +│ └── callbacks.py # External agent callback receiver +├── utils/ +│ ├── context_manager.py # Redis Streams context accumulator +│ ├── delivery_worker.py # Background message delivery (Redis Streams consumer) +│ ├── topology.py # Topology validation and routing +│ ├── convergence.py # Convergence detectors for loops +│ └── aggregator.py # Fan-in aggregation strategies +└── a2a/ + ├── agent_card.py # A2A Agent Card generation + ├── protocol.py # A2A ↔ Intuno format translation + ├── discovery.py # Fetch + import external A2A agents + └── routes.py # A2A-compatible API endpoints +``` + +--- + +## Configuration + +Settings in `src/core/settings.py`: + +| Setting | Default | Description | +|---------|---------|-------------| +| `NETWORK_CONTEXT_TTL_SECONDS` | 604800 (7 days) | Redis context stream TTL | +| `NETWORK_CONTEXT_MAX_ENTRIES` | 500 | Max messages in Redis context stream | +| `NETWORK_MAX_PARTICIPANTS` | 50 | Max participants per network | +| `NETWORK_CALLBACK_TIMEOUT_SECONDS` | 30 | HTTP timeout for callback delivery | +| `NETWORK_MESSAGE_DELIVERY_MAX_RETRIES` | 3 | Retry count for failed deliveries | diff --git a/docs/PROJECT.md b/docs/PROJECT.md index f6bc939..71f0303 100644 --- a/docs/PROJECT.md +++ b/docs/PROJECT.md @@ -29,6 +29,8 @@ Additional artifacts: - **Broker** — Proxy for agent-to-agent calls: invoke a capability on a registered agent; handles auth and invocation logging. - **Task (Orchestrator)** — High-level “goal + input” API: the server plans steps, discovers agents, invokes via the broker, and returns a result (sync or async with polling). - **Conversation / Message** — Conversation threads and messages; creation is tied to broker/orchestrator usage; API is read/update/delete and logs. +- **Communication Network** — A group of agents (participants) that can exchange messages bidirectionally through three channels: **calls** (synchronous), **messages** (near-real-time), and **mailboxes** (async). Each participant registers a callback URL; the network accumulates shared conversational context. +- **A2A (Agent-to-Agent)** — Protocol interoperability layer. External A2A-compatible agents can be imported and indexed as first-class Intuno agents, discoverable alongside natively registered ones. --- @@ -174,3 +176,5 @@ Example manifests: `demo/manifests/*.json`. | [TOOL_CALL_GUIDE.md](./TOOL_CALL_GUIDE.md) | Tool-call integration guide. | | [AGENT_REGISTRATION_SUMMARY.md](./AGENT_REGISTRATION_SUMMARY.md) | Agent registration summary. | | [DEMO_README.md](./DEMO_README.md) | Demo usage (duplicate/summary of `demo/README.md`). | +| [NETWORKS.md](./NETWORKS.md) | Communication networks: channels, topologies, and bidirectional agent communication. | +| [A2A.md](./A2A.md) | A2A protocol integration: agent import, discovery, and interoperability. | diff --git a/src/core/admin_auth.py b/src/core/admin_auth.py new file mode 100644 index 0000000..9ec5123 --- /dev/null +++ b/src/core/admin_auth.py @@ -0,0 +1,20 @@ +"""Admin authentication dependency.""" + +from fastapi import Depends + +from src.core.auth import get_current_user +from src.exceptions import ForbiddenException +from src.models.auth import User + + +async def get_admin_user( + current_user: User = Depends(get_current_user), +) -> User: + """Require the current user to be an admin. + + Wraps get_current_user and raises 403 if the user does not have + the is_admin flag set. + """ + if not getattr(current_user, "is_admin", False): + raise ForbiddenException("Admin access required") + return current_user diff --git a/src/core/settings.py b/src/core/settings.py index a5723d6..3495d30 100644 --- a/src/core/settings.py +++ b/src/core/settings.py @@ -96,6 +96,17 @@ class Settings(BaseSettings): WORKFLOW_DEFAULT_MAX_CONCURRENT_PER_AGENT: int = 10 WORKFLOW_DEFAULT_MAX_CONCURRENT_EXECUTIONS: int = 5 + # ── Network settings (communication networks) ───────────────────── + NETWORK_CONTEXT_TTL_SECONDS: int = 86400 * 7 # 7 days + NETWORK_CONTEXT_MAX_ENTRIES: int = 500 # max messages in Redis context stream + NETWORK_MAX_PARTICIPANTS: int = 50 + NETWORK_CALLBACK_TIMEOUT_SECONDS: int = 30 + NETWORK_MESSAGE_DELIVERY_MAX_RETRIES: int = 3 + + # ── Safety & Governance ───────────────────────────────────────────── + SAFETY_CHECK_ENABLED: bool = True + AGENT_STATUS_CACHE_TTL: int = 300 # seconds to cache agent active status in Redis + # ── Economy settings (from agent-economy) ────────────────────────── ECONOMY_WELCOME_BONUS_CREDITS: int = 500 ECONOMY_CREDIT_PACKAGES: list[dict] = [ diff --git a/src/exceptions.py b/src/exceptions.py index 70a6836..940bab7 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -12,6 +12,8 @@ "ValidationException", "DatabaseException", "RateLimitException", + "PlatformHaltedException", + "AgentDisabledException", ] @@ -91,3 +93,18 @@ class RateLimitException(BaseCustomException): def __init__(self, detail: str = "Rate limit exceeded. Please try again later"): super().__init__(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail) + + +# Safety & Governance Exceptions +class PlatformHaltedException(BaseCustomException): + """Exception raised when the platform is in emergency halt mode""" + + def __init__(self, detail: str = "Platform is in emergency halt mode. All agent operations are suspended."): + super().__init__(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail) + + +class AgentDisabledException(BaseCustomException): + """Exception raised when a disabled agent is invoked""" + + def __init__(self, detail: str = "Agent has been disabled by an administrator"): + super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) diff --git a/src/main.py b/src/main.py index 5a6707c..e67ef58 100644 --- a/src/main.py +++ b/src/main.py @@ -30,6 +30,8 @@ from src.routes.message import router as message_router from src.routes.registry import router as registry_router from src.routes.task import router as task_router +from src.routes.admin import router as admin_router +from src.routes.safety import router as safety_router from src.mcp_app import create_mcp_app # Workflow routers (from agent-os) @@ -37,6 +39,12 @@ from src.workflow.routes.executions import router as execution_router from src.workflow.routes.webhooks import router as webhook_router +# Network routers (communication networks) +from src.network.routes.networks import router as network_router +from src.network.routes.channels import router as channel_router +from src.network.routes.callbacks import router as callback_router +from src.network.a2a.routes import router as a2a_router + # Economy routers (from agent-economy) from src.economy.routes.wallets import router as wallets_router from src.economy.routes.market import router as market_router @@ -226,11 +234,21 @@ async def handle_workflow_exception(_request: Request, exc: WorkflowAppException app.include_router(invocation_log_router) app.include_router(task_router) +# ── Admin / Safety routers ─────────────────────────────────────────── +app.include_router(admin_router, tags=["Admin"]) +app.include_router(safety_router, tags=["Safety"]) + # ── Workflow routers (from agent-os) ───────────────────────────────── app.include_router(workflow_router, prefix="/workflows", tags=["Workflows"]) app.include_router(execution_router, tags=["Executions"]) app.include_router(webhook_router, tags=["Webhooks"]) +# ── Network routers (communication networks) ───────────────────────── +app.include_router(network_router, tags=["Networks"]) +app.include_router(channel_router, tags=["Channels"]) +app.include_router(callback_router, tags=["Callbacks"]) +app.include_router(a2a_router, tags=["A2A"]) + # ── Economy routers (from agent-economy) ───────────────────────────── app.include_router(wallets_router, prefix="/wallets", tags=["Wallets"]) app.include_router(market_router, prefix="/market", tags=["Market"]) @@ -244,38 +262,9 @@ async def handle_workflow_exception(_request: Request, exc: WorkflowAppException @app.get("/.well-known/agent.json") async def a2a_agent_card(): """A2A-compatible AgentCard for agent-to-agent discovery.""" - return JSONResponse( - { - "name": "Intuno Agent Network", - "description": "Registry, broker, and orchestrator for AI agents", - "url": settings.BASE_URL, - "version": settings.API_VERSION, - "capabilities": { - "streaming": True, - "pushNotifications": False, - }, - "skills": [ - { - "id": "discover", - "name": "Discover Agents", - "description": "Semantic search for AI agents by natural-language query", - }, - { - "id": "invoke", - "name": "Invoke Agent", - "description": "Execute an agent with input data through the broker", - }, - { - "id": "orchestrate", - "name": "Orchestrate Task", - "description": "Multi-step task orchestration across multiple agents", - }, - ], - "authentication": { - "schemes": ["apiKey", "bearer"], - }, - } - ) + from src.network.a2a.agent_card import build_platform_card + + return JSONResponse(build_platform_card()) @app.get("/.well-known/mcp/server-card.json") diff --git a/src/models/__init__.py b/src/models/__init__.py index 4e36376..fe6882e 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -12,6 +12,7 @@ from src.models.message import Message from src.models.registry import Agent, AgentCredential, AgentRating from src.models.task import Task +from src.models.halt_code import HaltCode # Workflow models (from agent-os) from src.workflow.models.entities import ( # noqa: F401 @@ -21,6 +22,13 @@ WorkflowExecution, ) +# Network models (communication networks) +from src.network.models.entities import ( # noqa: F401 + CommunicationNetwork, + NetworkMessage, + NetworkParticipant, +) + # Economy models (from agent-economy) from src.economy.models.wallet import Transaction, Wallet # noqa: F401 from src.economy.models.order import Order, Trade # noqa: F401 @@ -40,11 +48,16 @@ "AgentCredential", "InvocationLog", "Task", + "HaltCode", # Workflow "WorkflowDefinition", "WorkflowExecution", "ProcessEntry", "ContextEntry", + # Network + "CommunicationNetwork", + "NetworkParticipant", + "NetworkMessage", # Economy "Wallet", "Transaction", diff --git a/src/models/auth.py b/src/models/auth.py index 816be30..cb577da 100644 --- a/src/models/auth.py +++ b/src/models/auth.py @@ -22,6 +22,7 @@ class User(BaseModel): last_name: Column[str] = Column(String, nullable=True) phone_number: Column[str] = Column(String, nullable=True, unique=True) is_active: Column[bool] = Column(Boolean, default=True, nullable=False) + is_admin: Column[bool] = Column(Boolean, default=False, nullable=False) # Relationships api_keys = relationship("ApiKey", back_populates="user", cascade="all, delete-orphan") diff --git a/src/models/halt_code.py b/src/models/halt_code.py new file mode 100644 index 0000000..f9260a9 --- /dev/null +++ b/src/models/halt_code.py @@ -0,0 +1,25 @@ +"""Halt code model — distributed kill switch codes for trustees.""" + +from typing import Optional +from uuid import UUID + +from sqlalchemy import Boolean, Column, ForeignKey, String +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID + +from .base import BaseModel + + +class HaltCode(BaseModel): + """A halt code held by a trustee who can stop the platform.""" + + __tablename__: str = "halt_codes" + + code_hash: Column[str] = Column(String, nullable=False) + label: Column[str] = Column(String, nullable=False) + trustee_name: Column[str] = Column(String, nullable=False) + trustee_email: Column[Optional[str]] = Column(String, nullable=True) + is_master: Column[bool] = Column(Boolean, default=False, nullable=False) + is_active: Column[bool] = Column(Boolean, default=True, nullable=False) + created_by: Column[UUID] = Column( + PostgresUUID, ForeignKey("users.id"), nullable=False + ) diff --git a/src/network/__init__.py b/src/network/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/a2a/__init__.py b/src/network/a2a/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/a2a/agent_card.py b/src/network/a2a/agent_card.py new file mode 100644 index 0000000..c4ce7f3 --- /dev/null +++ b/src/network/a2a/agent_card.py @@ -0,0 +1,127 @@ +"""A2A Agent Card generation and serving. + +Generates JSON-LD Agent Cards from the Intuno agent registry for agents +that opt into A2A interoperability. Cards are served at +``GET /.well-known/agent.json`` (platform-level) and per-agent endpoints. +""" + +from typing import Any, Optional +from uuid import UUID + +from src.core.settings import settings + + +def build_agent_card( + agent: Any, + capabilities: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + """Build an A2A Agent Card from an Intuno agent registry entry. + + See: https://google.github.io/A2A/specification/ + """ + card: dict[str, Any] = { + "name": agent.name, + "description": agent.description, + "url": f"{settings.BASE_URL}/a2a/agents/{agent.agent_id}", + "version": getattr(agent, "version", "1.0.0"), + "capabilities": { + "streaming": getattr(agent, "supports_streaming", False), + "pushNotifications": True, # via network callback mechanism + **(capabilities or {}), + }, + "skills": _build_skills(agent), + "authentication": _build_auth(agent), + } + + # Add input schema if available + if agent.input_schema: + card["defaultInputModes"] = ["application/json"] + card["defaultOutputModes"] = ["application/json"] + + return card + + +def build_platform_card() -> dict[str, Any]: + """Build the platform-level A2A Agent Card for Intuno itself.""" + return { + "name": "Intuno Agent Network", + "description": ( + "Registry, broker, and orchestrator for AI agents. " + "Supports multi-directional agent communication with calls, " + "messages, and mailboxes." + ), + "url": settings.BASE_URL, + "version": settings.API_VERSION, + "capabilities": { + "streaming": True, + "pushNotifications": True, + "networks": True, + "topologies": ["mesh", "star", "ring", "custom"], + "channels": ["call", "message", "mailbox"], + }, + "skills": [ + { + "id": "discover", + "name": "Discover Agents", + "description": "Semantic search for AI agents by natural-language query", + }, + { + "id": "invoke", + "name": "Invoke Agent", + "description": "Execute an agent with input data through the broker", + }, + { + "id": "orchestrate", + "name": "Orchestrate Task", + "description": "Multi-step task orchestration across multiple agents", + }, + { + "id": "network", + "name": "Communication Network", + "description": ( + "Create multi-directional communication networks between agents " + "with calls, messages, and mailboxes" + ), + }, + ], + "authentication": { + "schemes": ["apiKey", "bearer"], + }, + } + + +def _build_skills(agent: Any) -> list[dict[str, str]]: + """Extract skills from agent metadata.""" + skills = [] + + # If agent has a2a_capabilities, use those + a2a_caps = getattr(agent, "a2a_capabilities", None) + if a2a_caps and isinstance(a2a_caps, list): + for cap in a2a_caps: + if isinstance(cap, dict): + skills.append(cap) + elif isinstance(cap, str): + skills.append({"id": cap, "name": cap, "description": cap}) + return skills + + # Default: generate a single skill from agent description + skills.append({ + "id": agent.agent_id, + "name": agent.name, + "description": agent.description, + }) + return skills + + +def _build_auth(agent: Any) -> dict[str, Any]: + """Build authentication section from agent auth_type.""" + auth_type = getattr(agent, "auth_type", "public") or "public" + + if auth_type == "public": + return {"schemes": []} + elif auth_type == "api_key": + return {"schemes": ["apiKey"]} + elif auth_type == "bearer_token": + return {"schemes": ["bearer"]} + + return {"schemes": [auth_type]} diff --git a/src/network/a2a/discovery.py b/src/network/a2a/discovery.py new file mode 100644 index 0000000..e5b6e95 --- /dev/null +++ b/src/network/a2a/discovery.py @@ -0,0 +1,353 @@ +"""A2A agent discovery and import. + +Fetches remote Agent Cards from external A2A-compatible services and +registers them as first-class agents in the Intuno registry + Qdrant. +Once imported, A2A agents are discoverable, invocable, and can join +communication networks — exactly like any other agent. +""" + +import logging +from typing import Any, Optional +from urllib.parse import urljoin +from uuid import UUID + +import httpx +from fastapi import Depends + +from src.core.settings import settings +from src.models.registry import Agent +from src.repositories.registry import RegistryRepository +from src.utilities.embedding import EmbeddingService +from src.utilities.qdrant_service import QdrantService + +logger = logging.getLogger(__name__) + +# Default paths to try when fetching an Agent Card +AGENT_CARD_PATHS = [ + "/.well-known/agent.json", + "/agent.json", + "/a2a/agent-card", +] + + +class A2ADiscoveryService: + """Discover and import external A2A agents into the Intuno registry.""" + + def __init__( + self, + registry_repository: RegistryRepository = Depends(), + embedding_service: EmbeddingService = Depends(), + ): + self.registry_repository = registry_repository + self.embedding_service = embedding_service + self.qdrant_service = QdrantService() + self._http_client: Optional[httpx.AsyncClient] = None + + def set_http_client(self, client: httpx.AsyncClient) -> None: + self._http_client = client + + async def fetch_agent_card(self, base_url: str) -> Optional[dict[str, Any]]: + """Fetch an A2A Agent Card from a remote URL. + + Tries well-known paths if the URL doesn't point directly to a card. + """ + client = self._http_client + owns_client = client is None + if owns_client: + client = httpx.AsyncClient(timeout=15) + + try: + # If the URL looks like it already points to a card, try it directly + if base_url.endswith(".json") or base_url.endswith("/agent-card"): + card = await self._try_fetch(client, base_url) + if card: + return card + + # Try well-known paths + normalized = base_url.rstrip("/") + for path in AGENT_CARD_PATHS: + url = f"{normalized}{path}" + card = await self._try_fetch(client, url) + if card: + return card + + return None + finally: + if owns_client: + await client.aclose() + + async def import_agent( + self, + base_url: str, + user_id: UUID, + card: Optional[dict[str, Any]] = None, + ) -> Agent: + """Import an A2A agent as a first-class Intuno agent. + + Fetches the Agent Card (if not provided), extracts metadata, + creates a registry entry, generates embeddings, and indexes + in Qdrant. The resulting agent is fully discoverable and invocable. + """ + if card is None: + card = await self.fetch_agent_card(base_url) + if card is None: + raise ValueError( + f"Could not fetch A2A Agent Card from {base_url}" + ) + + name = card.get("name", "Unknown A2A Agent") + description = card.get("description", "") + version = card.get("version", "1.0.0") + agent_url = card.get("url", base_url) + + # Build a rich description from skills for better embedding + skills = card.get("skills", []) + skills_text = "" + if skills: + skill_descriptions = [ + s.get("description", s.get("name", "")) + for s in skills + if isinstance(s, dict) + ] + skills_text = " | Skills: " + ", ".join(skill_descriptions) + + full_description = f"{description}{skills_text}" + + # Determine invoke endpoint — A2A tasks are sent via JSON-RPC + invoke_endpoint = self._resolve_invoke_endpoint(agent_url, card) + + # Determine auth type from the card + auth_info = card.get("authentication", {}) + schemes = auth_info.get("schemes", []) + if "bearer" in schemes: + auth_type = "bearer_token" + elif "apiKey" in schemes: + auth_type = "api_key" + else: + auth_type = "public" + + # Build tags from skills and capabilities + tags = ["a2a", "external"] + for skill in skills: + if isinstance(skill, dict) and skill.get("name"): + tags.append(skill["name"].lower().replace(" ", "-")) + + capabilities = card.get("capabilities", {}) + if capabilities.get("streaming"): + tags.append("streaming") + + # Check for existing import (by invoke_endpoint) + existing = await self.registry_repository.find_agent_by_endpoint( + invoke_endpoint + ) + if existing: + # Update existing agent with latest card data + existing.name = name + existing.description = full_description + existing.version = version + existing.tags = tags + existing.auth_type = auth_type + + # Re-embed and re-index + agent_embedding = await self._generate_embedding( + name, full_description, tags + ) + await self._upsert_qdrant(existing, agent_embedding) + return await self.registry_repository.update_agent(existing) + + # Generate agent_id + import re + from uuid import uuid4 + + slug = re.sub(r"[^a-z0-9]+", "-", name.lower().strip()).strip("-") + short_id = str(uuid4()).replace("-", "")[:8] + agent_id = f"a2a-{slug}-{short_id}" + + # Build input schema from A2A card if available + input_schema = self._extract_input_schema(card) + + # Create agent + agent = Agent( + agent_id=agent_id, + user_id=user_id, + name=name, + description=full_description, + version=version, + invoke_endpoint=invoke_endpoint, + auth_type=auth_type, + input_schema=input_schema, + tags=tags, + category="a2a", + trust_verification="a2a-card", + supports_streaming=capabilities.get("streaming", False), + ) + + created_agent = await self.registry_repository.create_agent(agent) + + # Embed and index in Qdrant + agent_embedding = await self._generate_embedding( + name, full_description, tags + ) + await self._upsert_qdrant(created_agent, agent_embedding) + + return await self.registry_repository.get_agent_by_id(created_agent.id) + + async def import_multiple( + self, + urls: list[str], + user_id: UUID, + ) -> list[dict[str, Any]]: + """Import multiple A2A agents. Returns results per URL.""" + results = [] + for url in urls: + try: + agent = await self.import_agent(url, user_id) + results.append({ + "url": url, + "success": True, + "agent_id": agent.agent_id, + "name": agent.name, + }) + except Exception as exc: + logger.warning("Failed to import A2A agent from %s: %s", url, exc) + results.append({ + "url": url, + "success": False, + "error": str(exc), + }) + return results + + async def refresh_agent(self, agent_uuid: UUID, user_id: UUID) -> Agent: + """Re-fetch the Agent Card and update the registry entry.""" + agent = await self.registry_repository.get_agent_by_id(agent_uuid) + if not agent: + raise ValueError("Agent not found") + if agent.user_id != user_id: + raise ValueError("Not authorized to refresh this agent") + if "a2a" not in (agent.tags or []): + raise ValueError("Agent is not an A2A import") + + card = await self.fetch_agent_card(agent.invoke_endpoint) + if card is None: + raise ValueError( + f"Could not fetch updated Agent Card from {agent.invoke_endpoint}" + ) + + return await self.import_agent( + agent.invoke_endpoint, user_id, card=card + ) + + # ── Internal helpers ───────────────────────────────────────────── + + async def _try_fetch( + self, client: httpx.AsyncClient, url: str + ) -> Optional[dict[str, Any]]: + try: + response = await client.get( + url, + headers={"Accept": "application/json"}, + timeout=10, + ) + if response.status_code == 200: + data = response.json() + # Validate it looks like an Agent Card + if isinstance(data, dict) and ("name" in data or "skills" in data): + return data + except Exception: + pass + return None + + def _resolve_invoke_endpoint( + self, agent_url: str, card: dict[str, Any] + ) -> str: + """Determine the invoke endpoint from the Agent Card. + + A2A agents typically accept tasks at a /tasks/send endpoint + relative to their base URL. + """ + # Check if the card specifies an explicit endpoint + explicit = card.get("invoke_endpoint") or card.get("endpoint") + if explicit: + return explicit + + # Default: A2A JSON-RPC endpoint + base = agent_url.rstrip("/") + return base + + def _extract_input_schema( + self, card: dict[str, Any] + ) -> Optional[dict[str, Any]]: + """Extract or generate an input schema from the Agent Card.""" + skills = card.get("skills", []) + if not skills: + return { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Input message for the agent", + } + }, + "required": ["message"], + } + + # Build schema from skills + properties = { + "message": { + "type": "string", + "description": "Input message for the agent", + }, + "skill": { + "type": "string", + "description": "Target skill ID", + "enum": [ + s.get("id", s.get("name", "")) + for s in skills + if isinstance(s, dict) + ], + }, + } + return { + "type": "object", + "properties": properties, + "required": ["message"], + } + + async def _generate_embedding( + self, + name: str, + description: str, + tags: list[str], + ) -> list[float]: + """Generate embedding for the agent.""" + try: + return await self.embedding_service.generate_enhanced_embedding( + agent_name=name, + description=description, + tags=tags, + ) + except Exception: + # Fall back to basic embedding + text = self.embedding_service.prepare_agent_text_for_embedding( + name, description, tags + ) + return await self.embedding_service.generate_embedding(text) + + async def _upsert_qdrant( + self, agent: Agent, embedding: list[float] + ) -> None: + """Upsert agent embedding into Qdrant.""" + qdrant_point_id = agent.qdrant_point_id or agent.id + await self.qdrant_service.upsert_vector( + point_id=qdrant_point_id, + vector=embedding, + payload={ + "agent_id": agent.agent_id, + "is_active": True, + "name": agent.name, + "embedding_version": settings.EMBEDDING_VERSION, + }, + ) + agent.qdrant_point_id = qdrant_point_id + agent.embedding_version = settings.EMBEDDING_VERSION + await self.registry_repository.update_agent(agent) diff --git a/src/network/a2a/protocol.py b/src/network/a2a/protocol.py new file mode 100644 index 0000000..19206da --- /dev/null +++ b/src/network/a2a/protocol.py @@ -0,0 +1,147 @@ +"""A2A protocol adapter. + +Translates between Intuno's internal message format and the A2A +JSON-RPC wire format. + +A2A spec: https://google.github.io/A2A/specification/ +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any, Optional + + +# A2A task states +A2A_STATE_SUBMITTED = "submitted" +A2A_STATE_WORKING = "working" +A2A_STATE_INPUT_REQUIRED = "input-required" +A2A_STATE_COMPLETED = "completed" +A2A_STATE_FAILED = "failed" +A2A_STATE_CANCELED = "canceled" + +# Mapping from Intuno message status to A2A task state +_STATUS_MAP = { + "pending": A2A_STATE_SUBMITTED, + "delivered": A2A_STATE_WORKING, + "read": A2A_STATE_COMPLETED, + "failed": A2A_STATE_FAILED, +} + +# Mapping from Intuno channel type to A2A concepts +_CHANNEL_MAP = { + "call": "task", # synchronous call maps to A2A task + "message": "message", # async message maps to A2A message + "mailbox": "message", # mailbox also maps to A2A message (deferred) +} + + +def intuno_message_to_a2a_task( + message: dict[str, Any], + sender_name: str = "", + recipient_name: str = "", +) -> dict[str, Any]: + """Convert an Intuno NetworkMessage (as dict) to an A2A Task object.""" + task_id = message.get("id") or str(uuid.uuid4()) + status = message.get("status", "pending") + content = message.get("content", "") + metadata = message.get("metadata_") or message.get("metadata") or {} + + a2a_task = { + "id": str(task_id), + "status": { + "state": _STATUS_MAP.get(status, A2A_STATE_SUBMITTED), + "timestamp": ( + message.get("created_at", datetime.now(timezone.utc)).isoformat() + if isinstance(message.get("created_at"), datetime) + else datetime.now(timezone.utc).isoformat() + ), + }, + "history": [ + { + "role": "user", + "parts": [{"type": "text", "text": content}], + } + ], + "metadata": { + "intuno_network_id": str(message.get("network_id", "")), + "intuno_channel": message.get("channel_type", "message"), + "sender": sender_name, + "recipient": recipient_name, + **metadata, + }, + } + + return a2a_task + + +def a2a_task_to_intuno_message( + a2a_task: dict[str, Any], +) -> dict[str, Any]: + """Convert an A2A Task object to Intuno message format.""" + # Extract content from history + content = "" + history = a2a_task.get("history", []) + if history: + last_entry = history[-1] + parts = last_entry.get("parts", []) + text_parts = [p.get("text", "") for p in parts if p.get("type") == "text"] + content = "\n".join(text_parts) + + # Map A2A state back to Intuno status + state = a2a_task.get("status", {}).get("state", A2A_STATE_SUBMITTED) + reverse_status_map = { + A2A_STATE_SUBMITTED: "pending", + A2A_STATE_WORKING: "delivered", + A2A_STATE_COMPLETED: "read", + A2A_STATE_FAILED: "failed", + A2A_STATE_INPUT_REQUIRED: "pending", + A2A_STATE_CANCELED: "failed", + } + + a2a_metadata = a2a_task.get("metadata", {}) + + return { + "content": content, + "channel_type": a2a_metadata.get("intuno_channel", "message"), + "status": reverse_status_map.get(state, "pending"), + "metadata": { + "a2a_task_id": a2a_task.get("id"), + "a2a_state": state, + **{ + k: v + for k, v in a2a_metadata.items() + if not k.startswith("intuno_") + }, + }, + } + + +def build_a2a_json_rpc_response( + result: Any, + request_id: Optional[str | int] = None, +) -> dict[str, Any]: + """Wrap a result in a JSON-RPC 2.0 response envelope.""" + return { + "jsonrpc": "2.0", + "id": request_id, + "result": result, + } + + +def build_a2a_json_rpc_error( + code: int, + message: str, + request_id: Optional[str | int] = None, + data: Any = None, +) -> dict[str, Any]: + """Build a JSON-RPC 2.0 error response.""" + error: dict[str, Any] = {"code": code, "message": message} + if data is not None: + error["data"] = data + return { + "jsonrpc": "2.0", + "id": request_id, + "error": error, + } diff --git a/src/network/a2a/routes.py b/src/network/a2a/routes.py new file mode 100644 index 0000000..d16cb60 --- /dev/null +++ b/src/network/a2a/routes.py @@ -0,0 +1,304 @@ +"""A2A-compatible API endpoints. + +Provides endpoints that follow the A2A protocol specification, allowing +A2A-compatible agents to interact with Intuno networks. + +See: https://google.github.io/A2A/specification/ +""" + +from typing import Any, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +from src.core.auth import get_current_user +from src.models.auth import User +from src.network.a2a.agent_card import build_agent_card, build_platform_card +from src.network.a2a.protocol import ( + a2a_task_to_intuno_message, + build_a2a_json_rpc_error, + build_a2a_json_rpc_response, + intuno_message_to_a2a_task, +) +from src.repositories.registry import RegistryRepository + +router = APIRouter(prefix="/a2a", tags=["A2A"]) + + +# ── Agent Card endpoints ───────────────────────────────────────────── + + +@router.get("/agent-card") +async def get_platform_agent_card() -> JSONResponse: + """Serve the platform-level A2A Agent Card.""" + return JSONResponse(build_platform_card()) + + +@router.get("/agents/{agent_id}/agent-card") +async def get_agent_card( + agent_id: str, + registry: RegistryRepository = Depends(), +) -> JSONResponse: + """Serve an A2A Agent Card for a specific registered agent.""" + agent = await registry.get_agent_by_agent_id(agent_id) + if not agent: + return JSONResponse( + build_a2a_json_rpc_error(-32602, f"Agent '{agent_id}' not found"), + status_code=404, + ) + return JSONResponse(build_agent_card(agent)) + + +# ── A2A Task endpoints (JSON-RPC style) ────────────────────────────── + + +class A2ATaskSendRequest(BaseModel): + """A2A tasks/send request body.""" + + jsonrpc: str = "2.0" + id: Optional[str | int] = None + method: str = "tasks/send" + params: dict[str, Any] = Field(default_factory=dict) + + +@router.post("/tasks/send") +async def a2a_task_send( + data: A2ATaskSendRequest, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """A2A-compatible task send endpoint. + + Receives an A2A task, translates it to an Intuno network message, + processes it, and returns the result in A2A format. + """ + from src.network.services.channels import ChannelService + from src.network.repositories.networks import NetworkRepository + from src.network.utils.context_manager import NetworkContextManager + from src.database import get_redis + + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + + params = data.params + task_data = params.get("task", {}) + network_id = params.get("network_id") + sender_participant_id = params.get("sender_participant_id") + recipient_participant_id = params.get("recipient_participant_id") + + if not network_id or not sender_participant_id: + return JSONResponse( + build_a2a_json_rpc_error( + -32602, + "Missing required params: network_id, sender_participant_id", + data.id, + ), + status_code=400, + ) + + # Convert A2A task to Intuno message format + intuno_msg = a2a_task_to_intuno_message(task_data) + + # Process through the channel service + try: + redis = request.app.state.redis + repo = NetworkRepository( + session=(await request.app.state.db_session_factory()).__aenter__() + ) + ctx_manager = NetworkContextManager(redis) + channel_service = ChannelService(repo=repo, context_manager=ctx_manager) + channel_service.set_http_client(request.app.state.http_client) + + channel_type = intuno_msg.get("channel_type", "message") + + if channel_type == "call" and recipient_participant_id: + result = await channel_service.call( + network_id=UUID(network_id), + sender_participant_id=UUID(sender_participant_id), + recipient_participant_id=UUID(recipient_participant_id), + content=intuno_msg["content"], + metadata=intuno_msg.get("metadata"), + ) + # Convert result back to A2A task format + a2a_result = { + "id": result.get("message_id"), + "status": {"state": "completed"}, + "artifacts": [ + { + "parts": [ + {"type": "text", "text": str(result.get("response", ""))} + ] + } + ], + } + else: + message = await channel_service.send_message( + network_id=UUID(network_id), + sender_participant_id=UUID(sender_participant_id), + recipient_participant_id=UUID(recipient_participant_id), + content=intuno_msg["content"], + metadata=intuno_msg.get("metadata"), + ) + a2a_result = intuno_message_to_a2a_task( + { + "id": message.id, + "status": message.status, + "content": message.content, + "network_id": message.network_id, + "channel_type": message.channel_type, + "created_at": message.created_at, + }, + ) + + return JSONResponse(build_a2a_json_rpc_response(a2a_result, data.id)) + + except Exception as exc: + return JSONResponse( + build_a2a_json_rpc_error(-32603, str(exc), data.id), + status_code=500, + ) + + +# ── A2A Agent Discovery & Import ───────────────────────────────────── + + +@router.get("/agents") +async def list_a2a_agents( + registry: RegistryRepository = Depends(), +) -> JSONResponse: + """List all agents with A2A support enabled.""" + agents = await registry.list_agents(limit=100) + cards = [] + for agent in agents: + if getattr(agent, "is_active", False): + cards.append(build_agent_card(agent)) + return JSONResponse({"agents": cards}) + + +class A2AImportRequest(BaseModel): + """Import an external A2A agent by its base URL.""" + + url: str = Field(..., description="Base URL of the A2A-compatible agent") + + +class A2ABatchImportRequest(BaseModel): + """Import multiple external A2A agents.""" + + urls: list[str] = Field(..., description="List of A2A agent base URLs") + + +@router.post("/agents/import") +async def import_a2a_agent( + data: A2AImportRequest, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Import an external A2A agent as a first-class Intuno agent. + + Fetches the Agent Card from the given URL, creates a registry entry, + generates embeddings, and indexes in Qdrant. The agent becomes fully + discoverable and invocable — just like any natively registered agent. + """ + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + try: + agent = await discovery_service.import_agent(data.url, current_user.id) + return JSONResponse( + { + "success": True, + "agent_id": agent.agent_id, + "name": agent.name, + "description": agent.description, + "invoke_endpoint": agent.invoke_endpoint, + "tags": agent.tags, + }, + status_code=201, + ) + except ValueError as exc: + return JSONResponse( + {"success": False, "error": str(exc)}, + status_code=400, + ) + + +@router.post("/agents/import/batch") +async def import_a2a_agents_batch( + data: A2ABatchImportRequest, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Import multiple external A2A agents in one request.""" + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + results = await discovery_service.import_multiple(data.urls, current_user.id) + return JSONResponse({"results": results}) + + +@router.post("/agents/{agent_id}/refresh") +async def refresh_a2a_agent( + agent_id: str, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Re-fetch the Agent Card and update the registry entry.""" + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + agent = await discovery_service.registry_repository.get_agent_by_agent_id(agent_id) + if not agent: + return JSONResponse( + build_a2a_json_rpc_error(-32602, f"Agent '{agent_id}' not found"), + status_code=404, + ) + + try: + updated = await discovery_service.refresh_agent(agent.id, current_user.id) + return JSONResponse( + { + "success": True, + "agent_id": updated.agent_id, + "name": updated.name, + } + ) + except ValueError as exc: + return JSONResponse( + {"success": False, "error": str(exc)}, + status_code=400, + ) + + +@router.get("/agents/fetch-card") +async def fetch_agent_card_preview( + url: str, + request: Request, + current_user: User = Depends(get_current_user), +) -> JSONResponse: + """Preview an A2A Agent Card without importing it.""" + discovery_service = await _get_discovery_service(request) + discovery_service.set_http_client(request.app.state.http_client) + + card = await discovery_service.fetch_agent_card(url) + if card is None: + return JSONResponse( + {"success": False, "error": f"Could not fetch Agent Card from {url}"}, + status_code=404, + ) + return JSONResponse({"success": True, "card": card}) + + +async def _get_discovery_service(request: Request): + """Helper to build a discovery service from request context.""" + from src.network.a2a.discovery import A2ADiscoveryService + from src.database import AsyncSessionLocal + from src.utilities.embedding import EmbeddingService + + session = AsyncSessionLocal() + return A2ADiscoveryService( + registry_repository=RegistryRepository(session=session), + embedding_service=EmbeddingService(), + ) diff --git a/src/network/models/__init__.py b/src/network/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/models/entities.py b/src/network/models/entities.py new file mode 100644 index 0000000..8d9286a --- /dev/null +++ b/src/network/models/entities.py @@ -0,0 +1,163 @@ +"""Communication network domain models. + +A CommunicationNetwork groups participants (agents, personas) that can +exchange messages through calls, messages, or mailboxes. +""" + +import enum +from typing import Optional +from uuid import UUID + +from sqlalchemy import Column, Enum, ForeignKey, String, Text, Boolean +from sqlalchemy.dialects.postgresql import JSONB, UUID as PostgresUUID +from sqlalchemy.orm import relationship + +from src.models.base import BaseModel + + +class TopologyType(str, enum.Enum): + mesh = "mesh" + star = "star" + ring = "ring" + custom = "custom" + + +class NetworkStatus(str, enum.Enum): + active = "active" + paused = "paused" + closed = "closed" + + +class ParticipantType(str, enum.Enum): + agent = "agent" + persona = "persona" + orchestrator = "orchestrator" + + +class ParticipantStatus(str, enum.Enum): + active = "active" + disconnected = "disconnected" + removed = "removed" + + +class ChannelType(str, enum.Enum): + call = "call" + message = "message" + mailbox = "mailbox" + + +class MessageStatus(str, enum.Enum): + pending = "pending" + delivered = "delivered" + read = "read" + failed = "failed" + + +class CommunicationNetwork(BaseModel): + """A group of participants that share a communication context.""" + + __tablename__: str = "communication_networks" + + owner_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("users.id"), nullable=False + ) + name: Column[str] = Column(String(255), nullable=False) + topology_type: Column[str] = Column( + Enum(TopologyType), nullable=False, default=TopologyType.mesh + ) + metadata_: Column[Optional[dict]] = Column("metadata", JSONB, nullable=True) + status: Column[str] = Column( + Enum(NetworkStatus), nullable=False, default=NetworkStatus.active + ) + + # Relationships + owner = relationship("User") + participants = relationship( + "NetworkParticipant", + back_populates="network", + cascade="all, delete-orphan", + ) + messages = relationship( + "NetworkMessage", + back_populates="network", + cascade="all, delete-orphan", + order_by="NetworkMessage.created_at", + ) + + +class NetworkParticipant(BaseModel): + """An entity registered in a communication network.""" + + __tablename__: str = "network_participants" + + network_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("communication_networks.id"), nullable=False + ) + agent_id: Column[Optional[UUID]] = Column( + PostgresUUID, ForeignKey("agents.id"), nullable=True + ) + participant_type: Column[str] = Column( + Enum(ParticipantType), nullable=False, default=ParticipantType.agent + ) + name: Column[str] = Column(String(255), nullable=False) + callback_url: Column[Optional[str]] = Column(Text, nullable=True) + polling_enabled: Column[bool] = Column(Boolean, nullable=False, default=False) + capabilities: Column[Optional[dict]] = Column(JSONB, nullable=True) + status: Column[str] = Column( + Enum(ParticipantStatus), nullable=False, default=ParticipantStatus.active + ) + + # Relationships + network = relationship("CommunicationNetwork", back_populates="participants") + agent = relationship("Agent") + sent_messages = relationship( + "NetworkMessage", + back_populates="sender", + foreign_keys="NetworkMessage.sender_participant_id", + ) + received_messages = relationship( + "NetworkMessage", + back_populates="recipient", + foreign_keys="NetworkMessage.recipient_participant_id", + ) + + +class NetworkMessage(BaseModel): + """A message exchanged within a communication network.""" + + __tablename__: str = "network_messages" + + network_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("communication_networks.id"), nullable=False + ) + sender_participant_id: Column[UUID] = Column( + PostgresUUID, ForeignKey("network_participants.id"), nullable=False + ) + recipient_participant_id: Column[Optional[UUID]] = Column( + PostgresUUID, ForeignKey("network_participants.id"), nullable=True + ) + channel_type: Column[str] = Column( + Enum(ChannelType), nullable=False + ) + content: Column[str] = Column(Text, nullable=False) + metadata_: Column[Optional[dict]] = Column("metadata", JSONB, nullable=True) + status: Column[str] = Column( + Enum(MessageStatus), nullable=False, default=MessageStatus.pending + ) + in_reply_to_id: Column[Optional[UUID]] = Column( + PostgresUUID, ForeignKey("network_messages.id"), nullable=True + ) + + # Relationships + network = relationship("CommunicationNetwork", back_populates="messages") + sender = relationship( + "NetworkParticipant", + back_populates="sent_messages", + foreign_keys=[sender_participant_id], + ) + recipient = relationship( + "NetworkParticipant", + back_populates="received_messages", + foreign_keys=[recipient_participant_id], + ) + in_reply_to = relationship("NetworkMessage", remote_side="NetworkMessage.id") diff --git a/src/network/models/schemas.py b/src/network/models/schemas.py new file mode 100644 index 0000000..5e4cf7e --- /dev/null +++ b/src/network/models/schemas.py @@ -0,0 +1,116 @@ +"""Pydantic request/response schemas for communication networks.""" + +from datetime import datetime +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + + +# ── Network schemas ────────────────────────────────────────────────── + + +class NetworkCreate(BaseModel): + name: str = Field(..., max_length=255) + topology_type: str = Field(default="mesh", description="mesh | star | ring | custom") + metadata: Optional[dict[str, Any]] = None + + +class NetworkUpdate(BaseModel): + name: Optional[str] = Field(default=None, max_length=255) + topology_type: Optional[str] = None + status: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + + +class NetworkResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + owner_id: UUID + name: str + topology_type: str + metadata_: Optional[dict[str, Any]] = Field(default=None, alias="metadata_") + status: str + created_at: datetime + updated_at: datetime + + +# ── Participant schemas ────────────────────────────────────────────── + + +class ParticipantJoin(BaseModel): + agent_id: Optional[UUID] = None + participant_type: str = Field(default="agent", description="agent | persona | orchestrator") + name: str = Field(..., max_length=255) + callback_url: Optional[str] = None + polling_enabled: bool = False + capabilities: Optional[dict[str, Any]] = None + + +class ParticipantUpdate(BaseModel): + callback_url: Optional[str] = None + polling_enabled: Optional[bool] = None + capabilities: Optional[dict[str, Any]] = None + status: Optional[str] = None + + +class ParticipantResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + network_id: UUID + agent_id: Optional[UUID] = None + participant_type: str + name: str + callback_url: Optional[str] = None + polling_enabled: bool + capabilities: Optional[dict[str, Any]] = None + status: str + created_at: datetime + updated_at: datetime + + +# ── Message schemas ────────────────────────────────────────────────── + + +class NetworkMessageCreate(BaseModel): + recipient_participant_id: Optional[UUID] = None + channel_type: str = Field(..., description="call | message | mailbox") + content: str + metadata: Optional[dict[str, Any]] = None + in_reply_to_id: Optional[UUID] = None + + +class NetworkMessageResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + network_id: UUID + sender_participant_id: UUID + recipient_participant_id: Optional[UUID] = None + channel_type: str + content: str + metadata_: Optional[dict[str, Any]] = Field(default=None, alias="metadata_") + status: str + in_reply_to_id: Optional[UUID] = None + created_at: datetime + updated_at: datetime + + +# ── Context snapshot ───────────────────────────────────────────────── + + +class ContextEntry(BaseModel): + sender: str + recipient: Optional[str] = None + channel: str + content: str + timestamp: datetime + + +class NetworkContextSnapshot(BaseModel): + network_id: UUID + participant_count: int + message_count: int + entries: list[ContextEntry] diff --git a/src/network/repositories/__init__.py b/src/network/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/repositories/networks.py b/src/network/repositories/networks.py new file mode 100644 index 0000000..522dcdf --- /dev/null +++ b/src/network/repositories/networks.py @@ -0,0 +1,182 @@ +"""Repository for communication network domain operations.""" + +from typing import Optional +from uuid import UUID + +from fastapi import Depends +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from src.database import get_db +from src.network.models.entities import ( + CommunicationNetwork, + NetworkMessage, + NetworkParticipant, + ParticipantStatus, +) + + +class NetworkRepository: + """CRUD for communication networks, participants, and messages.""" + + def __init__(self, session: AsyncSession = Depends(get_db)): + self.session = session + + # ── Networks ───────────────────────────────────────────────────── + + async def create_network(self, network: CommunicationNetwork) -> CommunicationNetwork: + self.session.add(network) + await self.session.commit() + await self.session.refresh(network) + return network + + async def get_network(self, network_id: UUID) -> Optional[CommunicationNetwork]: + result = await self.session.execute( + select(CommunicationNetwork) + .where(CommunicationNetwork.id == network_id) + .options(selectinload(CommunicationNetwork.participants)) + ) + return result.scalar_one_or_none() + + async def list_networks( + self, + owner_id: UUID, + limit: int = 50, + offset: int = 0, + ) -> list[CommunicationNetwork]: + result = await self.session.execute( + select(CommunicationNetwork) + .where(CommunicationNetwork.owner_id == owner_id) + .order_by(CommunicationNetwork.created_at.desc()) + .limit(limit) + .offset(offset) + ) + return list(result.scalars().all()) + + async def update_network(self, network: CommunicationNetwork) -> CommunicationNetwork: + await self.session.commit() + await self.session.refresh(network) + return network + + async def delete_network(self, network_id: UUID) -> bool: + network = await self.get_network(network_id) + if network: + await self.session.delete(network) + await self.session.commit() + return True + return False + + # ── Participants ───────────────────────────────────────────────── + + async def add_participant(self, participant: NetworkParticipant) -> NetworkParticipant: + self.session.add(participant) + await self.session.commit() + await self.session.refresh(participant) + return participant + + async def get_participant(self, participant_id: UUID) -> Optional[NetworkParticipant]: + result = await self.session.execute( + select(NetworkParticipant).where(NetworkParticipant.id == participant_id) + ) + return result.scalar_one_or_none() + + async def get_participant_by_agent( + self, network_id: UUID, agent_id: UUID + ) -> Optional[NetworkParticipant]: + result = await self.session.execute( + select(NetworkParticipant).where( + NetworkParticipant.network_id == network_id, + NetworkParticipant.agent_id == agent_id, + NetworkParticipant.status == ParticipantStatus.active, + ) + ) + return result.scalar_one_or_none() + + async def list_participants( + self, + network_id: UUID, + active_only: bool = True, + ) -> list[NetworkParticipant]: + q = select(NetworkParticipant).where( + NetworkParticipant.network_id == network_id + ) + if active_only: + q = q.where(NetworkParticipant.status == ParticipantStatus.active) + q = q.order_by(NetworkParticipant.created_at) + result = await self.session.execute(q) + return list(result.scalars().all()) + + async def update_participant(self, participant: NetworkParticipant) -> NetworkParticipant: + await self.session.commit() + await self.session.refresh(participant) + return participant + + async def remove_participant(self, participant: NetworkParticipant) -> NetworkParticipant: + participant.status = ParticipantStatus.removed + await self.session.commit() + await self.session.refresh(participant) + return participant + + # ── Messages ───────────────────────────────────────────────────── + + async def create_message(self, message: NetworkMessage) -> NetworkMessage: + self.session.add(message) + await self.session.commit() + await self.session.refresh(message) + return message + + async def get_message(self, message_id: UUID) -> Optional[NetworkMessage]: + result = await self.session.execute( + select(NetworkMessage).where(NetworkMessage.id == message_id) + ) + return result.scalar_one_or_none() + + async def list_messages( + self, + network_id: UUID, + limit: int = 100, + offset: int = 0, + channel_type: Optional[str] = None, + participant_id: Optional[UUID] = None, + ) -> list[NetworkMessage]: + q = ( + select(NetworkMessage) + .where(NetworkMessage.network_id == network_id) + .order_by(NetworkMessage.created_at) + ) + if channel_type: + q = q.where(NetworkMessage.channel_type == channel_type) + if participant_id: + q = q.where( + (NetworkMessage.sender_participant_id == participant_id) + | (NetworkMessage.recipient_participant_id == participant_id) + ) + q = q.limit(limit).offset(offset) + result = await self.session.execute(q) + return list(result.scalars().all()) + + async def get_context( + self, + network_id: UUID, + limit: int = 50, + ) -> list[NetworkMessage]: + """Get recent messages for building network context.""" + result = await self.session.execute( + select(NetworkMessage) + .where(NetworkMessage.network_id == network_id) + .options( + selectinload(NetworkMessage.sender), + selectinload(NetworkMessage.recipient), + ) + .order_by(NetworkMessage.created_at.desc()) + .limit(limit) + ) + messages = list(result.scalars().all()) + messages.reverse() # chronological order + return messages + + async def update_message(self, message: NetworkMessage) -> NetworkMessage: + await self.session.commit() + await self.session.refresh(message) + return message diff --git a/src/network/routes/__init__.py b/src/network/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/routes/callbacks.py b/src/network/routes/callbacks.py new file mode 100644 index 0000000..80d1814 --- /dev/null +++ b/src/network/routes/callbacks.py @@ -0,0 +1,62 @@ +"""Callback routes: webhook receiver for external agents to push messages back. + +This is the key endpoint that enables bidirectional communication. +When Intuno delivers a message to an external agent, the payload includes +a ``reply_url`` pointing to this endpoint. The agent can POST back to +proactively send messages into the network. +""" + +from typing import Any, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Request +from pydantic import BaseModel, Field + +from src.network.models.schemas import NetworkMessageResponse +from src.network.services.channels import ChannelService + +router = APIRouter(prefix="/networks", tags=["Callbacks"]) + + +class CallbackPayload(BaseModel): + """Payload an external agent sends to its reply_url.""" + + content: str + recipient_participant_id: Optional[UUID] = None + channel_type: str = Field(default="message", description="message | call | mailbox") + metadata: Optional[dict[str, Any]] = None + in_reply_to_id: Optional[UUID] = None + + +@router.post( + "/{network_id}/participants/{participant_id}/callback", + response_model=NetworkMessageResponse, +) +async def receive_callback( + network_id: UUID, + participant_id: UUID, + data: CallbackPayload, + request: Request, + service: ChannelService = Depends(), +) -> NetworkMessageResponse: + """Receive a proactive message from an external agent. + + No authentication required — the reply_url itself acts as a capability + token. The participant_id in the URL identifies the sender. + + The external agent can: + - Reply to a specific message (in_reply_to_id) + - Target a specific recipient (recipient_participant_id) + - Broadcast to the network (omit recipient_participant_id) + - Choose a channel type (call/message/mailbox) + """ + service.set_http_client(request.app.state.http_client) + return await service.handle_callback( + network_id=network_id, + participant_id=participant_id, + content=data.content, + recipient_participant_id=data.recipient_participant_id, + channel_type=data.channel_type, + metadata=data.metadata, + in_reply_to_id=data.in_reply_to_id, + ) diff --git a/src/network/routes/channels.py b/src/network/routes/channels.py new file mode 100644 index 0000000..511dadd --- /dev/null +++ b/src/network/routes/channels.py @@ -0,0 +1,152 @@ +"""Channel routes: calls, messages, mailboxes, inbox, and acknowledgment.""" + +from typing import Any, List, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Query, Request, status +from pydantic import BaseModel, Field + +from src.core.auth import get_current_user +from src.models.auth import User +from src.network.models.schemas import NetworkMessageResponse +from src.network.services.channels import ChannelService + +router = APIRouter(prefix="/networks", tags=["Channels"]) + + +# ── Request schemas ────────────────────────────────────────────────── + + +class CallRequest(BaseModel): + sender_participant_id: UUID + recipient_participant_id: UUID + content: str + metadata: Optional[dict[str, Any]] = None + + +class MessageRequest(BaseModel): + sender_participant_id: UUID + recipient_participant_id: UUID + content: str + metadata: Optional[dict[str, Any]] = None + + +class MailboxRequest(BaseModel): + sender_participant_id: UUID + recipient_participant_id: UUID + content: str + metadata: Optional[dict[str, Any]] = None + + +class AckRequest(BaseModel): + message_ids: list[UUID] + + +# ── Call ───────────────────────────────────────────────────────────── + + +@router.post("/{network_id}/call") +async def make_call( + network_id: UUID, + data: CallRequest, + request: Request, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> dict: + """Synchronous call to another participant. Blocks until response.""" + service.set_http_client(request.app.state.http_client) + return await service.call( + network_id=network_id, + sender_participant_id=data.sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + content=data.content, + metadata=data.metadata, + ) + + +# ── Message ────────────────────────────────────────────────────────── + + +@router.post( + "/{network_id}/messages/send", + response_model=NetworkMessageResponse, + status_code=status.HTTP_201_CREATED, +) +async def send_message( + network_id: UUID, + data: MessageRequest, + request: Request, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> NetworkMessageResponse: + """Send a near-real-time message. Non-blocking.""" + service.set_http_client(request.app.state.http_client) + return await service.send_message( + network_id=network_id, + sender_participant_id=data.sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + content=data.content, + metadata=data.metadata, + ) + + +# ── Mailbox ────────────────────────────────────────────────────────── + + +@router.post( + "/{network_id}/mailbox", + response_model=NetworkMessageResponse, + status_code=status.HTTP_201_CREATED, +) +async def send_to_mailbox( + network_id: UUID, + data: MailboxRequest, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> NetworkMessageResponse: + """Send to mailbox. Fully async — no push delivery.""" + return await service.send_to_mailbox( + network_id=network_id, + sender_participant_id=data.sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + content=data.content, + metadata=data.metadata, + ) + + +# ── Inbox ──────────────────────────────────────────────────────────── + + +@router.get("/{network_id}/inbox/{participant_id}") +async def get_inbox( + network_id: UUID, + participant_id: UUID, + current_user: User = Depends(get_current_user), + channel_type: Optional[str] = Query(default=None), + limit: int = Query(default=50, ge=1, le=200), + service: ChannelService = Depends(), +) -> List[NetworkMessageResponse]: + """Poll inbox for a participant.""" + channel_types = [channel_type] if channel_type else None + messages = await service.get_inbox( + network_id=network_id, + participant_id=participant_id, + channel_types=channel_types, + limit=limit, + ) + return messages + + +# ── Acknowledge ────────────────────────────────────────────────────── + + +@router.post("/{network_id}/messages/ack") +async def acknowledge_messages( + network_id: UUID, + data: AckRequest, + current_user: User = Depends(get_current_user), + service: ChannelService = Depends(), +) -> dict: + """Mark messages as read.""" + count = await service.acknowledge(network_id, data.message_ids) + return {"acknowledged": count} diff --git a/src/network/routes/networks.py b/src/network/routes/networks.py new file mode 100644 index 0000000..0b3d705 --- /dev/null +++ b/src/network/routes/networks.py @@ -0,0 +1,174 @@ +"""Network routes: CRUD for communication networks, participants, and context.""" + +from typing import List +from uuid import UUID + +from fastapi import APIRouter, Depends, Query, status + +from src.core.auth import get_current_user +from src.exceptions import NotFoundException +from src.models.auth import User +from src.network.models.schemas import ( + NetworkContextSnapshot, + NetworkCreate, + NetworkMessageResponse, + NetworkResponse, + NetworkUpdate, + ParticipantJoin, + ParticipantResponse, + ParticipantUpdate, +) +from src.network.services.networks import NetworkService + +router = APIRouter(prefix="/networks", tags=["Networks"]) + + +# ── Networks ───────────────────────────────────────────────────────── + + +@router.post("", response_model=NetworkResponse, status_code=status.HTTP_201_CREATED) +async def create_network( + data: NetworkCreate, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> NetworkResponse: + network = await service.create_network(current_user.id, data) + return network + + +@router.get("", response_model=List[NetworkResponse]) +async def list_networks( + current_user: User = Depends(get_current_user), + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), + service: NetworkService = Depends(), +) -> List[NetworkResponse]: + return await service.list_networks(current_user.id, limit, offset) + + +@router.get("/{network_id}", response_model=NetworkResponse) +async def get_network( + network_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> NetworkResponse: + return await service.get_network(network_id, current_user.id) + + +@router.patch("/{network_id}", response_model=NetworkResponse) +async def update_network( + network_id: UUID, + data: NetworkUpdate, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> NetworkResponse: + return await service.update_network(network_id, current_user.id, data) + + +@router.delete("/{network_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_network( + network_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> None: + success = await service.delete_network(network_id, current_user.id) + if not success: + raise NotFoundException("Network") + + +# ── Participants ───────────────────────────────────────────────────── + + +@router.post( + "/{network_id}/participants", + response_model=ParticipantResponse, + status_code=status.HTTP_201_CREATED, +) +async def join_network( + network_id: UUID, + data: ParticipantJoin, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> ParticipantResponse: + return await service.join_network(network_id, current_user.id, data) + + +@router.get( + "/{network_id}/participants", + response_model=List[ParticipantResponse], +) +async def list_participants( + network_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> List[ParticipantResponse]: + return await service.list_participants(network_id, current_user.id) + + +@router.patch( + "/{network_id}/participants/{participant_id}", + response_model=ParticipantResponse, +) +async def update_participant( + network_id: UUID, + participant_id: UUID, + data: ParticipantUpdate, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> ParticipantResponse: + return await service.update_participant( + network_id, participant_id, current_user.id, data + ) + + +@router.delete( + "/{network_id}/participants/{participant_id}", + status_code=status.HTTP_204_NO_CONTENT, +) +async def leave_network( + network_id: UUID, + participant_id: UUID, + current_user: User = Depends(get_current_user), + service: NetworkService = Depends(), +) -> None: + success = await service.leave_network(network_id, participant_id, current_user.id) + if not success: + raise NotFoundException("Participant") + + +# ── Context ────────────────────────────────────────────────────────── + + +@router.get("/{network_id}/context") +async def get_network_context( + network_id: UUID, + current_user: User = Depends(get_current_user), + limit: int = Query(default=50, ge=1, le=200), + service: NetworkService = Depends(), +) -> dict: + entries = await service.get_context(network_id, current_user.id, limit) + return { + "network_id": str(network_id), + "entries": entries, + } + + +# ── Messages ───────────────────────────────────────────────────────── + + +@router.get( + "/{network_id}/messages", + response_model=List[NetworkMessageResponse], +) +async def list_messages( + network_id: UUID, + current_user: User = Depends(get_current_user), + limit: int = Query(default=100, ge=1, le=500), + offset: int = Query(default=0, ge=0), + channel_type: str | None = Query(default=None), + participant_id: UUID | None = Query(default=None), + service: NetworkService = Depends(), +) -> List[NetworkMessageResponse]: + return await service.list_messages( + network_id, current_user.id, limit, offset, channel_type, participant_id + ) diff --git a/src/network/services/__init__.py b/src/network/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/services/channels.py b/src/network/services/channels.py new file mode 100644 index 0000000..09a4025 --- /dev/null +++ b/src/network/services/channels.py @@ -0,0 +1,460 @@ +"""Channel service — calls, messages, and mailboxes. + +Implements the three communication primitives with different timing +semantics. Each interaction is recorded in the network context and +delivered to the recipient via the appropriate mechanism. +""" + +import json +import logging +import time +from typing import Any, Optional +from uuid import UUID + +import httpx +from fastapi import Depends + +from src.core.settings import settings +from src.exceptions import BadRequestException, NotFoundException +from src.network.models.entities import ( + ChannelType, + CommunicationNetwork, + MessageStatus, + NetworkMessage, + NetworkParticipant, + NetworkStatus, + ParticipantStatus, +) +from src.network.models.schemas import NetworkMessageCreate +from src.network.repositories.networks import NetworkRepository +from src.network.utils.context_manager import NetworkContextManager + +logger = logging.getLogger(__name__) + + +class ChannelService: + """Unified service for calls, messages, and mailboxes.""" + + def __init__( + self, + repo: NetworkRepository = Depends(), + context_manager: NetworkContextManager = Depends(), + ): + self.repo = repo + self.ctx = context_manager + self._http_client: Optional[httpx.AsyncClient] = None + + def set_http_client(self, client: httpx.AsyncClient) -> None: + self._http_client = client + + # ── Calls (synchronous, blocking) ──────────────────────────────── + + async def call( + self, + network_id: UUID, + sender_participant_id: UUID, + recipient_participant_id: UUID, + content: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """Synchronous call: send payload to recipient, wait for response. + + Returns a dict with the call result including the recipient's response. + """ + sender, recipient, network = await self._validate_communication( + network_id, sender_participant_id, recipient_participant_id + ) + + if not recipient.callback_url: + raise BadRequestException( + "Recipient has no callback_url; cannot make a synchronous call" + ) + + # Record outgoing message + outgoing = await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType.call, + content=content, + metadata=metadata, + ) + + # Build context window for the call + context_entries = await self.ctx.get_context_window(network_id, limit=30) + participants = await self.repo.list_participants(network_id) + + # Build the payload with reply_url + payload = self._build_delivery_payload( + network_id=network_id, + sender=sender, + recipient=recipient, + channel="call", + content=content, + context=context_entries, + participants=participants, + message_id=outgoing.id, + ) + + # Synchronous HTTP call + response_data = await self._deliver_http( + recipient.callback_url, + payload, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + + # Record the response as a message from recipient back to sender + response_content = ( + json.dumps(response_data) if isinstance(response_data, dict) else str(response_data) + ) + await self._record_message( + network_id=network_id, + sender=recipient, + recipient=sender, + channel_type=ChannelType.call, + content=response_content, + metadata={"in_reply_to": str(outgoing.id)}, + ) + + # Mark outgoing as delivered + outgoing.status = MessageStatus.delivered + await self.repo.update_message(outgoing) + + return { + "success": True, + "message_id": str(outgoing.id), + "response": response_data, + } + + # ── Messages (near-real-time, non-blocking) ────────────────────── + + async def send_message( + self, + network_id: UUID, + sender_participant_id: UUID, + recipient_participant_id: UUID, + content: str, + metadata: Optional[dict[str, Any]] = None, + ) -> NetworkMessage: + """Non-blocking message: record and push via webhook. + + The sender does not block on the recipient's processing. Delivery + is best-effort with retries handled by the delivery worker. + """ + sender, recipient, network = await self._validate_communication( + network_id, sender_participant_id, recipient_participant_id + ) + + message = await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType.message, + content=content, + metadata=metadata, + ) + + # Attempt immediate webhook delivery (fire-and-forget style) + if recipient.callback_url: + context_entries = await self.ctx.get_context_window(network_id, limit=20) + participants = await self.repo.list_participants(network_id) + payload = self._build_delivery_payload( + network_id=network_id, + sender=sender, + recipient=recipient, + channel="message", + content=content, + context=context_entries, + participants=participants, + message_id=message.id, + ) + try: + await self._deliver_http( + recipient.callback_url, + payload, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + message.status = MessageStatus.delivered + except Exception: + logger.warning( + "Message delivery failed for participant %s; will retry", + recipient_participant_id, + ) + message.status = MessageStatus.pending + await self.repo.update_message(message) + + return message + + # ── Mailbox (fully asynchronous) ───────────────────────────────── + + async def send_to_mailbox( + self, + network_id: UUID, + sender_participant_id: UUID, + recipient_participant_id: UUID, + content: str, + metadata: Optional[dict[str, Any]] = None, + ) -> NetworkMessage: + """Async mailbox: store message, no push delivery.""" + sender, recipient, network = await self._validate_communication( + network_id, sender_participant_id, recipient_participant_id + ) + + return await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType.mailbox, + content=content, + metadata=metadata, + ) + + # ── Inbox (polling) ────────────────────────────────────────────── + + async def get_inbox( + self, + network_id: UUID, + participant_id: UUID, + channel_types: Optional[list[str]] = None, + limit: int = 50, + ) -> list[NetworkMessage]: + """Get unread messages for a participant.""" + messages = await self.repo.list_messages( + network_id=network_id, + limit=limit, + participant_id=participant_id, + ) + if channel_types: + messages = [m for m in messages if m.channel_type in channel_types] + return messages + + async def acknowledge( + self, network_id: UUID, message_ids: list[UUID] + ) -> int: + """Mark messages as read.""" + count = 0 + for msg_id in message_ids: + message = await self.repo.get_message(msg_id) + if message and message.network_id == network_id: + message.status = MessageStatus.read + await self.repo.update_message(message) + count += 1 + return count + + # ── Callback (external agents pushing back) ────────────────────── + + async def handle_callback( + self, + network_id: UUID, + participant_id: UUID, + content: str, + recipient_participant_id: Optional[UUID] = None, + channel_type: str = "message", + metadata: Optional[dict[str, Any]] = None, + in_reply_to_id: Optional[UUID] = None, + ) -> NetworkMessage: + """Handle a proactive message from an external agent via callback URL. + + This is the key to bidirectionality: external agents POST to their + reply_url and this method records the message in the network. + """ + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + + sender = await self.repo.get_participant(participant_id) + if not sender or sender.network_id != network_id: + raise NotFoundException("Participant") + if sender.status != ParticipantStatus.active: + raise BadRequestException("Participant is not active") + + network = await self.repo.get_network(network_id) + if not network or network.status != NetworkStatus.active: + raise BadRequestException("Network is not active") + + recipient = None + if recipient_participant_id: + recipient = await self.repo.get_participant(recipient_participant_id) + if not recipient or recipient.network_id != network_id: + raise NotFoundException("Recipient participant") + + message = await self._record_message( + network_id=network_id, + sender=sender, + recipient=recipient, + channel_type=ChannelType(channel_type), + content=content, + metadata=metadata, + in_reply_to_id=in_reply_to_id, + ) + + # If there's a specific recipient with a callback_url, forward the message + if recipient and recipient.callback_url and channel_type == "message": + context_entries = await self.ctx.get_context_window(network_id, limit=20) + participants = await self.repo.list_participants(network_id) + payload = self._build_delivery_payload( + network_id=network_id, + sender=sender, + recipient=recipient, + channel=channel_type, + content=content, + context=context_entries, + participants=participants, + message_id=message.id, + ) + try: + await self._deliver_http( + recipient.callback_url, + payload, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + message.status = MessageStatus.delivered + await self.repo.update_message(message) + except Exception: + logger.warning( + "Forwarding callback message failed for participant %s", + recipient_participant_id, + ) + + return message + + # ── Internal helpers ───────────────────────────────────────────── + + async def _validate_communication( + self, + network_id: UUID, + sender_id: UUID, + recipient_id: UUID, + ) -> tuple[NetworkParticipant, NetworkParticipant, CommunicationNetwork]: + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_agent_active, check_platform_halt + await check_platform_halt() + + network = await self.repo.get_network(network_id) + if not network: + raise NotFoundException("Network") + if network.status != NetworkStatus.active: + raise BadRequestException("Network is not active") + + sender = await self.repo.get_participant(sender_id) + if not sender or sender.network_id != network_id: + raise NotFoundException("Sender participant") + if sender.status != ParticipantStatus.active: + raise BadRequestException("Sender is not active") + + recipient = await self.repo.get_participant(recipient_id) + if not recipient or recipient.network_id != network_id: + raise NotFoundException("Recipient participant") + if recipient.status != ParticipantStatus.active: + raise BadRequestException("Recipient is not active") + + # Safety check: verify linked agents are still active + if sender.agent_id: + await check_agent_active(sender.agent_id) + if recipient.agent_id: + await check_agent_active(recipient.agent_id) + + return sender, recipient, network + + async def _record_message( + self, + *, + network_id: UUID, + sender: NetworkParticipant, + recipient: Optional[NetworkParticipant], + channel_type: ChannelType, + content: str, + metadata: Optional[dict[str, Any]] = None, + in_reply_to_id: Optional[UUID] = None, + ) -> NetworkMessage: + message = NetworkMessage( + network_id=network_id, + sender_participant_id=sender.id, + recipient_participant_id=recipient.id if recipient else None, + channel_type=channel_type, + content=content, + metadata_=metadata, + in_reply_to_id=in_reply_to_id, + ) + message = await self.repo.create_message(message) + + await self.ctx.append( + network_id, + sender=sender.name, + recipient=recipient.name if recipient else None, + channel=channel_type.value, + content=content, + message_id=message.id, + ) + return message + + def _build_delivery_payload( + self, + *, + network_id: UUID, + sender: NetworkParticipant, + recipient: NetworkParticipant, + channel: str, + content: str, + context: list[dict], + participants: list[NetworkParticipant], + message_id: UUID, + ) -> dict[str, Any]: + """Build the standard payload delivered to external agents.""" + return { + "network_id": str(network_id), + "message_id": str(message_id), + "channel": channel, + "sender": { + "participant_id": str(sender.id), + "name": sender.name, + }, + "content": content, + "context": context, + "reply_url": ( + f"{settings.BASE_URL}/networks/{network_id}" + f"/participants/{recipient.id}/callback" + ), + "network_participants": [ + {"participant_id": str(p.id), "name": p.name} + for p in participants + if p.status == ParticipantStatus.active + ], + } + + async def _deliver_http( + self, + url: str, + payload: dict[str, Any], + timeout: int = 30, + ) -> dict[str, Any]: + """POST payload to an external agent's callback URL.""" + client = self._http_client + owns_client = client is None + if owns_client: + client = httpx.AsyncClient(timeout=timeout) + + try: + response = await client.post( + url, + json=payload, + headers={ + "Content-Type": "application/json", + "User-Agent": "Intuno-Network/1.0", + }, + timeout=timeout, + ) + if response.status_code == 200: + try: + return response.json() + except Exception: + return {"raw_response": response.text} + else: + raise httpx.HTTPStatusError( + f"Callback returned {response.status_code}", + request=response.request, + response=response, + ) + finally: + if owns_client: + await client.aclose() diff --git a/src/network/services/networks.py b/src/network/services/networks.py new file mode 100644 index 0000000..2cf4321 --- /dev/null +++ b/src/network/services/networks.py @@ -0,0 +1,219 @@ +"""Communication network service — business logic for networks and participants.""" + +from typing import Optional +from uuid import UUID + +from fastapi import Depends + +from src.exceptions import BadRequestException, NotFoundException +from src.network.models.entities import ( + ChannelType, + CommunicationNetwork, + NetworkMessage, + NetworkParticipant, + NetworkStatus, + ParticipantStatus, + ParticipantType, + TopologyType, +) +from src.network.models.schemas import ( + NetworkCreate, + NetworkMessageCreate, + NetworkUpdate, + ParticipantJoin, + ParticipantUpdate, +) +from src.network.repositories.networks import NetworkRepository +from src.network.utils.context_manager import NetworkContextManager + + +class NetworkService: + """Service for communication network operations.""" + + def __init__( + self, + repo: NetworkRepository = Depends(), + context_manager: NetworkContextManager = Depends(), + ): + self.repo = repo + self.ctx = context_manager + + # ── Networks ───────────────────────────────────────────────────── + + async def create_network( + self, owner_id: UUID, data: NetworkCreate + ) -> CommunicationNetwork: + network = CommunicationNetwork( + owner_id=owner_id, + name=data.name, + topology_type=TopologyType(data.topology_type), + metadata_=data.metadata, + status=NetworkStatus.active, + ) + return await self.repo.create_network(network) + + async def get_network(self, network_id: UUID, owner_id: UUID) -> CommunicationNetwork: + network = await self.repo.get_network(network_id) + if not network or network.owner_id != owner_id: + raise NotFoundException("Network") + return network + + async def list_networks( + self, owner_id: UUID, limit: int = 50, offset: int = 0 + ) -> list[CommunicationNetwork]: + return await self.repo.list_networks(owner_id, limit, offset) + + async def update_network( + self, network_id: UUID, owner_id: UUID, data: NetworkUpdate + ) -> CommunicationNetwork: + network = await self.get_network(network_id, owner_id) + if data.name is not None: + network.name = data.name + if data.topology_type is not None: + network.topology_type = TopologyType(data.topology_type) + if data.status is not None: + network.status = NetworkStatus(data.status) + if data.metadata is not None: + network.metadata_ = data.metadata + return await self.repo.update_network(network) + + async def delete_network(self, network_id: UUID, owner_id: UUID) -> bool: + network = await self.repo.get_network(network_id) + if not network or network.owner_id != owner_id: + return False + await self.ctx.clear(network_id) + return await self.repo.delete_network(network_id) + + # ── Participants ───────────────────────────────────────────────── + + async def join_network( + self, network_id: UUID, owner_id: UUID, data: ParticipantJoin + ) -> NetworkParticipant: + network = await self.get_network(network_id, owner_id) + if network.status != NetworkStatus.active: + raise BadRequestException("Network is not active") + if not data.callback_url and not data.polling_enabled: + raise BadRequestException( + "Participant must have a callback_url or polling_enabled" + ) + if data.agent_id: + existing = await self.repo.get_participant_by_agent(network_id, data.agent_id) + if existing: + raise BadRequestException("Agent already in network") + participant = NetworkParticipant( + network_id=network_id, + agent_id=data.agent_id, + participant_type=ParticipantType(data.participant_type), + name=data.name, + callback_url=data.callback_url, + polling_enabled=data.polling_enabled, + capabilities=data.capabilities, + status=ParticipantStatus.active, + ) + return await self.repo.add_participant(participant) + + async def list_participants( + self, network_id: UUID, owner_id: UUID + ) -> list[NetworkParticipant]: + await self.get_network(network_id, owner_id) + return await self.repo.list_participants(network_id) + + async def update_participant( + self, + network_id: UUID, + participant_id: UUID, + owner_id: UUID, + data: ParticipantUpdate, + ) -> NetworkParticipant: + await self.get_network(network_id, owner_id) + participant = await self.repo.get_participant(participant_id) + if not participant or participant.network_id != network_id: + raise NotFoundException("Participant") + if data.callback_url is not None: + participant.callback_url = data.callback_url + if data.polling_enabled is not None: + participant.polling_enabled = data.polling_enabled + if data.capabilities is not None: + participant.capabilities = data.capabilities + if data.status is not None: + participant.status = ParticipantStatus(data.status) + return await self.repo.update_participant(participant) + + async def leave_network( + self, network_id: UUID, participant_id: UUID, owner_id: UUID + ) -> bool: + await self.get_network(network_id, owner_id) + participant = await self.repo.get_participant(participant_id) + if not participant or participant.network_id != network_id: + return False + await self.repo.remove_participant(participant) + return True + + # ── Context ────────────────────────────────────────────────────── + + async def get_context( + self, network_id: UUID, owner_id: UUID, limit: int = 50 + ) -> list[dict]: + await self.get_network(network_id, owner_id) + return await self.ctx.get_context_window(network_id, limit) + + async def get_context_from_db( + self, network_id: UUID, owner_id: UUID, limit: int = 50 + ) -> list[NetworkMessage]: + """Authoritative context from Postgres (slower but complete).""" + await self.get_network(network_id, owner_id) + return await self.repo.get_context(network_id, limit) + + # ── Messages (internal recording) ──────────────────────────────── + + async def record_message( + self, + network_id: UUID, + sender_participant_id: UUID, + data: NetworkMessageCreate, + ) -> NetworkMessage: + """Record a message and update the Redis context cache.""" + sender = await self.repo.get_participant(sender_participant_id) + if not sender or sender.network_id != network_id: + raise BadRequestException("Sender is not a participant in this network") + + message = NetworkMessage( + network_id=network_id, + sender_participant_id=sender_participant_id, + recipient_participant_id=data.recipient_participant_id, + channel_type=ChannelType(data.channel_type), + content=data.content, + metadata_=data.metadata, + in_reply_to_id=data.in_reply_to_id, + ) + message = await self.repo.create_message(message) + + # Update Redis context cache + recipient = None + if data.recipient_participant_id: + r = await self.repo.get_participant(data.recipient_participant_id) + recipient = r.name if r else None + + await self.ctx.append( + network_id, + sender=sender.name, + recipient=recipient, + channel=data.channel_type, + content=data.content, + message_id=message.id, + ) + return message + + async def list_messages( + self, + network_id: UUID, + owner_id: UUID, + limit: int = 100, + offset: int = 0, + channel_type: Optional[str] = None, + participant_id: Optional[UUID] = None, + ) -> list[NetworkMessage]: + await self.get_network(network_id, owner_id) + return await self.repo.list_messages( + network_id, limit, offset, channel_type, participant_id + ) diff --git a/src/network/utils/__init__.py b/src/network/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/network/utils/aggregator.py b/src/network/utils/aggregator.py new file mode 100644 index 0000000..26673c0 --- /dev/null +++ b/src/network/utils/aggregator.py @@ -0,0 +1,126 @@ +"""Fan-in aggregation strategies for combining outputs from multiple agents. + +Used by the ``aggregate`` step type in the workflow orchestrator. +""" + +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger(__name__) + + +class Aggregator(ABC): + """Base class for aggregation strategies.""" + + @abstractmethod + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + """Combine multiple agent outputs into a single result.""" + ... + + +class MergeAggregator(Aggregator): + """Concatenate all outputs into a single dict. + + Each input is keyed by its source step ID. + """ + + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + merged: dict[str, Any] = {} + for item in inputs: + source = item.get("source", f"input_{len(merged)}") + merged[source] = item.get("output") + return {"strategy": "merge", "result": merged} + + +class VoteAggregator(Aggregator): + """Pick the majority answer for classification tasks. + + Each input should have an ``output`` field with a string value. + The most common value wins. + """ + + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + votes: dict[str, int] = {} + for item in inputs: + output = item.get("output") + key = str(output) if output is not None else "null" + votes[key] = votes.get(key, 0) + 1 + + if not votes: + return {"strategy": "vote", "result": None, "votes": {}} + + winner = max(votes, key=votes.get) + return { + "strategy": "vote", + "result": winner, + "votes": votes, + "total": len(inputs), + } + + +class LLMSummarizeAggregator(Aggregator): + """Use an LLM to synthesize all inputs into a coherent output. + + Falls back to merge if LLM is unavailable. + """ + + async def aggregate(self, inputs: list[dict[str, Any]]) -> dict[str, Any]: + try: + from openai import AsyncOpenAI + from src.core.settings import settings + + client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) + + formatted_inputs = "\n\n".join( + f"[{item.get('source', f'Agent {i+1}')}]:\n{json.dumps(item.get('output'), indent=2)}" + for i, item in enumerate(inputs) + ) + + response = await client.chat.completions.create( + model=settings.LLM_ENHANCEMENT_MODEL, + messages=[ + { + "role": "system", + "content": ( + "You are a synthesizer. Multiple AI agents have provided their outputs. " + "Combine them into a single coherent, comprehensive response. " + "Resolve contradictions, merge complementary information, " + "and produce a unified result." + ), + }, + { + "role": "user", + "content": f"Synthesize these agent outputs:\n\n{formatted_inputs}", + }, + ], + temperature=0.3, + ) + + synthesis = response.choices[0].message.content + return { + "strategy": "llm_summarize", + "result": synthesis, + "source_count": len(inputs), + } + except Exception as exc: + logger.warning("LLM summarize failed, falling back to merge: %s", exc) + fallback = MergeAggregator() + result = await fallback.aggregate(inputs) + result["strategy"] = "llm_summarize_fallback" + return result + + +def create_aggregator(strategy: str) -> Aggregator: + """Factory function for aggregation strategies.""" + if strategy == "merge": + return MergeAggregator() + elif strategy == "vote": + return VoteAggregator() + elif strategy == "llm_summarize": + return LLMSummarizeAggregator() + else: + raise ValueError(f"Unknown aggregation strategy: {strategy}") diff --git a/src/network/utils/context_manager.py b/src/network/utils/context_manager.py new file mode 100644 index 0000000..8b880f3 --- /dev/null +++ b/src/network/utils/context_manager.py @@ -0,0 +1,76 @@ +"""Network-scoped context manager. + +Maintains a fast Redis cache of recent messages per network alongside +the authoritative Postgres storage. When delivering a message to an +external agent, we build a context window from Redis for low-latency. +""" + +import json +import time +import uuid +from typing import Any + +import redis.asyncio as aioredis +from fastapi import Depends + +from src.core.settings import settings +from src.database import get_redis + + +class NetworkContextManager: + """Redis-backed context accumulator for communication networks.""" + + def __init__(self, redis: aioredis.Redis = Depends(get_redis)) -> None: + self._redis = redis + + def _stream_key(self, network_id: uuid.UUID) -> str: + return f"net:{network_id}:ctx" + + async def append( + self, + network_id: uuid.UUID, + *, + sender: str, + recipient: str | None, + channel: str, + content: str, + message_id: uuid.UUID | None = None, + ) -> None: + """Append a message to the network context stream.""" + entry = { + "sender": sender, + "recipient": recipient or "", + "channel": channel, + "content": content, + "message_id": str(message_id) if message_id else "", + "ts": str(time.time()), + } + key = self._stream_key(network_id) + await self._redis.xadd(key, entry, maxlen=settings.NETWORK_CONTEXT_MAX_ENTRIES) + await self._redis.expire(key, settings.NETWORK_CONTEXT_TTL_SECONDS) + + async def get_context_window( + self, + network_id: uuid.UUID, + limit: int = 50, + ) -> list[dict[str, Any]]: + """Retrieve recent context entries from Redis stream.""" + key = self._stream_key(network_id) + # Read from the end of the stream + entries = await self._redis.xrevrange(key, count=limit) + result = [] + for _stream_id, data in reversed(entries): + result.append( + { + "sender": data["sender"], + "recipient": data["recipient"] or None, + "channel": data["channel"], + "content": data["content"], + "timestamp": float(data["ts"]), + } + ) + return result + + async def clear(self, network_id: uuid.UUID) -> None: + """Delete the context stream for a network.""" + await self._redis.delete(self._stream_key(network_id)) diff --git a/src/network/utils/convergence.py b/src/network/utils/convergence.py new file mode 100644 index 0000000..e1ea647 --- /dev/null +++ b/src/network/utils/convergence.py @@ -0,0 +1,146 @@ +"""Convergence detectors for feedback loops. + +Determine when iterative agent interactions have converged and should +stop. Used by the loop step type in the workflow orchestrator. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger(__name__) + + +class ConvergenceDetector(ABC): + """Base class for convergence detection strategies.""" + + @abstractmethod + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + """Return True if the loop should stop.""" + ... + + +class MaxIterationsDetector(ConvergenceDetector): + """Hard cap on iterations — always enforced as a safety net.""" + + def __init__(self, max_iterations: int = 5): + self.max_iterations = max_iterations + + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + return iteration >= self.max_iterations + + +class ApprovalDetector(ConvergenceDetector): + """Check if the output contains an explicit approval signal. + + Looks for keywords like "approved", "accepted", "lgtm" in the output + or for a structured ``{"approved": true}`` field. + """ + + APPROVAL_KEYWORDS = {"approved", "accepted", "lgtm", "looks good", "ship it"} + + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + if isinstance(current_output, dict): + if current_output.get("approved") is True: + return True + text = str(current_output.get("output", "")) + str( + current_output.get("content", "") + ) + elif isinstance(current_output, str): + text = current_output + else: + return False + + text_lower = text.lower() + return any(kw in text_lower for kw in self.APPROVAL_KEYWORDS) + + +class SimilarityDetector(ConvergenceDetector): + """Compare consecutive outputs using text similarity. + + Uses a simple token overlap ratio (Jaccard similarity). For + production use, this could be upgraded to use embedding cosine + similarity via the EmbeddingService. + """ + + def __init__(self, threshold: float = 0.95): + self.threshold = threshold + + async def has_converged( + self, + iteration: int, + current_output: Any, + previous_output: Any, + context: dict[str, Any], + ) -> bool: + if previous_output is None: + return False + + current_text = self._to_text(current_output) + previous_text = self._to_text(previous_output) + + if not current_text or not previous_text: + return False + + similarity = self._jaccard_similarity(current_text, previous_text) + logger.debug( + "Similarity check: iteration=%d similarity=%.3f threshold=%.3f", + iteration, + similarity, + self.threshold, + ) + return similarity >= self.threshold + + def _to_text(self, output: Any) -> str: + if isinstance(output, str): + return output + if isinstance(output, dict): + return str(output.get("output", "")) or str(output.get("content", "")) + return str(output) + + def _jaccard_similarity(self, a: str, b: str) -> float: + tokens_a = set(a.lower().split()) + tokens_b = set(b.lower().split()) + if not tokens_a and not tokens_b: + return 1.0 + intersection = tokens_a & tokens_b + union = tokens_a | tokens_b + return len(intersection) / len(union) if union else 1.0 + + +def create_detector( + detector_type: str, + config: dict[str, Any] | None = None, +) -> ConvergenceDetector: + """Factory function for convergence detectors.""" + config = config or {} + if detector_type == "similarity": + return SimilarityDetector(threshold=config.get("threshold", 0.95)) + elif detector_type == "approval": + return ApprovalDetector() + elif detector_type == "max_iterations": + return MaxIterationsDetector( + max_iterations=config.get("max_iterations", 5) + ) + else: + raise ValueError(f"Unknown convergence detector type: {detector_type}") diff --git a/src/network/utils/delivery_worker.py b/src/network/utils/delivery_worker.py new file mode 100644 index 0000000..1e1c880 --- /dev/null +++ b/src/network/utils/delivery_worker.py @@ -0,0 +1,157 @@ +"""Background delivery worker for async network messages. + +Consumes from a Redis Stream to deliver pending messages to participants +via their callback URLs. Follows the same consumer-group pattern as +``src.workflow.utils.event_consumer.EventConsumer``. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any + +import httpx +import redis.asyncio as aioredis + +from src.core.settings import settings + +logger = logging.getLogger(__name__) + +STREAM_KEY = "intuno:network:delivery" +CONSUMER_GROUP = "network_delivery_workers" +CONSUMER_NAME = "delivery-worker-1" + + +class DeliveryWorker: + """Reads pending deliveries from Redis Stream and POSTs to callback URLs.""" + + def __init__(self, redis: aioredis.Redis) -> None: + self._redis = redis + self._task: asyncio.Task[None] | None = None + self._http_client: httpx.AsyncClient | None = None + + async def start(self, http_client: httpx.AsyncClient | None = None) -> None: + self._http_client = http_client + try: + await self._redis.xgroup_create( + STREAM_KEY, CONSUMER_GROUP, id="0", mkstream=True, + ) + except aioredis.ResponseError as exc: + if "BUSYGROUP" not in str(exc): + raise + self._task = asyncio.create_task(self._consume(), name="delivery-worker") + logger.info("Delivery worker started on stream '%s'", STREAM_KEY) + + async def stop(self) -> None: + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + logger.info("Delivery worker stopped") + + @staticmethod + async def enqueue( + redis: aioredis.Redis, + *, + callback_url: str, + payload: dict[str, Any], + message_id: str, + attempt: int = 0, + ) -> None: + """Enqueue a delivery task into the Redis Stream.""" + await redis.xadd( + STREAM_KEY, + { + "callback_url": callback_url, + "payload": json.dumps(payload, default=str), + "message_id": message_id, + "attempt": str(attempt), + }, + ) + + async def _consume(self) -> None: + while True: + try: + entries = await self._redis.xreadgroup( + CONSUMER_GROUP, + CONSUMER_NAME, + {STREAM_KEY: ">"}, + count=10, + block=5000, + ) + if not entries: + continue + for _stream, messages in entries: + for msg_id, data in messages: + await self._deliver(msg_id, data) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Delivery worker error; retrying in 2s") + await asyncio.sleep(2) + + async def _deliver(self, stream_id: str, data: dict[str, str]) -> None: + callback_url = data.get("callback_url", "") + payload_raw = data.get("payload", "{}") + message_id = data.get("message_id", "") + attempt = int(data.get("attempt", "0")) + + try: + payload = json.loads(payload_raw) + except json.JSONDecodeError: + payload = {"raw": payload_raw} + + client = self._http_client or httpx.AsyncClient( + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS + ) + owns_client = self._http_client is None + + try: + response = await client.post( + callback_url, + json=payload, + headers={ + "Content-Type": "application/json", + "User-Agent": "Intuno-Network/1.0", + }, + timeout=settings.NETWORK_CALLBACK_TIMEOUT_SECONDS, + ) + if response.status_code == 200: + logger.debug("Delivered message %s to %s", message_id, callback_url) + else: + logger.warning( + "Delivery to %s returned %d for message %s", + callback_url, + response.status_code, + message_id, + ) + await self._maybe_retry(data, attempt) + except Exception: + logger.warning( + "Delivery to %s failed for message %s (attempt %d)", + callback_url, + message_id, + attempt, + ) + await self._maybe_retry(data, attempt) + finally: + if owns_client: + await client.aclose() + + await self._redis.xack(STREAM_KEY, CONSUMER_GROUP, stream_id) + + async def _maybe_retry(self, data: dict[str, str], attempt: int) -> None: + if attempt < settings.NETWORK_MESSAGE_DELIVERY_MAX_RETRIES: + backoff = 2 ** (attempt + 1) + await asyncio.sleep(backoff) + await self.enqueue( + self._redis, + callback_url=data.get("callback_url", ""), + payload=json.loads(data.get("payload", "{}")), + message_id=data.get("message_id", ""), + attempt=attempt + 1, + ) diff --git a/src/network/utils/topology.py b/src/network/utils/topology.py new file mode 100644 index 0000000..da295bd --- /dev/null +++ b/src/network/utils/topology.py @@ -0,0 +1,104 @@ +"""Topology validation and routing rules for communication networks. + +Enforces communication constraints based on the network's topology type: +- mesh: any participant can communicate with any other +- star: only the hub (first participant) can initiate +- ring: messages flow sequentially through participants +- custom: no enforcement (topology managed externally) +""" + +from uuid import UUID + +from src.exceptions import BadRequestException +from src.network.models.entities import ( + CommunicationNetwork, + NetworkParticipant, + TopologyType, +) + + +class TopologyValidator: + """Validates whether communication is allowed given the network topology.""" + + def validate( + self, + network: CommunicationNetwork, + sender: NetworkParticipant, + recipient: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> None: + """Raise BadRequestException if communication is not allowed.""" + topology = network.topology_type + if topology == TopologyType.mesh or topology == TopologyType.custom: + return # no restrictions + + if topology == TopologyType.star: + self._validate_star(sender, participants) + elif topology == TopologyType.ring: + self._validate_ring(sender, recipient, participants) + + def _validate_star( + self, + sender: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> None: + """In star topology, only the hub (first participant) can initiate.""" + if not participants: + return + hub = participants[0] + if sender.id != hub.id: + raise BadRequestException( + f"Star topology: only the hub participant '{hub.name}' can initiate communication" + ) + + def _validate_ring( + self, + sender: NetworkParticipant, + recipient: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> None: + """In ring topology, messages flow to the next participant in order.""" + if len(participants) < 2: + return + ids = [p.id for p in participants] + try: + sender_idx = ids.index(sender.id) + except ValueError: + raise BadRequestException("Sender is not in the participant list") + next_idx = (sender_idx + 1) % len(ids) + if recipient.id != ids[next_idx]: + expected_name = participants[next_idx].name + raise BadRequestException( + f"Ring topology: '{sender.name}' can only send to the next participant " + f"'{expected_name}', not '{recipient.name}'" + ) + + def get_reachable( + self, + network: CommunicationNetwork, + sender: NetworkParticipant, + participants: list[NetworkParticipant], + ) -> list[NetworkParticipant]: + """Return participants that the sender can communicate with.""" + topology = network.topology_type + others = [p for p in participants if p.id != sender.id] + + if topology == TopologyType.mesh or topology == TopologyType.custom: + return others + + if topology == TopologyType.star: + hub = participants[0] if participants else None + if hub and sender.id == hub.id: + return others + return [hub] if hub else [] + + if topology == TopologyType.ring: + ids = [p.id for p in participants] + try: + sender_idx = ids.index(sender.id) + except ValueError: + return [] + next_idx = (sender_idx + 1) % len(ids) + return [participants[next_idx]] + + return others diff --git a/src/repositories/registry.py b/src/repositories/registry.py index 03c2c31..6ae097e 100644 --- a/src/repositories/registry.py +++ b/src/repositories/registry.py @@ -39,6 +39,13 @@ async def get_agent_by_agent_id(self, agent_id: str) -> Optional[Agent]: ) return result.scalar_one_or_none() + async def find_agent_by_endpoint(self, invoke_endpoint: str) -> Optional[Agent]: + """Find an agent by its invoke endpoint URL.""" + result = await self.session.execute( + select(Agent).where(Agent.invoke_endpoint == invoke_endpoint) + ) + return result.scalar_one_or_none() + async def get_agents_by_user_id(self, user_id: UUID) -> List[Agent]: """Get all agents for a user.""" result = await self.session.execute( diff --git a/src/routes/admin.py b/src/routes/admin.py new file mode 100644 index 0000000..d73e1be --- /dev/null +++ b/src/routes/admin.py @@ -0,0 +1,182 @@ +"""Admin routes — platform governance and agent kill switch. + +All endpoints require admin privileges via the get_admin_user dependency. +""" + +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.admin_auth import get_admin_user +from src.database import get_db +from src.exceptions import NotFoundException +from src.models.auth import User +from src.models.registry import Agent +from src.services import safety + +router = APIRouter(prefix="/admin", tags=["Admin"]) + + +# ── Request/Response schemas ──────────────────────────────────────── + + +class KillAgentRequest(BaseModel): + reason: str = Field(..., min_length=1, description="Why is this agent being disabled?") + + +class HaltPlatformRequest(BaseModel): + reason: str = Field(..., min_length=1, description="Why is the platform being halted?") + + +class AgentStatusResponse(BaseModel): + agent_id: str + agent_uuid: UUID + name: str + is_active: bool + owner_id: UUID + + +class PlatformStatusResponse(BaseModel): + halted: bool + reason: Optional[str] = None + halted_by: Optional[str] = None + redis_available: bool + disabled_agent_count: int = 0 + + +# ── Agent kill switch ─────────────────────────────────────────────── + + +@router.post("/agents/{agent_uuid}/kill") +async def kill_agent( + agent_uuid: UUID, + body: KillAgentRequest, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Force-disable an agent. Sets is_active=False and caches in Redis.""" + result = await session.execute(select(Agent).where(Agent.id == agent_uuid)) + agent = result.scalar_one_or_none() + if not agent: + raise NotFoundException("Agent") + + agent.is_active = False + await session.commit() + + # Cache kill in Redis for fast rejection + await safety.kill_agent(agent_uuid) + + return { + "success": True, + "agent_id": agent.agent_id, + "agent_uuid": str(agent_uuid), + "reason": body.reason, + "killed_by": str(admin.id), + } + + +@router.post("/agents/{agent_uuid}/reactivate") +async def reactivate_agent( + agent_uuid: UUID, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Re-enable a previously disabled agent.""" + result = await session.execute(select(Agent).where(Agent.id == agent_uuid)) + agent = result.scalar_one_or_none() + if not agent: + raise NotFoundException("Agent") + + agent.is_active = True + await session.commit() + + # Clear Redis kill cache + await safety.reactivate_agent(agent_uuid) + + return { + "success": True, + "agent_id": agent.agent_id, + "agent_uuid": str(agent_uuid), + "reactivated_by": str(admin.id), + } + + +@router.get("/agents/disabled", response_model=list[AgentStatusResponse]) +async def list_disabled_agents( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """List all currently disabled agents.""" + result = await session.execute( + select(Agent).where(Agent.is_active == False).order_by(Agent.updated_at.desc()) # noqa: E712 + ) + agents = result.scalars().all() + + return [ + AgentStatusResponse( + agent_id=a.agent_id, + agent_uuid=a.id, + name=a.name, + is_active=a.is_active, + owner_id=a.user_id, + ) + for a in agents + ] + + +# ── Platform halt ─────────────────────────────────────────────────── + + +@router.post("/platform/halt") +async def halt_platform( + body: HaltPlatformRequest, + admin: User = Depends(get_admin_user), +): + """Emergency halt — suspend all agent operations platform-wide.""" + await safety.halt_platform(body.reason, admin.id) + return { + "success": True, + "halted": True, + "reason": body.reason, + "halted_by": str(admin.id), + } + + +@router.post("/platform/resume") +async def resume_platform( + admin: User = Depends(get_admin_user), +): + """Resume platform operations after an emergency halt.""" + await safety.resume_platform(admin.id) + return { + "success": True, + "halted": False, + "resumed_by": str(admin.id), + } + + +@router.get("/platform/status", response_model=PlatformStatusResponse) +async def get_platform_status( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Get current platform safety status.""" + status = await safety.get_platform_status() + + # Count disabled agents + result = await session.execute( + select(func.count()).select_from(Agent).where(Agent.is_active == False) # noqa: E712 + ) + disabled_count = result.scalar() or 0 + + return PlatformStatusResponse( + halted=status.get("halted", False), + reason=status.get("reason"), + halted_by=status.get("halted_by"), + redis_available=status.get("redis_available", False), + disabled_agent_count=disabled_count, + ) diff --git a/src/routes/safety.py b/src/routes/safety.py new file mode 100644 index 0000000..726ae08 --- /dev/null +++ b/src/routes/safety.py @@ -0,0 +1,204 @@ +"""Safety routes — public halt endpoint and admin halt code management. + +The halt endpoint is intentionally PUBLIC (no JWT required). The code +itself is the authentication. This is by design: it should be easy to +stop the platform, harder to restart it. +""" + +import secrets +from typing import Optional +from uuid import UUID + +import bcrypt +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.admin_auth import get_admin_user +from src.database import get_db +from src.models.auth import User +from src.models.halt_code import HaltCode +from src.services import safety + +router = APIRouter(prefix="/safety", tags=["Safety"]) + + +# ── Schemas ───────────────────────────────────────────────────────── + + +class HaltRequest(BaseModel): + code: str = Field(..., min_length=1, description="Halt code issued to a trustee") + reason: Optional[str] = Field(default=None, description="Optional reason for halting") + + +class HaltResponse(BaseModel): + halted: bool + trustee: str + reason: Optional[str] = None + message: str + + +class CreateHaltCodeRequest(BaseModel): + trustee_name: str = Field(..., min_length=1) + trustee_email: Optional[str] = None + label: str = Field(..., min_length=1, description="Human-readable label, e.g. 'Guardian - Europe'") + is_master: bool = Field(default=False) + + +class CreateHaltCodeResponse(BaseModel): + id: UUID + label: str + trustee_name: str + is_master: bool + code: str = Field(description="The plaintext code — shown ONCE, never stored") + + +class HaltCodeListItem(BaseModel): + id: UUID + label: str + trustee_name: str + trustee_email: Optional[str] + is_master: bool + is_active: bool + + +class PlatformStatusPublic(BaseModel): + halted: bool + message: str + + +# ── Public endpoints (no auth) ────────────────────────────────────── + + +@router.get("/status", response_model=PlatformStatusPublic) +async def get_public_status(): + """Public platform status — anyone can check if the platform is halted.""" + status = await safety.get_platform_status() + halted = status.get("halted", False) + return PlatformStatusPublic( + halted=halted, + message="Platform is halted. All agent operations are suspended." if halted + else "Platform is operational.", + ) + + +@router.post("/halt", response_model=HaltResponse) +async def halt_with_code( + body: HaltRequest, + session: AsyncSession = Depends(get_db), +): + """Halt the platform using a trustee code. + + This endpoint is PUBLIC — no JWT required. The halt code is the + authentication. By design, stopping the platform should be easy. + Restarting requires admin authentication. + """ + # Find all active halt codes and check against each + result = await session.execute( + select(HaltCode).where(HaltCode.is_active == True) # noqa: E712 + ) + halt_codes = result.scalars().all() + + matched_code = None + for hc in halt_codes: + if bcrypt.checkpw(body.code.encode("utf-8"), hc.code_hash.encode("utf-8")): + matched_code = hc + break + + if not matched_code: + from src.exceptions import ForbiddenException + raise ForbiddenException("Invalid halt code") + + reason = body.reason or f"Halted by trustee: {matched_code.trustee_name}" + await safety.halt_platform(reason, matched_code.created_by) + + return HaltResponse( + halted=True, + trustee=matched_code.trustee_name, + reason=reason, + message="Platform has been halted. All agent operations are suspended.", + ) + + +# ── Admin endpoints (manage halt codes) ───────────────────────────── + + +@router.post("/codes", response_model=CreateHaltCodeResponse) +async def create_halt_code( + body: CreateHaltCodeRequest, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Create a new halt code for a trustee. The plaintext code is returned + ONCE and never stored — only its bcrypt hash is persisted.""" + # Generate a secure random code: 8 groups of 4 chars + raw_code = "-".join( + secrets.token_hex(2).upper() for _ in range(4) + ) + + code_hash = bcrypt.hashpw( + raw_code.encode("utf-8"), bcrypt.gensalt() + ).decode("utf-8") + + halt_code = HaltCode( + code_hash=code_hash, + label=body.label, + trustee_name=body.trustee_name, + trustee_email=body.trustee_email, + is_master=body.is_master, + created_by=admin.id, + ) + session.add(halt_code) + await session.commit() + await session.refresh(halt_code) + + return CreateHaltCodeResponse( + id=halt_code.id, + label=body.label, + trustee_name=body.trustee_name, + is_master=body.is_master, + code=raw_code, + ) + + +@router.get("/codes", response_model=list[HaltCodeListItem]) +async def list_halt_codes( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """List all halt codes (without the actual codes — those are never stored).""" + result = await session.execute( + select(HaltCode).order_by(HaltCode.created_at.desc()) + ) + codes = result.scalars().all() + return [ + HaltCodeListItem( + id=c.id, + label=c.label, + trustee_name=c.trustee_name, + trustee_email=c.trustee_email, + is_master=c.is_master, + is_active=c.is_active, + ) + for c in codes + ] + + +@router.delete("/codes/{code_id}") +async def revoke_halt_code( + code_id: UUID, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Revoke a halt code — it can no longer be used to halt the platform.""" + result = await session.execute(select(HaltCode).where(HaltCode.id == code_id)) + halt_code = result.scalar_one_or_none() + if not halt_code: + from src.exceptions import NotFoundException + raise NotFoundException("Halt code") + + halt_code.is_active = False + await session.commit() + + return {"success": True, "revoked": str(code_id)} diff --git a/src/schemas/registry.py b/src/schemas/registry.py index d566e24..9b5d1c6 100644 --- a/src/schemas/registry.py +++ b/src/schemas/registry.py @@ -115,6 +115,7 @@ class AgentUpdate(BaseModel): base_price: Optional[float] = None pricing_enabled: Optional[bool] = None supports_streaming: Optional[bool] = None + is_active: Optional[bool] = None @field_validator("auth_type") @classmethod diff --git a/src/services/broker.py b/src/services/broker.py index 2f2e634..de8b81c 100644 --- a/src/services/broker.py +++ b/src/services/broker.py @@ -88,6 +88,10 @@ async def invoke_agent( """ start_time = time.time() + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + # Resolve conversation_id and message_id from request if not passed conv_id = conversation_id or invoke_request.conversation_id msg_id = message_id or invoke_request.message_id @@ -532,13 +536,25 @@ async def invoke_agent_stream( async generator of dicts. Otherwise, falls back to a normal invocation and returns an InvokeResponse. """ + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + # Resolve agent to check streaming support agent = await self.registry_repository.get_agent_by_agent_id( invoke_request.agent_id ) - supports_streaming = ( - getattr(agent, "supports_streaming", False) if agent else False - ) + + # Check agent exists and is active (fixes gap where streaming path skipped this) + if not agent or not agent.is_active: + return InvokeResponse( + success=False, + error="Agent not found or inactive", + latency_ms=0, + status_code=404, + ) + + supports_streaming = getattr(agent, "supports_streaming", False) if not supports_streaming: # Fall back to normal invocation diff --git a/src/services/registry.py b/src/services/registry.py index ace1c91..4686a27 100644 --- a/src/services/registry.py +++ b/src/services/registry.py @@ -271,6 +271,8 @@ async def update_agent( agent.base_price = update.base_price if update.pricing_enabled is not None: agent.pricing_enabled = update.pricing_enabled + if update.is_active is not None: + agent.is_active = update.is_active # Regenerate embedding if enhance: diff --git a/src/services/safety.py b/src/services/safety.py new file mode 100644 index 0000000..4f87576 --- /dev/null +++ b/src/services/safety.py @@ -0,0 +1,145 @@ +"""Safety service — platform halt, agent kill switch, and safety checks. + +Provides the central "off switch" for the Intuno platform. All communication +chokepoints (broker, channels, A2A) call into this service before processing. +""" + +import logging +from typing import Optional +from uuid import UUID + +from src.core.redis_client import get_redis +from src.core.settings import settings +from src.exceptions import AgentDisabledException, PlatformHaltedException + +logger = logging.getLogger(__name__) + +# Redis key constants +EMERGENCY_HALT_KEY = "platform:emergency_halt" +EMERGENCY_HALT_REASON_KEY = "platform:emergency_halt:reason" +EMERGENCY_HALT_ACTOR_KEY = "platform:emergency_halt:actor" +AGENT_STATUS_PREFIX = "agent:status:" + + +async def check_platform_halt() -> None: + """Raise PlatformHaltedException if the platform is in emergency halt. + + This is designed to be called at every communication chokepoint. + Fast O(1) Redis GET — adds ~0.1ms overhead per call. + Fails open if Redis is unavailable (consistent with rate limiter pattern). + """ + if not settings.SAFETY_CHECK_ENABLED: + return + + redis = await get_redis() + if not redis: + return # Fail open: if Redis is down, allow requests + + try: + halted = await redis.get(EMERGENCY_HALT_KEY) + if halted == "1": + reason = await redis.get(EMERGENCY_HALT_REASON_KEY) + detail = "Platform is in emergency halt mode." + if reason: + detail += f" Reason: {reason}" + raise PlatformHaltedException(detail) + except PlatformHaltedException: + raise + except Exception as e: + logger.warning("Safety check (platform halt) failed: %s", e) + + +async def check_agent_active(agent_id: UUID) -> None: + """Raise AgentDisabledException if the agent has been killed/deactivated. + + Checks Redis cache first, falls back to no-op if unavailable. + The authoritative is_active check remains in the broker/service layer + via the DB — this adds a fast-path rejection for killed agents. + """ + if not settings.SAFETY_CHECK_ENABLED: + return + + redis = await get_redis() + if not redis: + return # Fail open + + try: + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + cached = await redis.get(key) + if cached == "0": + raise AgentDisabledException() + except AgentDisabledException: + raise + except Exception as e: + logger.warning("Safety check (agent status) failed: %s", e) + + +async def halt_platform(reason: str, actor_id: UUID) -> None: + """Activate emergency halt — all agent operations will be rejected.""" + redis = await get_redis() + if not redis: + raise RuntimeError("Redis is required for platform halt") + + await redis.set(EMERGENCY_HALT_KEY, "1") + await redis.set(EMERGENCY_HALT_REASON_KEY, reason) + await redis.set(EMERGENCY_HALT_ACTOR_KEY, str(actor_id)) + logger.critical( + "PLATFORM HALT activated by user %s. Reason: %s", + actor_id, + reason, + ) + + +async def resume_platform(actor_id: UUID) -> None: + """Deactivate emergency halt — resume normal operations.""" + redis = await get_redis() + if not redis: + raise RuntimeError("Redis is required for platform resume") + + await redis.delete(EMERGENCY_HALT_KEY) + await redis.delete(EMERGENCY_HALT_REASON_KEY) + await redis.delete(EMERGENCY_HALT_ACTOR_KEY) + logger.critical("PLATFORM HALT deactivated by user %s", actor_id) + + +async def kill_agent(agent_id: UUID) -> None: + """Cache agent as killed in Redis for fast rejection at chokepoints.""" + redis = await get_redis() + if not redis: + return + + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + await redis.set(key, "0", ex=settings.AGENT_STATUS_CACHE_TTL) + logger.warning("Agent %s killed (cached in Redis)", agent_id) + + +async def reactivate_agent(agent_id: UUID) -> None: + """Remove killed status from Redis cache.""" + redis = await get_redis() + if not redis: + return + + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + await redis.delete(key) + logger.info("Agent %s reactivated (Redis cache cleared)", agent_id) + + +async def get_platform_status() -> dict: + """Get current platform safety status.""" + redis = await get_redis() + if not redis: + return {"halted": False, "redis_available": False} + + try: + halted = await redis.get(EMERGENCY_HALT_KEY) + reason = await redis.get(EMERGENCY_HALT_REASON_KEY) + actor = await redis.get(EMERGENCY_HALT_ACTOR_KEY) + return { + "halted": halted == "1", + "reason": reason, + "halted_by": actor, + "redis_available": True, + } + except Exception as e: + logger.warning("Failed to get platform status: %s", e) + return {"halted": False, "redis_available": False, "error": str(e)} diff --git a/src/workflow/models/dsl.py b/src/workflow/models/dsl.py index 6bc01b2..e71f8d1 100644 --- a/src/workflow/models/dsl.py +++ b/src/workflow/models/dsl.py @@ -53,11 +53,49 @@ class StepConditionBranch(BaseModel): goto: str +class ConvergenceConfig(BaseModel): + """Configuration for loop convergence detection.""" + + type: str = Field( + default="max_iterations", + description="Convergence strategy: 'similarity', 'approval', or 'max_iterations'.", + ) + threshold: float | None = Field( + default=None, + description="Threshold for similarity-based convergence (0.0 to 1.0).", + ) + + +class LoopConfig(BaseModel): + """Configuration for loop (feedback cycle) steps.""" + + max_iterations: int = Field(default=5, ge=1, le=50) + convergence: ConvergenceConfig = Field(default_factory=ConvergenceConfig) + body: list["WorkflowStep"] = Field( + ..., description="Steps to execute in each iteration of the loop." + ) + + +class AggregateConfig(BaseModel): + """Configuration for fan-in aggregation steps.""" + + sources: list[str] = Field( + ..., description="Step IDs whose outputs to aggregate." + ) + strategy: str = Field( + default="merge", + description="Aggregation strategy: 'merge', 'vote', or 'llm_summarize'.", + ) + timeout_seconds: int | None = Field( + default=None, description="Max wait time for all sources to complete." + ) + + class WorkflowStep(BaseModel): id: str type: str | None = Field( default=None, - description="Explicit step type: 'condition', 'sub_workflow', or 'plan'. " + description="Explicit step type: 'condition', 'sub_workflow', 'plan', 'loop', or 'aggregate'. " "Inferred as 'agent' or 'skill' from agent/skill fields when omitted.", ) agent: str | None = None @@ -75,11 +113,23 @@ class WorkflowStep(BaseModel): parallel_with: str | None = None recovery: RecoveryConfig | None = None when: list[StepConditionBranch] | None = None + loop: LoopConfig | None = Field( + default=None, + description="Loop configuration for feedback cycle steps (type='loop').", + ) + aggregate: AggregateConfig | None = Field( + default=None, + description="Aggregation configuration for fan-in steps (type='aggregate').", + ) @property def resolved_type(self) -> str: if self.type: return self.type + if self.loop is not None: + return "loop" + if self.aggregate is not None: + return "aggregate" if self.when is not None: return "condition" if self.workflow is not None: diff --git a/src/workflow/utils/dsl_parser.py b/src/workflow/utils/dsl_parser.py index bc1d19c..286b4ef 100644 --- a/src/workflow/utils/dsl_parser.py +++ b/src/workflow/utils/dsl_parser.py @@ -106,6 +106,25 @@ def _validate_step_refs(wf: WorkflowDef) -> None: raise DSLParseError( f"Plan step '{step.id}' must have a 'goal' field" ) + if step.resolved_type == "loop" and not step.loop: + raise DSLParseError( + f"Loop step '{step.id}' must have a 'loop' configuration" + ) + if step.resolved_type == "loop" and step.loop: + if not step.loop.body: + raise DSLParseError( + f"Loop step '{step.id}' must have at least one step in its body" + ) + if step.resolved_type == "aggregate" and not step.aggregate: + raise DSLParseError( + f"Aggregate step '{step.id}' must have an 'aggregate' configuration" + ) + if step.resolved_type == "aggregate" and step.aggregate: + for source_id in step.aggregate.sources: + if source_id not in step_ids: + raise DSLParseError( + f"Aggregate step '{step.id}' references unknown source '{source_id}'" + ) def _detect_cycles(wf: WorkflowDef) -> None: diff --git a/src/workflow/utils/orchestrator.py b/src/workflow/utils/orchestrator.py index 6c6d426..298709b 100644 --- a/src/workflow/utils/orchestrator.py +++ b/src/workflow/utils/orchestrator.py @@ -209,6 +209,21 @@ async def _execute_step( ) return + if step.resolved_type == "loop": + await self._handle_loop( + step, entry_id, context_id, trigger_data, + engine, step_outputs, default_recovery, step_map, + skipped, execution_id, workflow_def, t0, + ) + return + + if step.resolved_type == "aggregate": + await self._handle_aggregate( + step, entry_id, context_id, + step_outputs, execution_id, t0, + ) + return + if step.resolved_type == "plan": await self._handle_plan( step, entry_id, context_id, trigger_data, @@ -468,3 +483,172 @@ async def _handle_plan( workflow_def, ) ) + + # -- Loop (feedback cycle) handling ---------------------------------------- + + async def _handle_loop( + self, + step: WorkflowStep, + entry_id: uuid.UUID, + context_id: uuid.UUID, + trigger_data: dict[str, Any], + engine: TemplateEngine, + step_outputs: dict[str, Any], + default_recovery: RecoveryConfig, + step_map: dict[str, WorkflowStep], + skipped: set[str], + execution_id: uuid.UUID, + workflow_def: WorkflowDef, + t0: float, + ) -> None: + """Execute a feedback loop until convergence or max iterations.""" + from src.network.utils.convergence import ( + MaxIterationsDetector, + create_detector, + ) + + if not step.loop: + raise StepExecutionError(f"Loop step '{step.id}' has no loop config") + + loop_cfg = step.loop + max_iter = loop_cfg.max_iterations + + # Create convergence detector + mandatory max iterations safety net + convergence = create_detector( + loop_cfg.convergence.type, + { + "threshold": loop_cfg.convergence.threshold, + "max_iterations": max_iter, + }, + ) + safety = MaxIterationsDetector(max_iterations=max_iter) + + previous_output: Any = None + iteration_outputs: list[Any] = [] + + for iteration in range(1, max_iter + 1): + logger.info( + "Loop '%s' iteration %d/%d", step.id, iteration, max_iter, + ) + + # Run all body steps sequentially for this iteration + iter_outputs: dict[str, Any] = {} + for body_step in loop_cfg.body: + # Create a process entry for this iteration's step + iter_step_id = f"{step.id}::iter{iteration}::{body_step.id}" + entry = await self._exec_repo.create_process_entry( + execution_id=execution_id, + step_id=iter_step_id, + step_type=body_step.resolved_type, + target_name=body_step.target_ref or body_step.id, + ) + + # Inject loop iteration context + loop_context = { + "iteration": iteration, + "previous_output": previous_output, + "iteration_outputs": iteration_outputs, + } + step_outputs[f"{step.id}::loop"] = loop_context + + await self._execute_step( + body_step, + entry.id, + context_id, + trigger_data, + step_outputs, + default_recovery, + {bs.id: bs for bs in loop_cfg.body}, + set(), + execution_id, + workflow_def, + ) + + iter_outputs[body_step.id] = step_outputs.get(body_step.id) + + # The last body step's output is considered the iteration output + last_body_step = loop_cfg.body[-1] + current_output = step_outputs.get(last_body_step.id, {}).get("output") + iteration_outputs.append(current_output) + + # Check convergence + converged = await convergence.has_converged( + iteration, current_output, previous_output, + {"step_outputs": step_outputs}, + ) + hit_max = await safety.has_converged( + iteration, current_output, previous_output, {}, + ) + + previous_output = current_output + + if converged or hit_max: + reason = "converged" if converged else "max_iterations" + logger.info( + "Loop '%s' stopped after %d iterations: %s", + step.id, iteration, reason, + ) + break + + duration_ms = int((time.monotonic() - t0) * 1000) + output_data = { + "iterations": len(iteration_outputs), + "converged": converged if 'converged' in dir() else False, + "final_output": previous_output, + } + + step_outputs[step.id] = {"output": output_data} + await self._ctx.write(context_id, step.id, output_data) + await self._exec_repo.mark_process_completed( + entry_id, output=output_data, duration_ms=duration_ms, + ) + + # -- Aggregate (fan-in) handling ------------------------------------------- + + async def _handle_aggregate( + self, + step: WorkflowStep, + entry_id: uuid.UUID, + context_id: uuid.UUID, + step_outputs: dict[str, Any], + execution_id: uuid.UUID, + t0: float, + ) -> None: + """Aggregate outputs from multiple source steps.""" + from src.network.utils.aggregator import create_aggregator + + if not step.aggregate: + raise StepExecutionError( + f"Aggregate step '{step.id}' has no aggregate config" + ) + + agg_cfg = step.aggregate + + # Collect outputs from source steps + inputs: list[dict[str, Any]] = [] + missing_sources: list[str] = [] + for source_id in agg_cfg.sources: + source_output = step_outputs.get(source_id) + if source_output is None: + missing_sources.append(source_id) + else: + inputs.append({ + "source": source_id, + "output": source_output.get("output"), + }) + + if missing_sources: + logger.warning( + "Aggregate step '%s' is missing outputs from: %s", + step.id, missing_sources, + ) + + aggregator = create_aggregator(agg_cfg.strategy) + result = await aggregator.aggregate(inputs) + + duration_ms = int((time.monotonic() - t0) * 1000) + step_outputs[step.id] = {"output": result} + await self._ctx.write(context_id, step.id, result) + await self._exec_repo.mark_process_completed( + entry_id, output=result, duration_ms=duration_ms, + ) diff --git a/src/workflow/utils/resolver.py b/src/workflow/utils/resolver.py index 7876a98..790d482 100644 --- a/src/workflow/utils/resolver.py +++ b/src/workflow/utils/resolver.py @@ -50,6 +50,10 @@ async def resolve( if cache_key in self._cache: return self._cache[cache_key] + # Safety check: reject if platform is halted + from src.services.safety import check_platform_halt + await check_platform_halt() + if ref.startswith(SEARCH_PREFIX): query = ref[len(SEARCH_PREFIX):].strip() target = await self._discover(query, exclude_ids or []) @@ -133,6 +137,10 @@ async def _discover( for agent, _distance in results: if agent.agent_id in cb_excluded: continue + if not agent.is_active: + logger.info("Skipping agent '%s' — inactive", agent.agent_id) + cb_excluded.append(agent.agent_id) + continue available = await self._circuit_breaker.is_available(agent.agent_id) if not available: logger.info( diff --git a/tests/conftest.py b/tests/conftest.py index 8cbac7a..c699b35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ def pytest_collection_modifyitems(config, items): "test_economy.py", "test_workflow.py", "test_new_orchestrator.py", + "test_networks.py", } for item in items: if item.path and item.path.name in integration_files: diff --git a/tests/test_networks.py b/tests/test_networks.py new file mode 100644 index 0000000..b4c0380 --- /dev/null +++ b/tests/test_networks.py @@ -0,0 +1,528 @@ +"""Integration tests for communication networks and multi-directional orchestration. + +These tests exercise the full network lifecycle: creating networks, +adding participants, exchanging messages and mailbox items, verifying +shared context, and importing A2A agents. + +Requires a running backend at BASE_URL (default: http://localhost:8000). +Mark: integration (auto-applied by conftest.py). + +Run: + pytest tests/test_networks.py -v + pytest tests/test_networks.py -v -k "test_network_lifecycle" +""" + +import asyncio +import os +import uuid + +import httpx +import pytest + +BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000") +TIMEOUT = 15 + + +# ── Helpers ────────────────────────────────────────────────────────── + + +async def register_and_login(client: httpx.AsyncClient) -> str: + """Register a test user and return a JWT token.""" + email = f"test-net-{uuid.uuid4().hex[:8]}@test.local" + password = "TestPass123!" + + await client.post( + f"{BASE_URL}/auth/register", + json={ + "email": email, + "password": password, + "first_name": "Net", + "last_name": "Test", + }, + ) + resp = await client.post( + f"{BASE_URL}/auth/login", + json={"email": email, "password": password}, + ) + assert resp.status_code == 200, f"Login failed: {resp.text}" + return resp.json()["access_token"] + + +def auth(token: str) -> dict: + return {"Authorization": f"Bearer {token}"} + + +async def register_test_agent( + client: httpx.AsyncClient, token: str, name: str +) -> dict: + """Register a simple test agent and return its data.""" + resp = await client.post( + f"{BASE_URL}/registry/agents", + headers=auth(token), + json={ + "name": name, + "description": f"Test agent for network integration: {name}", + "endpoint": f"https://httpbin.org/post", + "tags": ["test", "network"], + }, + ) + assert resp.status_code in (200, 201), f"Agent registration failed: {resp.text}" + return resp.json() + + +# ── Fixtures ───────────────────────────────────────────────────────── + + +@pytest.fixture +async def client(): + async with httpx.AsyncClient(timeout=TIMEOUT) as c: + yield c + + +@pytest.fixture +async def authed(client): + """Return (client, token) tuple.""" + token = await register_and_login(client) + return client, token + + +# ── Network Lifecycle ──────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_network_lifecycle(authed): + """Create a network, add participants, exchange messages, verify context.""" + client, token = authed + headers = auth(token) + + # 1. Create network + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Test Mesh Network", "topology_type": "mesh"}, + ) + assert resp.status_code == 201, f"Create network failed: {resp.text}" + network = resp.json() + network_id = network["id"] + assert network["name"] == "Test Mesh Network" + assert network["topology_type"] == "mesh" + assert network["status"] == "active" + + # 2. List networks + resp = await client.get(f"{BASE_URL}/networks", headers=headers) + assert resp.status_code == 200 + networks = resp.json() + assert any(n["id"] == network_id for n in networks) + + # 3. Add participants (with polling since we don't have real callback URLs) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Alice", + "participant_type": "persona", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201, f"Add participant failed: {resp.text}" + alice = resp.json() + alice_id = alice["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Bob", + "participant_type": "persona", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201 + bob = resp.json() + bob_id = bob["id"] + + # 4. List participants + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/participants", headers=headers + ) + assert resp.status_code == 200 + participants = resp.json() + assert len(participants) == 2 + names = {p["name"] for p in participants} + assert names == {"Alice", "Bob"} + + # 5. Send a message (Alice → Bob) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": alice_id, + "recipient_participant_id": bob_id, + "content": "Hey Bob, what do you think about the proposal?", + }, + ) + assert resp.status_code == 201, f"Send message failed: {resp.text}" + msg1 = resp.json() + assert msg1["channel_type"] == "message" + assert msg1["sender_participant_id"] == alice_id + + # 6. Send to mailbox (Bob → Alice) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/mailbox", + headers=headers, + json={ + "sender_participant_id": bob_id, + "recipient_participant_id": alice_id, + "content": "I'll review it tonight and get back to you.", + }, + ) + assert resp.status_code == 201 + msg2 = resp.json() + assert msg2["channel_type"] == "mailbox" + + # 7. Check inbox (Alice should see Bob's mailbox message) + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/inbox/{alice_id}", + headers=headers, + ) + assert resp.status_code == 200 + inbox = resp.json() + assert len(inbox) >= 1 + + # 8. Check shared context + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/context", + headers=headers, + ) + assert resp.status_code == 200 + context = resp.json() + assert len(context["entries"]) >= 2 + senders = {e["sender"] for e in context["entries"]} + assert "Alice" in senders + assert "Bob" in senders + + # 9. List messages + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/messages", + headers=headers, + ) + assert resp.status_code == 200 + messages = resp.json() + assert len(messages) >= 2 + + # 10. Acknowledge messages + message_ids = [m["id"] for m in messages[:1]] + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/ack", + headers=headers, + json={"message_ids": message_ids}, + ) + assert resp.status_code == 200 + assert resp.json()["acknowledged"] == 1 + + # 11. Update network + resp = await client.patch( + f"{BASE_URL}/networks/{network_id}", + headers=headers, + json={"name": "Updated Mesh Network"}, + ) + assert resp.status_code == 200 + assert resp.json()["name"] == "Updated Mesh Network" + + # 12. Remove participant + resp = await client.delete( + f"{BASE_URL}/networks/{network_id}/participants/{bob_id}", + headers=headers, + ) + assert resp.status_code == 204 + + # 13. Verify Bob is removed + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/participants", headers=headers + ) + assert resp.status_code == 200 + assert len(resp.json()) == 1 + + # 14. Delete network + resp = await client.delete( + f"{BASE_URL}/networks/{network_id}", headers=headers + ) + assert resp.status_code == 204 + + +# ── Callback (Bidirectional) ───────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_callback_bidirectional(authed): + """External agent pushes a message back via the callback endpoint.""" + client, token = authed + headers = auth(token) + + # Create network + participants + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Callback Test Network"}, + ) + network_id = resp.json()["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={"name": "Agent A", "polling_enabled": True}, + ) + agent_a_id = resp.json()["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={"name": "Agent B", "polling_enabled": True}, + ) + agent_b_id = resp.json()["id"] + + # Agent A sends a message + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": agent_a_id, + "recipient_participant_id": agent_b_id, + "content": "Can you analyze this data?", + }, + ) + assert resp.status_code == 201 + original_msg_id = resp.json()["id"] + + # Agent B responds proactively via callback (no auth required) + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants/{agent_b_id}/callback", + json={ + "content": "Analysis complete. Found 3 anomalies.", + "recipient_participant_id": agent_a_id, + "channel_type": "message", + "in_reply_to_id": original_msg_id, + }, + ) + assert resp.status_code == 200, f"Callback failed: {resp.text}" + callback_msg = resp.json() + assert callback_msg["sender_participant_id"] == agent_b_id + + # Verify context has both messages + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/context", headers=headers + ) + context = resp.json() + assert len(context["entries"]) >= 2 + contents = [e["content"] for e in context["entries"]] + assert any("analyze" in c for c in contents) + assert any("anomalies" in c for c in contents) + + # Cleanup + await client.delete(f"{BASE_URL}/networks/{network_id}", headers=headers) + + +# ── Multi-Participant Context Sharing ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_multi_participant_context(authed): + """Three participants exchange messages; all see the full context.""" + client, token = authed + headers = auth(token) + + # Create network with 3 participants + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Multi-Party Context Test"}, + ) + network_id = resp.json()["id"] + + participant_ids = {} + for name in ["Persona A", "Persona B", "Persona C"]: + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={"name": name, "participant_type": "persona", "polling_enabled": True}, + ) + participant_ids[name] = resp.json()["id"] + + # A → B + await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": participant_ids["Persona A"], + "recipient_participant_id": participant_ids["Persona B"], + "content": "Hey B, should we include C in this discussion?", + }, + ) + + # B → C + await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": participant_ids["Persona B"], + "recipient_participant_id": participant_ids["Persona C"], + "content": "C, A wants to loop you in. Thoughts on the project?", + }, + ) + + # C → A (proactive via callback) + await client.post( + f"{BASE_URL}/networks/{network_id}/participants/{participant_ids['Persona C']}/callback", + json={ + "content": "Thanks for including me! I have some ideas to share.", + "recipient_participant_id": participant_ids["Persona A"], + }, + ) + + # Verify shared context has all 3 messages from all 3 senders + resp = await client.get( + f"{BASE_URL}/networks/{network_id}/context", headers=headers + ) + context = resp.json() + senders = {e["sender"] for e in context["entries"]} + assert senders == {"Persona A", "Persona B", "Persona C"} + assert len(context["entries"]) == 3 + + # Cleanup + await client.delete(f"{BASE_URL}/networks/{network_id}", headers=headers) + + +# ── A2A Agent Card ─────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_a2a_platform_card(authed): + """Verify the platform A2A Agent Card is served correctly.""" + client, token = authed + + resp = await client.get(f"{BASE_URL}/.well-known/agent.json") + assert resp.status_code == 200 + card = resp.json() + assert card["name"] == "Intuno Agent Network" + assert "capabilities" in card + assert card["capabilities"]["networks"] is True + assert "call" in card["capabilities"]["channels"] + assert "message" in card["capabilities"]["channels"] + assert "mailbox" in card["capabilities"]["channels"] + assert len(card["skills"]) >= 4 + + +@pytest.mark.asyncio +async def test_a2a_agent_card_for_registered_agent(authed): + """Register an agent and verify its A2A card is generated.""" + client, token = authed + headers = auth(token) + + agent = await register_test_agent(client, token, "A2A Card Test Agent") + agent_id = agent["agent_id"] + + resp = await client.get( + f"{BASE_URL}/a2a/agents/{agent_id}/agent-card", headers=headers + ) + assert resp.status_code == 200 + card = resp.json() + assert card["name"] == "A2A Card Test Agent" + assert "skills" in card + assert len(card["skills"]) >= 1 + + +# ── A2A Fetch Card Preview ────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_a2a_fetch_card_from_self(authed): + """Fetch Intuno's own Agent Card via the preview endpoint.""" + client, token = authed + headers = auth(token) + + resp = await client.get( + f"{BASE_URL}/a2a/agents/fetch-card", + headers=headers, + params={"url": BASE_URL}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["card"]["name"] == "Intuno Agent Network" + + +# ── Network with Agent Participants ────────────────────────────────── + + +@pytest.mark.asyncio +async def test_network_with_agent_participants(authed): + """Create a network where participants are linked to registered agents.""" + client, token = authed + headers = auth(token) + + # Register two agents + agent1 = await register_test_agent(client, token, "Network Agent Alpha") + agent2 = await register_test_agent(client, token, "Network Agent Beta") + + # Create network + resp = await client.post( + f"{BASE_URL}/networks", + headers=headers, + json={"name": "Agent Network Test"}, + ) + network_id = resp.json()["id"] + + # Add agents as participants + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Alpha", + "agent_id": agent1["id"], + "participant_type": "agent", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201 + alpha_id = resp.json()["id"] + assert resp.json()["agent_id"] == agent1["id"] + + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Beta", + "agent_id": agent2["id"], + "participant_type": "agent", + "polling_enabled": True, + }, + ) + assert resp.status_code == 201 + beta_id = resp.json()["id"] + + # Verify duplicate agent is rejected + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/participants", + headers=headers, + json={ + "name": "Alpha Duplicate", + "agent_id": agent1["id"], + "polling_enabled": True, + }, + ) + assert resp.status_code == 400 + + # Exchange messages + resp = await client.post( + f"{BASE_URL}/networks/{network_id}/messages/send", + headers=headers, + json={ + "sender_participant_id": alpha_id, + "recipient_participant_id": beta_id, + "content": "Beta, can you process dataset X?", + }, + ) + assert resp.status_code == 201 + + # Cleanup + await client.delete(f"{BASE_URL}/networks/{network_id}", headers=headers)