From f930d01dfd01dd94b319604a67af2625d09e868b Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 12:16:40 +0530 Subject: [PATCH 01/20] Add modular billing credit ledger --- src/api/routes/billing.py | 610 ++++++++++++------------------ src/api/routes/v2/activities.py | 10 +- src/api/routes/v2/jobs.py | 20 + src/api/routes/v2/memory.py | 91 ++++- src/billing/__init__.py | 31 ++ src/billing/metering.py | 49 +++ src/billing/razorpay.py | 107 ++++++ src/billing/service.py | 342 +++++++++++++++++ src/billing/store.py | 648 ++++++++++++++++++++++++++++++++ src/billing/types.py | 97 +++++ src/config/settings.py | 4 + src/utils/billing.py | 72 ++++ tests/test_billing.py | 126 +++++++ 13 files changed, 1829 insertions(+), 378 deletions(-) create mode 100644 src/billing/__init__.py create mode 100644 src/billing/metering.py create mode 100644 src/billing/razorpay.py create mode 100644 src/billing/service.py create mode 100644 src/billing/store.py create mode 100644 src/billing/types.py create mode 100644 src/utils/billing.py create mode 100644 tests/test_billing.py diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 449d661c..6b54230f 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -1,255 +1,55 @@ -"""Billing and Razorpay payment routes.""" +"""Billing routes backed by the modular credit ledger.""" from __future__ import annotations -import hashlib -import hmac +import json import logging from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Optional import httpx -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, Field +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from pydantic import BaseModel from src.api.dependencies import get_current_user +from src.billing.razorpay import ( + RazorpayConfigError, + create_order, + create_subscription, + require_razorpay_keys, + verify_order_signature, + verify_subscription_signature, + verify_webhook_signature, +) +from src.billing.service import get_default_billing_service, public_plans, public_topups +from src.billing.types import ( + BillingSummary, + CheckoutRequest, + CheckoutResponse, + LedgerEntryPublic, + PlanPublic, + TopUpPackPublic, + VerifyPaymentRequest, +) from src.config import settings +from src.utils import billing as billing_config logger = logging.getLogger("xmem.api.billing") router = APIRouter(prefix="/api/billing", tags=["Billing"]) -class BillingPlan(BaseModel): - id: str - name: str - amount: int - currency: str - description: str - features: List[str] - - -class UsageSnapshot(BaseModel): - memories_written: int = 0 - retrievals: int = 0 - graph_queries: int = 0 - credits_used: int = 0 - credits_limit: int = 5000 - - -class Invoice(BaseModel): - id: str - date: datetime - amount_paise: int - status: Literal["paid", "pending", "failed"] - credits: int = 0 - receipt_url: Optional[str] = None - - -class BillingSummary(BaseModel): - plan_name: str - account_status: Literal["active", "trial", "paused", "past_due"] - currency: str - credit_balance: int - prepaid_balance_paise: int - current_month: UsageSnapshot - next_invoice_paise: int - last_payment_at: Optional[datetime] = None - invoices: List[Invoice] = Field(default_factory=list) - - class BillingSummaryResponse(BaseModel): summary: BillingSummary - plans: List[BillingPlan] - + plans: list[PlanPublic] + topups: list[TopUpPackPublic] -class CreateRazorpayOrderRequest(BaseModel): - package_id: str = Field(..., description="Plan/package ID selected by the user") - credits: int = Field(default=0, ge=0) - amount: int = Field(default=0, ge=0) - currency: str = Field(default="INR", min_length=3, max_length=3) - -class RazorpayOrderResponse(BaseModel): - id: str - order_id: str - amount: int - currency: str - key_id: str - receipt: str - package_id: str - - -class VerifyRazorpayPaymentRequest(BaseModel): - razorpay_payment_id: str - razorpay_order_id: str - razorpay_signature: str - package_id: str - credits: int = Field(default=0, ge=0) - amount: int = Field(default=0, ge=0) - currency: str = Field(default="INR", min_length=3, max_length=3) - - -class VerifyRazorpayPaymentResponse(BaseModel): - status: Literal["ok"] +class VerifyPaymentResponse(BaseModel): + status: str = "ok" summary: BillingSummary -_PLANS: Dict[str, BillingPlan] = { - "free": BillingPlan( - id="free", - name="Free", - amount=0, - currency="USD", - description="30 days free with access to the core platform, Chrome extension, MCP, and SDKs.", - features=[ - "Full XMem dashboard access", - "Chrome extension included", - "MCP server access included", - "Python and TypeScript SDKs included", - "No credit card required", - ], - ), - "pro": BillingPlan( - id="pro", - name="Pro", - amount=100, - currency="USD", - description="Full access for production apps, priority support, and pay-as-you-go usage.", - features=[ - "Everything in Free", - "Production-ready API access", - "Pay-as-you-go usage for higher volume", - "24/7 customer support", - "Access to exclusive features coming soon", - ], - ), - "enterprise": BillingPlan( - id="enterprise", - name="Enterprise", - amount=0, - currency="USD", - description="Dedicated onboarding, custom limits, security reviews, and team support.", - features=[ - "Everything in Pro", - "Custom usage limits", - "Security and procurement support", - "Dedicated onboarding", - ], - ), -} - -_in_memory_billing: Dict[str, Dict[str, Any]] = {} -_in_memory_orders: Dict[str, Dict[str, Any]] = {} - - -class BillingStore: - """Small billing metadata store backed by MongoDB with local memory fallback.""" - - def __init__(self) -> None: - self._client = None - self._db = None - self.billing = None - self.orders = None - self._connected = False - self._in_memory = False - self._try_connect() - - def _requires_durable_storage(self) -> bool: - return settings.environment.lower() in {"production", "prod"} - - def _enable_memory_fallback(self, error: Exception) -> None: - message = f"MongoDB connection failed for billing storage: {error}" - if self._requires_durable_storage(): - logger.error("%s; refusing in-memory fallback in production", message) - raise RuntimeError( - "MongoDB is required for billing storage when ENVIRONMENT=production" - ) from error - logger.warning("%s; using in-memory billing storage", message) - self._connected = False - self._in_memory = True - - def _try_connect(self) -> None: - provider = (settings.app_store_provider or "mongo").strip().lower() - if provider == "memory": - self._connected = False - self._in_memory = True - return - if provider == "postgres": - self._enable_memory_fallback(RuntimeError("Postgres billing storage is not implemented")) - return - - try: - from pymongo import ASCENDING, MongoClient - - self._client = MongoClient(settings.mongodb_uri, serverSelectionTimeoutMS=5000) - self._client.admin.command("ping") - self._db = self._client[settings.mongodb_database] - self.billing = self._db["billing_profiles"] - self.orders = self._db["billing_orders"] - self.billing.create_index([("user_id", ASCENDING)], unique=True) - self.orders.create_index([("order_id", ASCENDING)], unique=True) - self.orders.create_index([("user_id", ASCENDING)]) - self._connected = True - self._in_memory = False - except Exception as exc: - self._enable_memory_fallback(exc) - - def get_summary(self, user_id: str) -> BillingSummary: - if self._in_memory: - summary = _in_memory_billing.setdefault( - user_id, - _default_summary().model_dump(), - ) - return BillingSummary.model_validate(summary) - - doc = self.billing.find_one({"user_id": user_id}) - if not doc: - summary = _default_summary() - self.save_summary(user_id, summary) - return summary - - doc.pop("_id", None) - doc.pop("user_id", None) - return BillingSummary.model_validate(doc) - - def save_summary(self, user_id: str, summary: BillingSummary) -> None: - payload = summary.model_dump() - if self._in_memory: - _in_memory_billing[user_id] = payload - return - - self.billing.update_one( - {"user_id": user_id}, - {"$set": {"user_id": user_id, **payload}}, - upsert=True, - ) - - def save_order(self, order_id: str, order: Dict[str, Any]) -> None: - payload = {"order_id": order_id, **order} - if self._in_memory: - _in_memory_orders[order_id] = payload - return - - self.orders.update_one( - {"order_id": order_id}, - {"$set": payload}, - upsert=True, - ) - - def get_order(self, order_id: str) -> Optional[Dict[str, Any]]: - if self._in_memory: - return _in_memory_orders.get(order_id) - - doc = self.orders.find_one({"order_id": order_id}) - if doc: - doc.pop("_id", None) - return doc - - -_billing_store = BillingStore() - - async def require_auth(current_user: dict = Depends(get_current_user)) -> dict: if not current_user: raise HTTPException( @@ -267,169 +67,257 @@ def _user_id(user: dict) -> str: return str(user_id) -def _default_summary() -> BillingSummary: - return BillingSummary( - plan_name="Free trial", - account_status="trial", - currency="INR", - credit_balance=5000, - prepaid_balance_paise=0, - current_month=UsageSnapshot(), - next_invoice_paise=0, - invoices=[], - ) +def _receipt(user_id: str, package_id: str) -> str: + ts = int(datetime.now(timezone.utc).timestamp()) + safe_user = user_id.replace(":", "_")[:16] + return f"xmem-{package_id}-{safe_user}-{ts}" -def _get_summary(user_id: str) -> BillingSummary: - return _billing_store.get_summary(user_id) +def _pack_or_plan(package_id: str) -> tuple[str, dict[str, Any]]: + if package_id in billing_config.PLANS: + return "plan", billing_config.PLANS[package_id] + if package_id in billing_config.TOP_UP_PACKS: + return "topup", billing_config.TOP_UP_PACKS[package_id] + raise HTTPException(status_code=400, detail="Unknown billing package") -def _require_razorpay_config() -> tuple[str, str]: - if not settings.razorpay_key_id or not settings.razorpay_key_secret: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Razorpay is not configured. Set RAZORPAY_KEY_ID and RAZORPAY_KEY_SECRET.", - ) - return settings.razorpay_key_id, settings.razorpay_key_secret - - -def _get_plan(package_id: str) -> BillingPlan: - plan = _PLANS.get(package_id) - if not plan: - raise HTTPException(status_code=400, detail="Unknown billing plan") - return plan - - -def _verify_signature(order_id: str, payment_id: str, signature: str, secret: str) -> bool: - payload = f"{order_id}|{payment_id}".encode("utf-8") - expected = hmac.new(secret.encode("utf-8"), payload, hashlib.sha256).hexdigest() - return hmac.compare_digest(expected, signature) - - -@router.get("/plans", response_model=List[BillingPlan]) -async def list_billing_plans() -> List[BillingPlan]: - """Return the server-authoritative billing plans.""" - return list(_PLANS.values()) +@router.get("/plans", response_model=list[PlanPublic]) +async def list_billing_plans() -> list[PlanPublic]: + return public_plans() @router.get("/summary", response_model=BillingSummaryResponse) async def billing_summary(current_user: dict = Depends(require_auth)) -> BillingSummaryResponse: - """Return the current user's billing summary.""" + service = get_default_billing_service() return BillingSummaryResponse( - summary=_get_summary(_user_id(current_user)), - plans=list(_PLANS.values()), + summary=service.get_billing_summary(current_user), + plans=public_plans(), + topups=public_topups(), ) -@router.post("/razorpay/order", response_model=RazorpayOrderResponse) -async def create_razorpay_order( - request: CreateRazorpayOrderRequest, +@router.post("/razorpay/order", response_model=CheckoutResponse) +async def create_razorpay_checkout( + request: CheckoutRequest, current_user: dict = Depends(require_auth), -) -> RazorpayOrderResponse: - """Create a Razorpay order for the selected plan. - - The server owns plan amount and currency. Client-supplied amount/currency are - accepted only for compatibility and are intentionally ignored. - """ - key_id, key_secret = _require_razorpay_config() - plan = _get_plan(request.package_id) - - if plan.id != "pro": - raise HTTPException(status_code=400, detail="Only the Pro plan can be purchased online") +) -> CheckoutResponse: + try: + key_id, _ = require_razorpay_keys() + except RazorpayConfigError as exc: + raise HTTPException(status_code=503, detail=str(exc)) from exc user_id = _user_id(current_user) - receipt = f"xmem-{user_id[:16]}-{int(datetime.now(timezone.utc).timestamp())}" + package_type, package = _pack_or_plan(request.package_id) + service = get_default_billing_service() + account = service.ensure_billing_account(current_user) - payload = { - "amount": plan.amount, - "currency": plan.currency, - "receipt": receipt, - "notes": { - "user_id": user_id, - "package_id": plan.id, - "plan_name": plan.name, - }, + if request.package_id == "free": + raise HTTPException(status_code=400, detail="Free plan does not require checkout") + + notes = { + "user_id": user_id, + "billing_account_id": account["id"], + "package_id": request.package_id, + "package_type": package_type, } + receipt = _receipt(user_id, request.package_id) try: - async with httpx.AsyncClient(timeout=20) as client: - response = await client.post( - "https://api.razorpay.com/v1/orders", - auth=(key_id, key_secret), - json=payload, + if request.package_id == "pro" and settings.razorpay_pro_plan_id: + subscription = await create_subscription( + plan_id=settings.razorpay_pro_plan_id, + notes=notes, + ) + checkout_id = str(subscription["id"]) + service.store.save_checkout( + checkout_id, + { + "type": "subscription", + "user_id": user_id, + "billing_account_id": account["id"], + "package_id": request.package_id, + "subscription_id": checkout_id, + "status": "created", + }, + ) + return CheckoutResponse( + id=checkout_id, + subscription_id=checkout_id, + package_id=request.package_id, + amount=int(package["price_paise"]), + currency=str(package.get("currency") or "INR"), + key_id=key_id, + receipt=receipt, ) - except httpx.HTTPError as exc: - logger.exception("Failed to create Razorpay order") - raise HTTPException(status_code=502, detail="Failed to reach Razorpay") from exc - if response.status_code >= 400: - logger.warning("Razorpay order creation failed: %s %s", response.status_code, response.text[:500]) - raise HTTPException(status_code=502, detail="Razorpay order creation failed") + amount = int(package["price_paise"]) + order = await create_order( + amount_paise=amount, + currency=str(package.get("currency") or "INR"), + receipt=receipt, + notes=notes, + ) + except httpx.HTTPError as exc: + raise HTTPException(status_code=502, detail="Razorpay checkout creation failed") from exc - order = response.json() - order_id = order["id"] - _billing_store.save_order(order_id, { - "user_id": user_id, - "package_id": plan.id, - "amount": plan.amount, - "currency": plan.currency, - "receipt": receipt, - "created_at": datetime.now(timezone.utc), - }) - - return RazorpayOrderResponse( + order_id = str(order["id"]) + service.store.save_checkout( + order_id, + { + "type": package_type, + "user_id": user_id, + "billing_account_id": account["id"], + "package_id": request.package_id, + "order_id": order_id, + "amount": amount, + "currency": str(package.get("currency") or "INR"), + "status": "created", + }, + ) + return CheckoutResponse( id=order_id, order_id=order_id, - amount=plan.amount, - currency=plan.currency, + package_id=request.package_id, + amount=amount, + currency=str(package.get("currency") or "INR"), key_id=key_id, receipt=receipt, - package_id=plan.id, ) -@router.post("/razorpay/verify", response_model=VerifyRazorpayPaymentResponse) -async def verify_razorpay_payment( - request: VerifyRazorpayPaymentRequest, +@router.post("/topups", response_model=CheckoutResponse) +async def create_topup_checkout( + request: CheckoutRequest, current_user: dict = Depends(require_auth), -) -> VerifyRazorpayPaymentResponse: - """Verify a Razorpay checkout signature and activate the paid plan.""" - _, key_secret = _require_razorpay_config() - plan = _get_plan(request.package_id) - - if not _verify_signature( - request.razorpay_order_id, - request.razorpay_payment_id, - request.razorpay_signature, - key_secret, - ): - raise HTTPException(status_code=400, detail="Invalid Razorpay signature") +) -> CheckoutResponse: + if request.package_id not in billing_config.TOP_UP_PACKS: + raise HTTPException(status_code=400, detail="Unknown top-up pack") + return await create_razorpay_checkout(request, current_user) + +@router.post("/razorpay/verify", response_model=VerifyPaymentResponse) +async def verify_razorpay_payment( + request: VerifyPaymentRequest, + current_user: dict = Depends(require_auth), +) -> VerifyPaymentResponse: + service = get_default_billing_service() user_id = _user_id(current_user) - tracked_order = _billing_store.get_order(request.razorpay_order_id) - if tracked_order and tracked_order.get("user_id") != user_id: - raise HTTPException(status_code=403, detail="Payment order does not belong to this user") - if tracked_order and tracked_order.get("package_id") != plan.id: - raise HTTPException(status_code=400, detail="Payment order package mismatch") - - now = datetime.now(timezone.utc) - summary = _get_summary(user_id) - summary.plan_name = plan.name - summary.account_status = "active" - summary.currency = plan.currency - summary.prepaid_balance_paise += plan.amount - summary.next_invoice_paise = 0 - summary.last_payment_at = now - summary.invoices.insert( - 0, - Invoice( - id=request.razorpay_payment_id, - date=now, - amount_paise=plan.amount, - status="paid", - credits=0, - ), + + if request.razorpay_subscription_id: + if not verify_subscription_signature( + request.razorpay_subscription_id, + request.razorpay_payment_id, + request.razorpay_signature, + ): + raise HTTPException(status_code=400, detail="Invalid Razorpay signature") + service.grant_pro_subscription( + user_id=user_id, + payment_id=request.razorpay_payment_id, + subscription_id=request.razorpay_subscription_id, + ) + elif request.razorpay_order_id: + if not verify_order_signature( + request.razorpay_order_id, + request.razorpay_payment_id, + request.razorpay_signature, + ): + raise HTTPException(status_code=400, detail="Invalid Razorpay signature") + checkout = service.store.get_checkout(request.razorpay_order_id) + if checkout and checkout.get("user_id") != user_id: + raise HTTPException(status_code=403, detail="Payment order does not belong to this user") + package_id = str((checkout or {}).get("package_id") or request.package_id) + if package_id == "pro": + service.grant_pro_subscription( + user_id=user_id, + payment_id=request.razorpay_payment_id, + subscription_id=request.razorpay_order_id, + ) + else: + service.grant_topup( + user_id=user_id, + pack_id=package_id, + payment_id=request.razorpay_payment_id, + order_id=request.razorpay_order_id, + ) + else: + raise HTTPException(status_code=400, detail="Missing Razorpay order or subscription id") + + return VerifyPaymentResponse(summary=service.get_billing_summary(current_user)) + + +@router.post("/razorpay/webhook") +async def razorpay_webhook(request: Request) -> dict[str, str]: + body = await request.body() + signature = request.headers.get("x-razorpay-signature", "") + try: + if not verify_webhook_signature(body, signature): + raise HTTPException(status_code=400, detail="Invalid Razorpay webhook signature") + except RazorpayConfigError as exc: + raise HTTPException(status_code=503, detail=str(exc)) from exc + + payload = json.loads(body.decode("utf-8")) + event_id = str( + request.headers.get("x-razorpay-event-id") + or payload.get("id") + or "" + ) + event_name = str(payload.get("event") or "") + service = get_default_billing_service() + first_seen = service.store.mark_payment_event( + event_id, + {"event": event_name, "payload": payload}, ) - _billing_store.save_summary(user_id, summary) + if not first_seen: + return {"status": "ignored_duplicate"} + + payment = (((payload.get("payload") or {}).get("payment") or {}).get("entity") or {}) + subscription = (((payload.get("payload") or {}).get("subscription") or {}).get("entity") or {}) + order = (((payload.get("payload") or {}).get("order") or {}).get("entity") or {}) + notes = payment.get("notes") or subscription.get("notes") or order.get("notes") or {} + user_id = str(notes.get("user_id") or "") + package_id = str(notes.get("package_id") or "") + payment_id = str(payment.get("id") or payload.get("id") or "") + order_id = str(payment.get("order_id") or order.get("id") or "") + subscription_id = str(payment.get("subscription_id") or subscription.get("id") or "") + + if not user_id or not package_id: + logger.info("Ignoring Razorpay webhook without XMem user/package notes: %s", event_name) + return {"status": "ignored"} + + if event_name in {"payment.captured", "order.paid", "subscription.charged"}: + if package_id == "pro": + service.grant_pro_subscription( + user_id=user_id, + payment_id=payment_id or event_id, + subscription_id=subscription_id or order_id or event_id, + ) + elif package_id in billing_config.TOP_UP_PACKS: + service.grant_topup( + user_id=user_id, + pack_id=package_id, + payment_id=payment_id or event_id, + order_id=order_id or event_id, + ) - return VerifyRazorpayPaymentResponse(status="ok", summary=summary) + return {"status": "ok"} + + +@router.get("/ledger", response_model=list[LedgerEntryPublic]) +async def billing_ledger( + current_user: dict = Depends(require_auth), + limit: int = Query(default=100, ge=1, le=500), +) -> list[LedgerEntryPublic]: + service = get_default_billing_service() + return [ + LedgerEntryPublic( + id=str(entry["id"]), + type=str(entry["type"]), + amount=int(entry["amount"]), + idempotency_key=str(entry["idempotency_key"]), + job_id=entry.get("job_id"), + source=entry.get("source"), + metadata=entry.get("metadata") or {}, + created_at=entry["created_at"], + ) + for entry in service.list_ledger(current_user, limit=limit) + ] diff --git a/src/api/routes/v2/activities.py b/src/api/routes/v2/activities.py index 4282898c..b1500f7f 100644 --- a/src/api/routes/v2/activities.py +++ b/src/api/routes/v2/activities.py @@ -9,6 +9,7 @@ from src.api.dependencies import get_ingest_pipeline from src.api.routes import memory as memory_v1 +from src.billing.service import commit_job_billing, release_job_billing from src.jobs.durable import get_default_job_store try: # pragma: no cover - no-op fallback keeps imports working without SDK. @@ -50,15 +51,22 @@ async def mark_job_progress_activity(payload: Dict[str, Any]) -> None: @activity.defn async def mark_job_succeeded_activity(payload: Dict[str, Any]) -> None: + job = await asyncio.to_thread(get_default_job_store().get, payload["job_id"]) + result = payload.get("result") or {} + if job: + result = await asyncio.to_thread(commit_job_billing, job, result) await asyncio.to_thread( get_default_job_store().mark_succeeded, payload["job_id"], - payload.get("result") or {}, + result, ) @activity.defn async def mark_job_dead_letter_activity(payload: Dict[str, Any]) -> None: + job = await asyncio.to_thread(get_default_job_store().get, payload["job_id"]) + if job: + await asyncio.to_thread(release_job_billing, job, "dead_letter") await asyncio.to_thread( get_default_job_store().mark_dead_letter, payload["job_id"], diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 1e0cff33..d79a8367 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -17,6 +17,7 @@ ) from src.api.routes.v2.temporal_client import cancel_job_workflow, start_job_workflow from src.api.schemas import APIResponse +from src.billing.service import InsufficientCredits, get_default_billing_service, release_job_billing from src.jobs.durable import DEAD_LETTER, QUEUED, RUNNING, get_default_job_store router = APIRouter( @@ -112,6 +113,24 @@ async def retry_job(job_id: str, request: Request, user: dict = Depends(require_ if job.get("status") not in {"failed", "dead_letter", "cancelled"}: return _error(request, "Only failed, dead-lettered, or cancelled jobs can be retried.", 409, elapsed_ms(start)) + payload = job.get("payload") if isinstance(job.get("payload"), dict) else {} + billing_account_id = payload.get("billing_account_id") + if billing_account_id: + billing_service = get_default_billing_service() + try: + estimate = billing_service.estimate_required_credits(job.get("job_type") or "", payload) + reservation = await asyncio.to_thread( + billing_service.reserve_credits, + billing_account_id, + job_id, + estimate.reserved_credits, + ) + payload["billing_reservation_id"] = reservation.reservation_id + payload["billing_estimate"] = estimate.model_dump() + await asyncio.to_thread(get_default_job_store().update_payload, job_id, payload) + except InsufficientCredits as exc: + return _error(request, str(exc), 402, elapsed_ms(start)) + await asyncio.to_thread(get_default_job_store().reset_for_retry, job_id, True) job = await asyncio.to_thread(get_default_job_store().get, job_id) try: @@ -138,6 +157,7 @@ async def cancel_job(job_id: str, request: Request, user: dict = Depends(require except Exception as exc: error = str(exc) or exc.__class__.__name__ return _error(request, f"Cancel failed to reach workflow: {error}", 503, elapsed_ms(start)) + await asyncio.to_thread(release_job_billing, job, "cancelled") await asyncio.to_thread(get_default_job_store().mark_cancelled, job_id) await asyncio.to_thread(_mark_scanner_job_cancelled, job) job = await asyncio.to_thread(get_default_job_store().get, job_id) diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index 8fab333d..56e063b6 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -19,8 +19,9 @@ ) from src.api.routes.v2.temporal_client import start_job_workflow from src.api.schemas import APIResponse, BatchIngestRequest, IngestRequest, ScrapeRequest, StatusEnum +from src.billing.service import InsufficientCredits, get_default_billing_service from src.config import settings -from src.jobs.durable import QUEUED, get_default_job_store, new_attempt_id, stable_hash +from src.jobs.durable import QUEUED, get_default_job_store, idempotency_key, new_attempt_id, stable_hash router = APIRouter( prefix="/v2/memory", @@ -39,6 +40,10 @@ def _content_hash(payload: Dict[str, Any]) -> str: return stable_hash(payload) +def _durable_job_id(job_type: str, fields: Dict[str, Any]) -> str: + return f"{job_type}:{idempotency_key(job_type, fields)}" + + class WorkflowStartFailed(RuntimeError): def __init__(self, job: Dict[str, Any], error: str) -> None: super().__init__(error) @@ -117,22 +122,35 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De payload = req.model_dump() payload["user_id"] = user_id payload["timeout_seconds"] = float(settings.memory_ingest_timeout_seconds) + idempotency_fields = { + "user_id": user_id, + "org_id": payload.get("org_id", "default"), + "content_hash": _content_hash({ + "user_query": req.user_query, + "agent_response": req.agent_response or "", + "session_datetime": req.session_datetime, + "image_url": req.image_url, + "effort_level": req.effort_level, + }), + } + job_id = _durable_job_id("memory_ingest", idempotency_fields) + billing_service = get_default_billing_service() try: + account, estimate, reservation = await asyncio.to_thread( + billing_service.reserve_job_credits, + user=user, + job_id=job_id, + job_type="memory_ingest", + payload=payload, + ) + payload["billing_account_id"] = account["id"] + payload["billing_reservation_id"] = reservation.reservation_id + payload["billing_estimate"] = estimate.model_dump() job, created = await _enqueue_and_start( job_type="memory_ingest", payload=payload, - idempotency_fields={ - "user_id": user_id, - "org_id": payload.get("org_id", "default"), - "content_hash": _content_hash({ - "user_query": req.user_query, - "agent_response": req.agent_response or "", - "session_datetime": req.session_datetime, - "image_url": req.image_url, - "effort_level": req.effort_level, - }), - }, + idempotency_fields=idempotency_fields, user_id=job_user_id, timeout_seconds=float(settings.memory_ingest_timeout_seconds), max_attempts=3, @@ -145,6 +163,12 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De elapsed_ms(start), ) except WorkflowStartFailed as exc: + if payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _workflow_start_error( request, exc.job, @@ -152,7 +176,15 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De f"/v2/memory/ingest/{exc.job['job_id']}/status", elapsed_ms(start), ) + except InsufficientCredits as exc: + return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: + if payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _error(request, str(exc), 500, elapsed_ms(start)) @@ -183,15 +215,28 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user min(len(req.items) * float(settings.memory_ingest_timeout_seconds), 3600.0), ), } + idempotency_fields = { + "user_id": user_id, + "content_hash": _content_hash({"items": items}), + } + job_id = _durable_job_id("memory_batch_ingest", idempotency_fields) + billing_service = get_default_billing_service() try: + account, estimate, reservation = await asyncio.to_thread( + billing_service.reserve_job_credits, + user=user, + job_id=job_id, + job_type="memory_batch_ingest", + payload=payload, + ) + payload["billing_account_id"] = account["id"] + payload["billing_reservation_id"] = reservation.reservation_id + payload["billing_estimate"] = estimate.model_dump() job, created = await _enqueue_and_start( job_type="memory_batch_ingest", payload=payload, - idempotency_fields={ - "user_id": user_id, - "content_hash": _content_hash({"items": items}), - }, + idempotency_fields=idempotency_fields, user_id=user_id, timeout_seconds=payload["timeout_seconds"], max_attempts=3, @@ -204,6 +249,12 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user elapsed_ms(start), ) except WorkflowStartFailed as exc: + if payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _workflow_start_error( request, exc.job, @@ -211,7 +262,15 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user f"/v2/memory/jobs/{exc.job['job_id']}/status", elapsed_ms(start), ) + except InsufficientCredits as exc: + return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: + if payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _error(request, str(exc), 500, elapsed_ms(start)) diff --git a/src/billing/__init__.py b/src/billing/__init__.py new file mode 100644 index 00000000..17ed2996 --- /dev/null +++ b/src/billing/__init__.py @@ -0,0 +1,31 @@ +"""Billing and credit ledger package.""" + +from .service import ( + InsufficientCredits, + commit_job_billing, + commit_job_debit, + ensure_billing_account, + estimate_required_credits, + get_billing_summary, + get_default_billing_service, + record_usage_event, + release_job_billing, + release_job_reservation, + reserve_credits, + reserve_job_credits, +) + +__all__ = [ + "InsufficientCredits", + "commit_job_billing", + "commit_job_debit", + "ensure_billing_account", + "estimate_required_credits", + "get_billing_summary", + "get_default_billing_service", + "record_usage_event", + "release_job_billing", + "release_job_reservation", + "reserve_credits", + "reserve_job_credits", +] diff --git a/src/billing/metering.py b/src/billing/metering.py new file mode 100644 index 00000000..8d03ea7a --- /dev/null +++ b/src/billing/metering.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import math +from typing import Any, Mapping + +from src.billing.types import CreditEstimate +from src.utils import billing as billing_config + + +def _ingest_text(payload: Mapping[str, Any]) -> str: + return "\n".join( + str(payload.get(key) or "") + for key in ("user_query", "agent_response") + if payload.get(key) + ) + + +def content_tokens_for_job(job_type: str, payload: Mapping[str, Any]) -> int: + if job_type == "memory_batch_ingest": + return sum( + content_tokens_for_job("memory_ingest", item) + for item in list(payload.get("items") or []) + if isinstance(item, Mapping) + ) + if job_type == "memory_ingest": + return billing_config.estimate_tokens(_ingest_text(payload)) + if job_type == "memory_retrieve": + return billing_config.estimate_tokens(str(payload.get("query") or "")) + return billing_config.estimate_tokens(str(payload)) + + +def estimate_required_credits( + job_type: str, + payload: Mapping[str, Any], + *, + include_reservation_buffer: bool = True, +) -> CreditEstimate: + tokens = content_tokens_for_job(job_type, payload) + multiplier = billing_config.workflow_multiplier(job_type, payload) + billable = max(1, math.ceil(tokens * multiplier)) + buffer = billing_config.RESERVATION_BUFFER_MULTIPLIER if include_reservation_buffer else 1.0 + reserved = max(billable, math.ceil(billable * buffer)) + return CreditEstimate( + job_type=job_type, + content_tokens=tokens, + multiplier=multiplier, + billable_credits=billable, + reserved_credits=reserved, + ) diff --git a/src/billing/razorpay.py b/src/billing/razorpay.py new file mode 100644 index 00000000..aa5013bd --- /dev/null +++ b/src/billing/razorpay.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +from typing import Any, Optional + +import httpx + +from src.config import settings + +logger = logging.getLogger("xmem.billing.razorpay") + +RAZORPAY_API = "https://api.razorpay.com/v1" + + +class RazorpayConfigError(RuntimeError): + pass + + +def require_razorpay_keys() -> tuple[str, str]: + if not settings.razorpay_key_id or not settings.razorpay_key_secret: + raise RazorpayConfigError( + "Razorpay is not configured. Set RAZORPAY_KEY_ID and RAZORPAY_KEY_SECRET." + ) + return settings.razorpay_key_id, settings.razorpay_key_secret + + +def verify_order_signature(order_id: str, payment_id: str, signature: str) -> bool: + _, secret = require_razorpay_keys() + payload = f"{order_id}|{payment_id}".encode("utf-8") + expected = hmac.new(secret.encode("utf-8"), payload, hashlib.sha256).hexdigest() + return hmac.compare_digest(expected, signature) + + +def verify_subscription_signature( + subscription_id: str, + payment_id: str, + signature: str, +) -> bool: + _, secret = require_razorpay_keys() + payload = f"{payment_id}|{subscription_id}".encode("utf-8") + expected = hmac.new(secret.encode("utf-8"), payload, hashlib.sha256).hexdigest() + return hmac.compare_digest(expected, signature) + + +def verify_webhook_signature(body: bytes, signature: str) -> bool: + if not settings.razorpay_webhook_secret: + raise RazorpayConfigError("RAZORPAY_WEBHOOK_SECRET is not configured.") + expected = hmac.new( + settings.razorpay_webhook_secret.encode("utf-8"), + body, + hashlib.sha256, + ).hexdigest() + return hmac.compare_digest(expected, signature) + + +async def create_order( + *, + amount_paise: int, + currency: str, + receipt: str, + notes: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + key_id, key_secret = require_razorpay_keys() + payload = { + "amount": amount_paise, + "currency": currency, + "receipt": receipt, + "notes": notes or {}, + } + async with httpx.AsyncClient(timeout=20) as client: + response = await client.post( + f"{RAZORPAY_API}/orders", + auth=(key_id, key_secret), + json=payload, + ) + if response.status_code >= 400: + logger.warning("Razorpay order creation failed: %s %s", response.status_code, response.text[:500]) + response.raise_for_status() + return response.json() + + +async def create_subscription( + *, + plan_id: str, + total_count: int = 120, + notes: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + key_id, key_secret = require_razorpay_keys() + payload = { + "plan_id": plan_id, + "total_count": total_count, + "quantity": 1, + "customer_notify": 1, + "notes": notes or {}, + } + async with httpx.AsyncClient(timeout=20) as client: + response = await client.post( + f"{RAZORPAY_API}/subscriptions", + auth=(key_id, key_secret), + json=payload, + ) + if response.status_code >= 400: + logger.warning("Razorpay subscription creation failed: %s %s", response.status_code, response.text[:500]) + response.raise_for_status() + return response.json() diff --git a/src/billing/service.py b/src/billing/service.py new file mode 100644 index 00000000..c90693e6 --- /dev/null +++ b/src/billing/service.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import Any, Mapping, Optional + +from src.billing.metering import estimate_required_credits as _estimate_required_credits +from src.billing.store import BillingStore, InsufficientCredits, get_default_billing_store, utc_now +from src.billing.types import BillingSummary, CreditEstimate, CreditLotPublic, PlanPublic, ReservationResult, TopUpPackPublic +from src.utils import billing as billing_config + +logger = logging.getLogger("xmem.billing.service") + + +def _user_id(user: Mapping[str, Any]) -> str: + user_id = user.get("id") or user.get("_id") or user.get("sub") + if not user_id: + raise ValueError("Authenticated user is missing an id") + return str(user_id) + + +def public_plans() -> list[PlanPublic]: + return [ + PlanPublic( + id=plan_id, + name=str(plan["name"]), + price_paise=int(plan.get("price_paise") or 0), + currency=str(plan.get("currency") or "INR"), + monthly_credits=int(plan.get("monthly_credits") or 0), + trial_credits=int(plan.get("trial_credits") or 0), + trial_days=int(plan.get("trial_days") or 0), + nominal_paise_per_credit=billing_config.nominal_paise_per_credit(plan_id), + ) + for plan_id, plan in billing_config.PLANS.items() + ] + + +def public_topups() -> list[TopUpPackPublic]: + return [ + TopUpPackPublic( + id=pack_id, + price_paise=int(pack["price_paise"]), + currency=str(pack.get("currency") or "INR"), + credits=int(pack["credits"]), + ) + for pack_id, pack in billing_config.TOP_UP_PACKS.items() + ] + + +class BillingService: + def __init__(self, store: Optional[BillingStore] = None) -> None: + self.store = store or get_default_billing_store() + + def ensure_billing_account(self, user: Mapping[str, Any]) -> dict[str, Any]: + owner_id = _user_id(user) + account = self.store.ensure_account(owner_id=owner_id) + free_plan = billing_config.PLANS["free"] + trial_credits = int(free_plan.get("trial_credits") or 0) + trial_days = int(free_plan.get("trial_days") or 30) + if trial_credits > 0: + self.store.grant_credits( + account_id=account["id"], + amount=trial_credits, + source="free_trial", + expires_at=utc_now() + timedelta(days=trial_days), + idempotency_key=f"free_trial:{owner_id}", + metadata={"plan_id": "free", "trial_days": trial_days}, + ) + return account + + def estimate_required_credits( + self, + job_type: str, + payload: Mapping[str, Any], + *, + include_reservation_buffer: bool = True, + ) -> CreditEstimate: + return _estimate_required_credits( + job_type, + payload, + include_reservation_buffer=include_reservation_buffer, + ) + + def reserve_credits( + self, + account_id: str, + job_id: str, + estimated_credits: int, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> ReservationResult: + reservation = self.store.reserve_credits( + account_id=account_id, + job_id=job_id, + amount=estimated_credits, + metadata=metadata, + ) + wallet = self.store.get_wallet(account_id) + return ReservationResult( + reservation_id=reservation["id"], + billing_account_id=account_id, + job_id=job_id, + reserved_credits=int(reservation.get("reserved_credits") or 0), + status=reservation.get("status", "active"), + available_credits=int(wallet.get("available_credits") or 0), + ) + + def reserve_job_credits( + self, + *, + user: Mapping[str, Any], + job_id: str, + job_type: str, + payload: Mapping[str, Any], + ) -> tuple[dict[str, Any], CreditEstimate, ReservationResult]: + account = self.ensure_billing_account(user) + estimate = self.estimate_required_credits(job_type, payload) + reservation = self.reserve_credits( + account["id"], + job_id, + estimate.reserved_credits, + metadata={ + "job_type": job_type, + "billable_credits": estimate.billable_credits, + "content_tokens": estimate.content_tokens, + }, + ) + return account, estimate, reservation + + def commit_job_debit( + self, + account_id: str, + job_id: str, + final_credits: int, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + return self.store.commit_debit( + account_id=account_id, + job_id=job_id, + final_amount=final_credits, + metadata=metadata, + ) + + def release_job_reservation( + self, + account_id: str, + job_id: str, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[dict[str, Any]]: + return self.store.release_reservation( + account_id=account_id, + job_id=job_id, + metadata=metadata, + ) + + def commit_job_billing(self, job: Mapping[str, Any], result: Mapping[str, Any]) -> dict[str, Any]: + payload = job.get("payload") if isinstance(job.get("payload"), Mapping) else {} + account_id = payload.get("billing_account_id") + if not account_id: + return dict(result) + job_type = str(job.get("job_type") or payload.get("job_type") or "memory_ingest") + estimate = self.estimate_required_credits( + job_type, + payload, + include_reservation_buffer=False, + ) + self.commit_job_debit( + str(account_id), + str(job["job_id"]), + estimate.billable_credits, + metadata={ + "job_type": job_type, + "content_tokens": estimate.content_tokens, + "multiplier": estimate.multiplier, + }, + ) + enriched = dict(result) + enriched["billing"] = { + "billing_account_id": account_id, + "billable_credits": estimate.billable_credits, + "content_tokens": estimate.content_tokens, + "multiplier": estimate.multiplier, + } + return enriched + + def release_job_billing(self, job: Mapping[str, Any], reason: str = "job_not_completed") -> None: + payload = job.get("payload") if isinstance(job.get("payload"), Mapping) else {} + account_id = payload.get("billing_account_id") + if not account_id: + return + self.release_job_reservation( + str(account_id), + str(job["job_id"]), + metadata={"reason": reason}, + ) + + def grant_pro_subscription( + self, + *, + user_id: str, + payment_id: str, + subscription_id: str, + period_end=None, + ) -> dict[str, Any]: + account = self.store.ensure_account(owner_id=user_id) + plan = billing_config.PLANS["pro"] + expires_at = period_end or (utc_now() + timedelta(days=30)) + self.store.update_account( + account["id"], + { + "plan_id": "pro", + "status": "active", + "razorpay_subscription_id": subscription_id, + "current_period_end": expires_at, + }, + ) + return self.store.grant_credits( + account_id=account["id"], + amount=int(plan["monthly_credits"]), + source="pro_monthly", + expires_at=expires_at, + idempotency_key=f"pro_grant:{subscription_id}:{payment_id}", + metadata={"payment_id": payment_id, "subscription_id": subscription_id}, + ) + + def grant_topup( + self, + *, + user_id: str, + pack_id: str, + payment_id: str, + order_id: str, + ) -> dict[str, Any]: + pack = billing_config.TOP_UP_PACKS[pack_id] + account = self.store.ensure_account(owner_id=user_id) + return self.store.grant_credits( + account_id=account["id"], + amount=int(pack["credits"]), + source=pack_id, + expires_at=utc_now() + timedelta(days=billing_config.TOP_UP_EXPIRY_DAYS), + idempotency_key=f"topup_grant:{order_id}:{payment_id}", + metadata={"payment_id": payment_id, "order_id": order_id, "pack_id": pack_id}, + ) + + def get_billing_summary(self, user: Mapping[str, Any]) -> BillingSummary: + account = self.ensure_billing_account(user) + wallet = self.store.get_wallet(account["id"]) + plan_id = str(account.get("plan_id") or "free") + plan = billing_config.PLANS.get(plan_id, billing_config.PLANS["free"]) + lots = [ + CreditLotPublic( + id=str(lot["id"]), + source=str(lot.get("source") or ""), + remaining_credits=int(lot.get("remaining_credits") or 0), + expires_at=lot.get("expires_at"), + ) + for lot in self.store.active_lots(account["id"]) + ] + return BillingSummary( + billing_account_id=account["id"], + owner_type=str(account.get("owner_type") or "user"), + owner_id=str(account.get("owner_id")), + plan_id=plan_id, + plan_name=str(plan.get("name") or plan_id), + status=str(account.get("status") or "trialing"), + currency=str(plan.get("currency") or "INR"), + available_credits=int(wallet.get("available_credits") or 0), + reserved_credits=int(wallet.get("reserved_credits") or 0), + current_period_start=account.get("current_period_start"), + current_period_end=account.get("current_period_end"), + credit_lots=lots, + ) + + def list_ledger(self, user: Mapping[str, Any], limit: int = 100) -> list[dict[str, Any]]: + account = self.ensure_billing_account(user) + return self.store.list_ledger(account["id"], limit=limit) + + def record_usage_event(self, **event: Any) -> None: + self.store.record_usage_event(event) + + +_default_service: Optional[BillingService] = None + + +def get_default_billing_service() -> BillingService: + global _default_service + if _default_service is None: + _default_service = BillingService() + return _default_service + + +def ensure_billing_account(user: Mapping[str, Any]) -> dict[str, Any]: + return get_default_billing_service().ensure_billing_account(user) + + +def estimate_required_credits(job_type: str, payload: Mapping[str, Any]) -> CreditEstimate: + return get_default_billing_service().estimate_required_credits(job_type, payload) + + +def reserve_credits(account_id: str, job_id: str, estimated_credits: int) -> ReservationResult: + return get_default_billing_service().reserve_credits(account_id, job_id, estimated_credits) + + +def reserve_job_credits( + *, + user: Mapping[str, Any], + job_id: str, + job_type: str, + payload: Mapping[str, Any], +) -> tuple[dict[str, Any], CreditEstimate, ReservationResult]: + return get_default_billing_service().reserve_job_credits( + user=user, + job_id=job_id, + job_type=job_type, + payload=payload, + ) + + +def commit_job_debit(account_id: str, job_id: str, final_credits: int) -> dict[str, Any]: + return get_default_billing_service().commit_job_debit(account_id, job_id, final_credits) + + +def release_job_reservation(account_id: str, job_id: str) -> Optional[dict[str, Any]]: + return get_default_billing_service().release_job_reservation(account_id, job_id) + + +def commit_job_billing(job: Mapping[str, Any], result: Mapping[str, Any]) -> dict[str, Any]: + return get_default_billing_service().commit_job_billing(job, result) + + +def release_job_billing(job: Mapping[str, Any], reason: str = "job_not_completed") -> None: + get_default_billing_service().release_job_billing(job, reason) + + +def get_billing_summary(user: Mapping[str, Any]) -> BillingSummary: + return get_default_billing_service().get_billing_summary(user) + + +def record_usage_event(**event: Any) -> None: + get_default_billing_service().record_usage_event(**event) diff --git a/src/billing/store.py b/src/billing/store.py new file mode 100644 index 00000000..ed572809 --- /dev/null +++ b/src/billing/store.py @@ -0,0 +1,648 @@ +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, Optional + +from src.config import settings + +logger = logging.getLogger("xmem.billing.store") + +_memory_accounts: dict[str, dict[str, Any]] = {} +_memory_wallets: dict[str, dict[str, Any]] = {} +_memory_lots: dict[str, dict[str, Any]] = {} +_memory_ledger: dict[str, dict[str, Any]] = {} +_memory_reservations: dict[str, dict[str, Any]] = {} +_memory_usage_events: list[dict[str, Any]] = [] +_memory_payments: dict[str, dict[str, Any]] = {} + + +class BillingStoreError(RuntimeError): + pass + + +class InsufficientCredits(BillingStoreError): + def __init__(self, required: int, available: int) -> None: + super().__init__( + f"Insufficient credits: required {required}, available {available}." + ) + self.required = required + self.available = available + + +def utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _without_id(doc: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + if not doc: + return None + result = dict(doc) + result.pop("_id", None) + return result + + +def _is_expired(doc: dict[str, Any], now: Optional[datetime] = None) -> bool: + expires_at = doc.get("expires_at") + if not expires_at: + return False + now = now or utc_now() + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return expires_at <= now + + +class BillingStore: + """Mongo-backed credit ledger with in-memory fallback for local development.""" + + def __init__(self, uri: Optional[str] = None, database: Optional[str] = None) -> None: + self._uri = uri or settings.mongodb_uri + self._database = database or settings.mongodb_database + self._client = None + self._db = None + self._connected = False + self._in_memory = False + self._try_connect() + + def _requires_durable_storage(self) -> bool: + return settings.environment.lower() in {"production", "prod"} + + def _enable_memory_fallback(self, error: Exception) -> None: + message = f"Billing store connection failed: {error}" + if self._requires_durable_storage(): + logger.error("%s; refusing in-memory fallback in production", message) + raise RuntimeError( + "MongoDB is required for billing storage when ENVIRONMENT=production" + ) from error + logger.warning("%s; using in-memory billing storage", message) + self._connected = False + self._in_memory = True + + def _try_connect(self) -> None: + provider = (settings.app_store_provider or "mongo").strip().lower() + if provider == "memory": + self._in_memory = True + return + if provider == "postgres": + self._enable_memory_fallback( + RuntimeError("Postgres billing storage is not implemented") + ) + return + try: + from pymongo import ASCENDING, MongoClient + + self._client = MongoClient(self._uri, serverSelectionTimeoutMS=5000) + self._client.admin.command("ping") + self._db = self._client[self._database] + self.accounts = self._db["billing_accounts"] + self.wallets = self._db["credit_wallets"] + self.lots = self._db["credit_lots"] + self.ledger = self._db["credit_ledger"] + self.reservations = self._db["credit_reservations"] + self.usage_events = self._db["usage_events"] + self.payments = self._db["billing_payments"] + + self.accounts.create_index([("owner_type", ASCENDING), ("owner_id", ASCENDING)], unique=True) + self.accounts.create_index([("razorpay_subscription_id", ASCENDING)]) + self.wallets.create_index([("billing_account_id", ASCENDING)], unique=True) + self.lots.create_index([("billing_account_id", ASCENDING), ("expires_at", ASCENDING)]) + self.ledger.create_index([("idempotency_key", ASCENDING)], unique=True) + self.ledger.create_index([("billing_account_id", ASCENDING), ("created_at", ASCENDING)]) + self.reservations.create_index([("job_id", ASCENDING)], unique=True) + self.payments.create_index([("razorpay_event_id", ASCENDING)], unique=True, sparse=True) + self.payments.create_index([("razorpay_payment_id", ASCENDING)], unique=True, sparse=True) + + self._connected = True + self._in_memory = False + except Exception as exc: + self._enable_memory_fallback(exc) + + def ensure_account( + self, + *, + owner_id: str, + owner_type: str = "user", + plan_id: str = "free", + status: str = "trialing", + ) -> dict[str, Any]: + now = utc_now() + if self._in_memory: + key = f"{owner_type}:{owner_id}" + account = _memory_accounts.get(key) + if account: + return dict(account) + account = { + "id": uuid.uuid4().hex, + "owner_type": owner_type, + "owner_id": owner_id, + "plan_id": plan_id, + "status": status, + "created_at": now, + "updated_at": now, + } + _memory_accounts[key] = account + _memory_wallets[account["id"]] = { + "billing_account_id": account["id"], + "available_credits": 0, + "reserved_credits": 0, + "updated_at": now, + } + return dict(account) + + from pymongo import ReturnDocument + + doc = self.accounts.find_one_and_update( + {"owner_type": owner_type, "owner_id": owner_id}, + { + "$setOnInsert": { + "id": uuid.uuid4().hex, + "owner_type": owner_type, + "owner_id": owner_id, + "plan_id": plan_id, + "status": status, + "created_at": now, + }, + "$set": {"updated_at": now}, + }, + upsert=True, + return_document=ReturnDocument.AFTER, + ) + account = _without_id(doc) or {} + self.wallets.update_one( + {"billing_account_id": account["id"]}, + { + "$setOnInsert": { + "billing_account_id": account["id"], + "available_credits": 0, + "reserved_credits": 0, + }, + "$set": {"updated_at": now}, + }, + upsert=True, + ) + return account + + def get_account(self, account_id: str) -> Optional[dict[str, Any]]: + if self._in_memory: + for account in _memory_accounts.values(): + if account["id"] == account_id: + return dict(account) + return None + return _without_id(self.accounts.find_one({"id": account_id})) + + def update_account(self, account_id: str, updates: dict[str, Any]) -> None: + updates = {**updates, "updated_at": utc_now()} + if self._in_memory: + for account in _memory_accounts.values(): + if account["id"] == account_id: + account.update(updates) + return + return + self.accounts.update_one({"id": account_id}, {"$set": updates}) + + def get_wallet(self, account_id: str) -> dict[str, Any]: + if self._in_memory: + return dict( + _memory_wallets.setdefault( + account_id, + { + "billing_account_id": account_id, + "available_credits": 0, + "reserved_credits": 0, + "updated_at": utc_now(), + }, + ) + ) + wallet = self.wallets.find_one({"billing_account_id": account_id}) + if not wallet: + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$setOnInsert": { + "billing_account_id": account_id, + "available_credits": 0, + "reserved_credits": 0, + }, + "$set": {"updated_at": utc_now()}, + }, + upsert=True, + ) + wallet = self.wallets.find_one({"billing_account_id": account_id}) + return _without_id(wallet) or {} + + def _insert_ledger(self, entry: dict[str, Any]) -> Optional[dict[str, Any]]: + if self._in_memory: + key = entry["idempotency_key"] + if key in _memory_ledger: + return dict(_memory_ledger[key]) + _memory_ledger[key] = dict(entry) + return None + try: + self.ledger.insert_one(entry) + return None + except Exception as exc: + if exc.__class__.__name__ != "DuplicateKeyError": + raise + return _without_id(self.ledger.find_one({"idempotency_key": entry["idempotency_key"]})) + + def grant_credits( + self, + *, + account_id: str, + amount: int, + source: str, + expires_at: Optional[datetime], + idempotency_key: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + if amount <= 0: + raise ValueError("Credit grant amount must be positive") + now = utc_now() + ledger_entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "grant", + "amount": amount, + "source": source, + "expires_at": expires_at, + "idempotency_key": idempotency_key, + "metadata": metadata or {}, + "created_at": now, + } + duplicate = self._insert_ledger(ledger_entry) + if duplicate: + return duplicate + + lot = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "source": source, + "remaining_credits": amount, + "ledger_id": ledger_entry["id"], + "expires_at": expires_at, + "created_at": now, + "updated_at": now, + } + if self._in_memory: + _memory_lots[lot["id"]] = lot + wallet = _memory_wallets.setdefault( + account_id, + {"billing_account_id": account_id, "available_credits": 0, "reserved_credits": 0}, + ) + wallet["available_credits"] += amount + wallet["updated_at"] = now + return dict(ledger_entry) + + self.lots.insert_one(lot) + self.wallets.update_one( + {"billing_account_id": account_id}, + {"$inc": {"available_credits": amount}, "$set": {"updated_at": now}}, + upsert=True, + ) + return ledger_entry + + def reserve_credits( + self, + *, + account_id: str, + job_id: str, + amount: int, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + if amount <= 0: + raise ValueError("Reservation amount must be positive") + existing = self.get_reservation(job_id) + if existing and existing.get("status") in {"active", "committed"}: + return existing + + now = utc_now() + if self._in_memory: + wallet = self.get_wallet(account_id) + if int(wallet.get("available_credits") or 0) < amount: + raise InsufficientCredits(amount, int(wallet.get("available_credits") or 0)) + _memory_wallets[account_id]["available_credits"] -= amount + _memory_wallets[account_id]["reserved_credits"] += amount + version = int((existing or {}).get("version") or 0) + 1 + reservation = { + "id": (existing or {}).get("id") or uuid.uuid4().hex, + "billing_account_id": account_id, + "job_id": job_id, + "reserved_credits": amount, + "status": "active", + "version": version, + "metadata": metadata or {}, + "created_at": (existing or {}).get("created_at") or now, + "updated_at": now, + } + _memory_reservations[job_id] = reservation + self._insert_ledger( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "reserve", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"reserve:{job_id}:{version}", + "metadata": metadata or {}, + "created_at": now, + } + ) + return dict(reservation) + + wallet = self.wallets.find_one_and_update( + {"billing_account_id": account_id, "available_credits": {"$gte": amount}}, + { + "$inc": {"available_credits": -amount, "reserved_credits": amount}, + "$set": {"updated_at": now}, + }, + return_document=True, + ) + if not wallet: + current = self.get_wallet(account_id) + raise InsufficientCredits(amount, int(current.get("available_credits") or 0)) + + version = int((existing or {}).get("version") or 0) + 1 + reservation = { + "id": (existing or {}).get("id") or uuid.uuid4().hex, + "billing_account_id": account_id, + "job_id": job_id, + "reserved_credits": amount, + "status": "active", + "version": version, + "metadata": metadata or {}, + "created_at": (existing or {}).get("created_at") or now, + "updated_at": now, + } + self.reservations.update_one({"job_id": job_id}, {"$set": reservation}, upsert=True) + self._insert_ledger( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "reserve", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"reserve:{job_id}:{version}", + "metadata": metadata or {}, + "created_at": now, + } + ) + return reservation + + def get_reservation(self, job_id: str) -> Optional[dict[str, Any]]: + if self._in_memory: + reservation = _memory_reservations.get(job_id) + return dict(reservation) if reservation else None + return _without_id(self.reservations.find_one({"job_id": job_id})) + + def commit_debit( + self, + *, + account_id: str, + job_id: str, + final_amount: int, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + if final_amount <= 0: + raise ValueError("Debit amount must be positive") + duplicate = self.find_ledger_by_key(f"debit:{job_id}") + if duplicate: + return duplicate + reservation = self.get_reservation(job_id) + if not reservation: + raise BillingStoreError(f"No credit reservation exists for job {job_id}") + if reservation.get("status") == "committed": + existing = self.find_ledger_by_key(f"debit:{job_id}") + if existing: + return existing + raise BillingStoreError(f"Reservation for job {job_id} is already committed") + if reservation.get("status") != "active": + raise BillingStoreError(f"Reservation for job {job_id} is not active") + + now = utc_now() + reserved = int(reservation.get("reserved_credits") or 0) + extra = max(final_amount - reserved, 0) + refund = max(reserved - final_amount, 0) + if extra: + if self._in_memory: + wallet = self.get_wallet(account_id) + if int(wallet.get("available_credits") or 0) < extra: + raise InsufficientCredits(extra, int(wallet.get("available_credits") or 0)) + _memory_wallets[account_id]["available_credits"] -= extra + else: + wallet = self.wallets.find_one_and_update( + {"billing_account_id": account_id, "available_credits": {"$gte": extra}}, + {"$inc": {"available_credits": -extra}, "$set": {"updated_at": now}}, + return_document=True, + ) + if not wallet: + current = self.get_wallet(account_id) + raise InsufficientCredits(extra, int(current.get("available_credits") or 0)) + + self._consume_lots(account_id, final_amount) + if self._in_memory: + _memory_wallets[account_id]["reserved_credits"] -= reserved + _memory_wallets[account_id]["available_credits"] += refund + _memory_wallets[account_id]["updated_at"] = now + _memory_reservations[job_id]["status"] = "committed" + _memory_reservations[job_id]["final_credits"] = final_amount + _memory_reservations[job_id]["updated_at"] = now + else: + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"reserved_credits": -reserved, "available_credits": refund}, + "$set": {"updated_at": now}, + }, + ) + self.reservations.update_one( + {"job_id": job_id}, + { + "$set": { + "status": "committed", + "final_credits": final_amount, + "updated_at": now, + } + }, + ) + + entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "debit", + "amount": -final_amount, + "job_id": job_id, + "idempotency_key": f"debit:{job_id}", + "metadata": metadata or {}, + "created_at": now, + } + self._insert_ledger(entry) + if refund: + self._insert_ledger( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "refund", + "amount": refund, + "job_id": job_id, + "idempotency_key": f"refund:{job_id}", + "metadata": {"reason": "unused_reservation"}, + "created_at": now, + } + ) + return entry + + def release_reservation( + self, + *, + account_id: str, + job_id: str, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[dict[str, Any]]: + reservation = self.get_reservation(job_id) + if not reservation or reservation.get("status") != "active": + return reservation + now = utc_now() + amount = int(reservation.get("reserved_credits") or 0) + if self._in_memory: + _memory_wallets[account_id]["available_credits"] += amount + _memory_wallets[account_id]["reserved_credits"] -= amount + _memory_wallets[account_id]["updated_at"] = now + _memory_reservations[job_id]["status"] = "released" + _memory_reservations[job_id]["updated_at"] = now + else: + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"available_credits": amount, "reserved_credits": -amount}, + "$set": {"updated_at": now}, + }, + ) + self.reservations.update_one( + {"job_id": job_id}, {"$set": {"status": "released", "updated_at": now}} + ) + self._insert_ledger( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "release", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"release:{job_id}:{reservation.get('version', 1)}", + "metadata": metadata or {}, + "created_at": now, + } + ) + return self.get_reservation(job_id) + + def _consume_lots(self, account_id: str, amount: int) -> None: + remaining = amount + now = utc_now() + lots = list(self.active_lots(account_id)) + for lot in lots: + if remaining <= 0: + break + take = min(remaining, int(lot.get("remaining_credits") or 0)) + if take <= 0: + continue + remaining -= take + if self._in_memory: + _memory_lots[lot["id"]]["remaining_credits"] -= take + _memory_lots[lot["id"]]["updated_at"] = now + else: + self.lots.update_one( + {"id": lot["id"]}, + {"$inc": {"remaining_credits": -take}, "$set": {"updated_at": now}}, + ) + if remaining > 0: + raise BillingStoreError( + f"Wallet had credits but credit lots were short by {remaining}" + ) + + def active_lots(self, account_id: str) -> Iterable[dict[str, Any]]: + now = utc_now() + if self._in_memory: + lots = [ + dict(lot) + for lot in _memory_lots.values() + if lot["billing_account_id"] == account_id + and int(lot.get("remaining_credits") or 0) > 0 + and not _is_expired(lot, now) + ] + return sorted(lots, key=lambda item: item.get("expires_at") or datetime.max.replace(tzinfo=timezone.utc)) + cursor = self.lots.find( + { + "billing_account_id": account_id, + "remaining_credits": {"$gt": 0}, + "$or": [{"expires_at": None}, {"expires_at": {"$gt": now}}], + } + ).sort("expires_at", 1) + return [_without_id(lot) or {} for lot in cursor] + + def find_ledger_by_key(self, idempotency_key: str) -> Optional[dict[str, Any]]: + if self._in_memory: + entry = _memory_ledger.get(idempotency_key) + return dict(entry) if entry else None + return _without_id(self.ledger.find_one({"idempotency_key": idempotency_key})) + + def list_ledger(self, account_id: str, limit: int = 100) -> list[dict[str, Any]]: + if self._in_memory: + entries = [ + dict(entry) + for entry in _memory_ledger.values() + if entry.get("billing_account_id") == account_id + ] + return sorted(entries, key=lambda item: item["created_at"], reverse=True)[:limit] + return [ + _without_id(entry) or {} + for entry in self.ledger.find({"billing_account_id": account_id}) + .sort("created_at", -1) + .limit(limit) + ] + + def record_usage_event(self, event: dict[str, Any]) -> None: + payload = {"id": uuid.uuid4().hex, "created_at": utc_now(), **event} + if self._in_memory: + _memory_usage_events.append(payload) + return + self.usage_events.insert_one(payload) + + def save_checkout(self, checkout_id: str, payload: dict[str, Any]) -> None: + now = utc_now() + doc = {"id": checkout_id, **payload, "updated_at": now} + if self._in_memory: + _memory_payments[checkout_id] = doc + return + self.payments.update_one({"id": checkout_id}, {"$set": doc, "$setOnInsert": {"created_at": now}}, upsert=True) + + def get_checkout(self, checkout_id: str) -> Optional[dict[str, Any]]: + if self._in_memory: + return dict(_memory_payments[checkout_id]) if checkout_id in _memory_payments else None + return _without_id(self.payments.find_one({"id": checkout_id})) + + def mark_payment_event(self, event_id: str, payload: dict[str, Any]) -> bool: + if not event_id: + event_id = uuid.uuid4().hex + now = utc_now() + if self._in_memory: + if event_id in _memory_payments: + return False + _memory_payments[event_id] = {"razorpay_event_id": event_id, **payload, "created_at": now} + return True + try: + self.payments.insert_one({"razorpay_event_id": event_id, **payload, "created_at": now}) + return True + except Exception as exc: + if exc.__class__.__name__ == "DuplicateKeyError": + return False + raise + + +_default_store: Optional[BillingStore] = None + + +def get_default_billing_store() -> BillingStore: + global _default_store + if _default_store is None: + _default_store = BillingStore() + return _default_store diff --git a/src/billing/types.py b/src/billing/types.py new file mode 100644 index 00000000..10f1394b --- /dev/null +++ b/src/billing/types.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Literal, Optional + +from pydantic import BaseModel, Field + + +class PlanPublic(BaseModel): + id: str + name: str + price_paise: int + currency: str = "INR" + monthly_credits: int = 0 + trial_credits: int = 0 + trial_days: int = 0 + nominal_paise_per_credit: float = 0.0 + + +class TopUpPackPublic(BaseModel): + id: str + price_paise: int + currency: str = "INR" + credits: int + + +class CreditLotPublic(BaseModel): + id: str + source: str + remaining_credits: int + expires_at: Optional[datetime] = None + + +class BillingSummary(BaseModel): + billing_account_id: str + owner_type: str = "user" + owner_id: str + plan_id: str + plan_name: str + status: str + currency: str = "INR" + available_credits: int = 0 + reserved_credits: int = 0 + current_period_start: Optional[datetime] = None + current_period_end: Optional[datetime] = None + credit_lots: list[CreditLotPublic] = Field(default_factory=list) + + +class CreditEstimate(BaseModel): + job_type: str + content_tokens: int + multiplier: float + billable_credits: int + reserved_credits: int + + +class ReservationResult(BaseModel): + reservation_id: str + billing_account_id: str + job_id: str + reserved_credits: int + status: Literal["active", "committed", "released", "expired"] + available_credits: int + + +class CheckoutRequest(BaseModel): + package_id: str = Field(..., description="Plan ID or top-up pack ID") + + +class CheckoutResponse(BaseModel): + id: str + package_id: str + amount: int + currency: str + key_id: str + order_id: Optional[str] = None + subscription_id: Optional[str] = None + receipt: Optional[str] = None + + +class VerifyPaymentRequest(BaseModel): + package_id: str + razorpay_payment_id: str + razorpay_signature: str + razorpay_order_id: Optional[str] = None + razorpay_subscription_id: Optional[str] = None + + +class LedgerEntryPublic(BaseModel): + id: str + type: str + amount: int + idempotency_key: str + job_id: Optional[str] = None + source: Optional[str] = None + metadata: dict[str, Any] = Field(default_factory=dict) + created_at: datetime diff --git a/src/config/settings.py b/src/config/settings.py index 78630cee..d227259b 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -473,6 +473,10 @@ class Settings(BaseSettings): default=None, description="Optional Razorpay webhook signing secret" ) + razorpay_pro_plan_id: Optional[str] = Field( + default=None, + description="Razorpay subscription plan ID for the Pro plan", + ) @field_validator("fallback_order") @classmethod diff --git a/src/utils/billing.py b/src/utils/billing.py new file mode 100644 index 00000000..921c4739 --- /dev/null +++ b/src/utils/billing.py @@ -0,0 +1,72 @@ +"""Editable billing knobs for XMem. + +Change this file when plan pricing, monthly credits, top-up packs, or +workflow credit consumption rules need to change. The billing service imports +these values at runtime so the rest of the implementation can stay stable. +""" + +from __future__ import annotations + +import math +from typing import Any, Mapping + +PLANS: dict[str, dict[str, Any]] = { + "free": { + "name": "Free Trial", + "price_paise": 0, + "currency": "INR", + "trial_credits": 10_000, + "trial_days": 30, + "monthly_credits": 0, + }, + "pro": { + "name": "Pro", + "price_paise": 9_900, + "currency": "INR", + "monthly_credits": 5_000, + }, +} + +TOP_UP_PACKS: dict[str, dict[str, Any]] = { + "topup_99": {"price_paise": 9_900, "credits": 5_000, "currency": "INR"}, + "topup_199": {"price_paise": 19_900, "credits": 12_000, "currency": "INR"}, + "topup_499": {"price_paise": 49_900, "credits": 35_000, "currency": "INR"}, +} + +WORKFLOW_MULTIPLIERS: dict[str, float] = { + "memory_ingest_low": 1.0, + "memory_ingest_standard": 1.5, + "memory_ingest_high": 2.5, + "memory_batch_ingest": 1.5, + "memory_retrieve": 0.5, +} + +RESERVATION_BUFFER_MULTIPLIER = 1.25 +TOKEN_ESTIMATE_CHARS_PER_TOKEN = 4 +TOP_UP_EXPIRY_DAYS = 365 + + +def estimate_tokens(text: str) -> int: + """Approximate billable tokens when provider usage is not available.""" + if not text: + return 0 + return max(1, math.ceil(len(text) / TOKEN_ESTIMATE_CHARS_PER_TOKEN)) + + +def workflow_multiplier(job_type: str, payload: Mapping[str, Any]) -> float: + if job_type == "memory_ingest": + effort = str(payload.get("effort_level") or "low").strip().lower() + if effort == "high": + return WORKFLOW_MULTIPLIERS["memory_ingest_high"] + if effort in {"standard", "medium"}: + return WORKFLOW_MULTIPLIERS["memory_ingest_standard"] + return WORKFLOW_MULTIPLIERS["memory_ingest_low"] + return WORKFLOW_MULTIPLIERS.get(job_type, 1.0) + + +def nominal_paise_per_credit(plan_id: str) -> float: + plan = PLANS[plan_id] + credits = int(plan.get("monthly_credits") or plan.get("trial_credits") or 0) + if credits <= 0: + return 0.0 + return float(plan["price_paise"]) / credits diff --git a/tests/test_billing.py b/tests/test_billing.py new file mode 100644 index 00000000..919fcbf8 --- /dev/null +++ b/tests/test_billing.py @@ -0,0 +1,126 @@ +import os +from datetime import timedelta + +os.environ.setdefault("APP_STORE_PROVIDER", "memory") +os.environ.setdefault("FALLBACK_ORDER", '["ollama"]') +os.environ.setdefault("NEO4J_PASSWORD", "test") + +import pytest + +from src.billing import store as billing_store +from src.billing.service import BillingService +from src.billing.store import BillingStore, InsufficientCredits, utc_now +from src.utils import billing as billing_config + + +@pytest.fixture(autouse=True) +def clear_memory_billing(): + billing_store._memory_accounts.clear() + billing_store._memory_wallets.clear() + billing_store._memory_lots.clear() + billing_store._memory_ledger.clear() + billing_store._memory_reservations.clear() + billing_store._memory_usage_events.clear() + billing_store._memory_payments.clear() + + +def service() -> BillingService: + return BillingService(BillingStore()) + + +def test_token_estimate_uses_configured_chars_per_token(): + assert billing_config.estimate_tokens("a" * 9) == 3 + + +def test_pro_nominal_credit_value(): + assert billing_config.nominal_paise_per_credit("pro") == pytest.approx(1.98) + + +def test_free_trial_grant_is_idempotent(): + svc = service() + user = {"id": "user_1"} + + first = svc.ensure_billing_account(user) + second = svc.ensure_billing_account(user) + + assert first["id"] == second["id"] + summary = svc.get_billing_summary(user) + assert summary.available_credits == 10_000 + grants = [entry for entry in svc.list_ledger(user) if entry["type"] == "grant"] + assert len(grants) == 1 + + +def test_reservation_debit_and_refund_flow(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + reservation = svc.reserve_credits(account["id"], "job_1", 1000) + assert reservation.available_credits == 9000 + + svc.commit_job_debit(account["id"], "job_1", 750) + summary = svc.get_billing_summary(user) + + assert summary.available_credits == 9250 + assert summary.reserved_credits == 0 + ledger_types = {entry["type"] for entry in svc.list_ledger(user)} + assert {"reserve", "debit", "refund"}.issubset(ledger_types) + + +def test_failed_job_releases_reserved_credits(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + svc.reserve_credits(account["id"], "job_1", 1000) + svc.release_job_reservation(account["id"], "job_1") + + summary = svc.get_billing_summary(user) + assert summary.available_credits == 10_000 + assert summary.reserved_credits == 0 + + +def test_insufficient_credits_blocks_reservation(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + with pytest.raises(InsufficientCredits): + svc.reserve_credits(account["id"], "job_1", 10_001) + + +def test_pro_grant_is_idempotent_per_payment(): + svc = service() + + svc.grant_pro_subscription( + user_id="user_1", + payment_id="pay_1", + subscription_id="sub_1", + period_end=utc_now() + timedelta(days=30), + ) + svc.grant_pro_subscription( + user_id="user_1", + payment_id="pay_1", + subscription_id="sub_1", + period_end=utc_now() + timedelta(days=30), + ) + + summary = svc.get_billing_summary({"id": "user_1"}) + assert summary.available_credits == 15_000 + pro_grants = [ + entry + for entry in svc.list_ledger({"id": "user_1"}) + if entry["source"] == "pro_monthly" + ] + assert len(pro_grants) == 1 + + +def test_billing_config_changes_affect_estimates(monkeypatch): + svc = service() + payload = {"user_query": "a" * 400, "agent_response": "", "effort_level": "low"} + baseline = svc.estimate_required_credits("memory_ingest", payload) + + monkeypatch.setitem(billing_config.WORKFLOW_MULTIPLIERS, "memory_ingest_low", 2.0) + changed = svc.estimate_required_credits("memory_ingest", payload) + + assert changed.billable_credits == baseline.billable_credits * 2 From 2dcc9fea78806b2299521f22d7eb3bd0addf56e9 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 12:47:51 +0530 Subject: [PATCH 02/20] Address billing review findings --- src/api/routes/billing.py | 67 ++++++--- src/billing/__init__.py | 2 +- src/billing/service.py | 2 +- src/billing/store.py | 291 +++++++++++++++++++++++++++++--------- tests/test_billing.py | 28 +++- 5 files changed, 302 insertions(+), 88 deletions(-) diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 6b54230f..06712638 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -2,10 +2,11 @@ from __future__ import annotations +import asyncio import json import logging from datetime import datetime, timezone -from typing import Any, Optional +from typing import Any import httpx from fastapi import APIRouter, Depends, HTTPException, Query, Request, status @@ -89,8 +90,9 @@ async def list_billing_plans() -> list[PlanPublic]: @router.get("/summary", response_model=BillingSummaryResponse) async def billing_summary(current_user: dict = Depends(require_auth)) -> BillingSummaryResponse: service = get_default_billing_service() + summary = await asyncio.to_thread(service.get_billing_summary, current_user) return BillingSummaryResponse( - summary=service.get_billing_summary(current_user), + summary=summary, plans=public_plans(), topups=public_topups(), ) @@ -109,7 +111,7 @@ async def create_razorpay_checkout( user_id = _user_id(current_user) package_type, package = _pack_or_plan(request.package_id) service = get_default_billing_service() - account = service.ensure_billing_account(current_user) + account = await asyncio.to_thread(service.ensure_billing_account, current_user) if request.package_id == "free": raise HTTPException(status_code=400, detail="Free plan does not require checkout") @@ -129,7 +131,8 @@ async def create_razorpay_checkout( notes=notes, ) checkout_id = str(subscription["id"]) - service.store.save_checkout( + await asyncio.to_thread( + service.store.save_checkout, checkout_id, { "type": "subscription", @@ -161,7 +164,8 @@ async def create_razorpay_checkout( raise HTTPException(status_code=502, detail="Razorpay checkout creation failed") from exc order_id = str(order["id"]) - service.store.save_checkout( + await asyncio.to_thread( + service.store.save_checkout, order_id, { "type": package_type, @@ -204,13 +208,24 @@ async def verify_razorpay_payment( user_id = _user_id(current_user) if request.razorpay_subscription_id: + checkout = await asyncio.to_thread( + service.store.get_checkout, + request.razorpay_subscription_id, + ) + if not checkout: + raise HTTPException(status_code=400, detail="Unknown Razorpay subscription checkout") + if checkout.get("user_id") != user_id: + raise HTTPException(status_code=403, detail="Payment subscription does not belong to this user") + if checkout.get("package_id") != "pro": + raise HTTPException(status_code=400, detail="Payment subscription package mismatch") if not verify_subscription_signature( request.razorpay_subscription_id, request.razorpay_payment_id, request.razorpay_signature, ): raise HTTPException(status_code=400, detail="Invalid Razorpay signature") - service.grant_pro_subscription( + await asyncio.to_thread( + service.grant_pro_subscription, user_id=user_id, payment_id=request.razorpay_payment_id, subscription_id=request.razorpay_subscription_id, @@ -222,18 +237,27 @@ async def verify_razorpay_payment( request.razorpay_signature, ): raise HTTPException(status_code=400, detail="Invalid Razorpay signature") - checkout = service.store.get_checkout(request.razorpay_order_id) - if checkout and checkout.get("user_id") != user_id: + checkout = await asyncio.to_thread( + service.store.get_checkout, + request.razorpay_order_id, + ) + if not checkout: + raise HTTPException(status_code=400, detail="Unknown Razorpay payment order") + if checkout.get("user_id") != user_id: raise HTTPException(status_code=403, detail="Payment order does not belong to this user") - package_id = str((checkout or {}).get("package_id") or request.package_id) + if checkout.get("package_id") != request.package_id: + raise HTTPException(status_code=400, detail="Payment order package mismatch") + package_id = str(checkout.get("package_id")) if package_id == "pro": - service.grant_pro_subscription( + await asyncio.to_thread( + service.grant_pro_subscription, user_id=user_id, payment_id=request.razorpay_payment_id, subscription_id=request.razorpay_order_id, ) else: - service.grant_topup( + await asyncio.to_thread( + service.grant_topup, user_id=user_id, pack_id=package_id, payment_id=request.razorpay_payment_id, @@ -242,7 +266,8 @@ async def verify_razorpay_payment( else: raise HTTPException(status_code=400, detail="Missing Razorpay order or subscription id") - return VerifyPaymentResponse(summary=service.get_billing_summary(current_user)) + summary = await asyncio.to_thread(service.get_billing_summary, current_user) + return VerifyPaymentResponse(summary=summary) @router.post("/razorpay/webhook") @@ -255,7 +280,11 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: except RazorpayConfigError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc - payload = json.loads(body.decode("utf-8")) + try: + payload = json.loads(body.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + logger.warning("Razorpay webhook body is not valid JSON: %s", exc) + raise HTTPException(status_code=400, detail="Webhook body must be valid JSON") from exc event_id = str( request.headers.get("x-razorpay-event-id") or payload.get("id") @@ -263,7 +292,8 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: ) event_name = str(payload.get("event") or "") service = get_default_billing_service() - first_seen = service.store.mark_payment_event( + first_seen = await asyncio.to_thread( + service.store.mark_payment_event, event_id, {"event": event_name, "payload": payload}, ) @@ -286,13 +316,15 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: if event_name in {"payment.captured", "order.paid", "subscription.charged"}: if package_id == "pro": - service.grant_pro_subscription( + await asyncio.to_thread( + service.grant_pro_subscription, user_id=user_id, payment_id=payment_id or event_id, subscription_id=subscription_id or order_id or event_id, ) elif package_id in billing_config.TOP_UP_PACKS: - service.grant_topup( + await asyncio.to_thread( + service.grant_topup, user_id=user_id, pack_id=package_id, payment_id=payment_id or event_id, @@ -308,6 +340,7 @@ async def billing_ledger( limit: int = Query(default=100, ge=1, le=500), ) -> list[LedgerEntryPublic]: service = get_default_billing_service() + entries = await asyncio.to_thread(service.list_ledger, current_user, limit) return [ LedgerEntryPublic( id=str(entry["id"]), @@ -319,5 +352,5 @@ async def billing_ledger( metadata=entry.get("metadata") or {}, created_at=entry["created_at"], ) - for entry in service.list_ledger(current_user, limit=limit) + for entry in entries ] diff --git a/src/billing/__init__.py b/src/billing/__init__.py index 17ed2996..64cb4935 100644 --- a/src/billing/__init__.py +++ b/src/billing/__init__.py @@ -1,7 +1,7 @@ """Billing and credit ledger package.""" +from .store import InsufficientCredits from .service import ( - InsufficientCredits, commit_job_billing, commit_job_debit, ensure_billing_account, diff --git a/src/billing/service.py b/src/billing/service.py index c90693e6..569c195d 100644 --- a/src/billing/service.py +++ b/src/billing/service.py @@ -5,7 +5,7 @@ from typing import Any, Mapping, Optional from src.billing.metering import estimate_required_credits as _estimate_required_credits -from src.billing.store import BillingStore, InsufficientCredits, get_default_billing_store, utc_now +from src.billing.store import BillingStore, get_default_billing_store, utc_now from src.billing.types import BillingSummary, CreditEstimate, CreditLotPublic, PlanPublic, ReservationResult, TopUpPackPublic from src.utils import billing as billing_config diff --git a/src/billing/store.py b/src/billing/store.py index ed572809..e67aff74 100644 --- a/src/billing/store.py +++ b/src/billing/store.py @@ -3,7 +3,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any, Dict, Iterable, Optional +from typing import Any, Iterable, Optional from src.config import settings @@ -103,13 +103,18 @@ def _try_connect(self) -> None: self.usage_events = self._db["usage_events"] self.payments = self._db["billing_payments"] + self.accounts.create_index([("id", ASCENDING)], unique=True) self.accounts.create_index([("owner_type", ASCENDING), ("owner_id", ASCENDING)], unique=True) self.accounts.create_index([("razorpay_subscription_id", ASCENDING)]) self.wallets.create_index([("billing_account_id", ASCENDING)], unique=True) + self.lots.create_index([("id", ASCENDING)], unique=True) self.lots.create_index([("billing_account_id", ASCENDING), ("expires_at", ASCENDING)]) + self.ledger.create_index([("id", ASCENDING)], unique=True) self.ledger.create_index([("idempotency_key", ASCENDING)], unique=True) self.ledger.create_index([("billing_account_id", ASCENDING), ("created_at", ASCENDING)]) + self.reservations.create_index([("id", ASCENDING)], unique=True) self.reservations.create_index([("job_id", ASCENDING)], unique=True) + self.payments.create_index([("id", ASCENDING)], unique=True, sparse=True) self.payments.create_index([("razorpay_event_id", ASCENDING)], unique=True, sparse=True) self.payments.create_index([("razorpay_payment_id", ASCENDING)], unique=True, sparse=True) @@ -314,10 +319,18 @@ def reserve_credits( raise ValueError("Reservation amount must be positive") existing = self.get_reservation(job_id) if existing and existing.get("status") in {"active", "committed"}: + if existing.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) return existing now = utc_now() if self._in_memory: + if existing and existing.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) wallet = self.get_wallet(account_id) if int(wallet.get("available_credits") or 0) < amount: raise InsufficientCredits(amount, int(wallet.get("available_credits") or 0)) @@ -350,6 +363,74 @@ def reserve_credits( ) return dict(reservation) + from pymongo import ReturnDocument + + if existing: + if existing.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + reserved_doc = self.reservations.find_one_and_update( + { + "job_id": job_id, + "billing_account_id": account_id, + "status": {"$nin": ["active", "committed", "reserving"]}, + }, + { + "$set": { + "status": "reserving", + "reserved_credits": amount, + "metadata": metadata or {}, + "updated_at": now, + }, + "$inc": {"version": 1}, + }, + return_document=ReturnDocument.AFTER, + ) + if not reserved_doc: + current = self.get_reservation(job_id) + if ( + current + and current.get("billing_account_id") == account_id + and current.get("status") != "reserving" + ): + return current + raise BillingStoreError(f"Reservation for job {job_id} is already active") + reservation = _without_id(reserved_doc) or {} + version = int(reservation.get("version") or 1) + else: + reservation = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "job_id": job_id, + "reserved_credits": amount, + "status": "reserving", + "version": 1, + "metadata": metadata or {}, + "created_at": now, + "updated_at": now, + } + try: + self.reservations.insert_one(reservation) + except Exception as exc: + if exc.__class__.__name__ != "DuplicateKeyError": + raise + current = self.get_reservation(job_id) + if ( + current + and current.get("billing_account_id") == account_id + and current.get("status") != "reserving" + ): + return current + if current and current.get("billing_account_id") == account_id: + raise BillingStoreError( + f"Reservation for job {job_id} is already being created" + ) from exc + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) from exc + version = 1 + wallet = self.wallets.find_one_and_update( {"billing_account_id": account_id, "available_credits": {"$gte": amount}}, { @@ -359,22 +440,18 @@ def reserve_credits( return_document=True, ) if not wallet: + self.reservations.update_one( + {"job_id": job_id, "status": "reserving"}, + {"$set": {"status": "released", "updated_at": now}}, + ) current = self.get_wallet(account_id) raise InsufficientCredits(amount, int(current.get("available_credits") or 0)) - version = int((existing or {}).get("version") or 0) + 1 - reservation = { - "id": (existing or {}).get("id") or uuid.uuid4().hex, - "billing_account_id": account_id, - "job_id": job_id, - "reserved_credits": amount, - "status": "active", - "version": version, - "metadata": metadata or {}, - "created_at": (existing or {}).get("created_at") or now, - "updated_at": now, - } - self.reservations.update_one({"job_id": job_id}, {"$set": reservation}, upsert=True) + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id, "status": "reserving"}, + {"$set": {"status": "active", "updated_at": now}}, + ) + reservation = self.get_reservation(job_id) or reservation self._insert_ledger( { "id": uuid.uuid4().hex, @@ -408,28 +485,21 @@ def commit_debit( duplicate = self.find_ledger_by_key(f"debit:{job_id}") if duplicate: return duplicate - reservation = self.get_reservation(job_id) - if not reservation: - raise BillingStoreError(f"No credit reservation exists for job {job_id}") - if reservation.get("status") == "committed": - existing = self.find_ledger_by_key(f"debit:{job_id}") - if existing: - return existing - raise BillingStoreError(f"Reservation for job {job_id} is already committed") - if reservation.get("status") != "active": - raise BillingStoreError(f"Reservation for job {job_id} is not active") - now = utc_now() + reservation = self._claim_reservation_for_commit(account_id, job_id, final_amount, now) + if reservation.get("type") == "debit": + return reservation reserved = int(reservation.get("reserved_credits") or 0) extra = max(final_amount - reserved, 0) refund = max(reserved - final_amount, 0) - if extra: - if self._in_memory: + + try: + if extra and self._in_memory: wallet = self.get_wallet(account_id) if int(wallet.get("available_credits") or 0) < extra: raise InsufficientCredits(extra, int(wallet.get("available_credits") or 0)) _memory_wallets[account_id]["available_credits"] -= extra - else: + elif extra: wallet = self.wallets.find_one_and_update( {"billing_account_id": account_id, "available_credits": {"$gte": extra}}, {"$inc": {"available_credits": -extra}, "$set": {"updated_at": now}}, @@ -439,32 +509,35 @@ def commit_debit( current = self.get_wallet(account_id) raise InsufficientCredits(extra, int(current.get("available_credits") or 0)) - self._consume_lots(account_id, final_amount) - if self._in_memory: - _memory_wallets[account_id]["reserved_credits"] -= reserved - _memory_wallets[account_id]["available_credits"] += refund - _memory_wallets[account_id]["updated_at"] = now - _memory_reservations[job_id]["status"] = "committed" - _memory_reservations[job_id]["final_credits"] = final_amount - _memory_reservations[job_id]["updated_at"] = now - else: - self.wallets.update_one( - {"billing_account_id": account_id}, - { - "$inc": {"reserved_credits": -reserved, "available_credits": refund}, - "$set": {"updated_at": now}, - }, - ) - self.reservations.update_one( - {"job_id": job_id}, - { - "$set": { - "status": "committed", - "final_credits": final_amount, - "updated_at": now, - } - }, - ) + self._consume_lots(account_id, final_amount) + if self._in_memory: + _memory_wallets[account_id]["reserved_credits"] -= reserved + _memory_wallets[account_id]["available_credits"] += refund + _memory_wallets[account_id]["updated_at"] = now + _memory_reservations[job_id]["status"] = "committed" + _memory_reservations[job_id]["final_credits"] = final_amount + _memory_reservations[job_id]["updated_at"] = now + else: + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"reserved_credits": -reserved, "available_credits": refund}, + "$set": {"updated_at": now}, + }, + ) + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id}, + { + "$set": { + "status": "committed", + "final_credits": final_amount, + "updated_at": now, + } + }, + ) + except Exception: + self._release_commit_claim(account_id, job_id) + raise entry = { "id": uuid.uuid4().hex, @@ -492,6 +565,75 @@ def commit_debit( ) return entry + def _claim_reservation_for_commit( + self, + account_id: str, + job_id: str, + final_amount: int, + now: datetime, + ) -> dict[str, Any]: + if self._in_memory: + reservation = self.get_reservation(job_id) + if not reservation: + raise BillingStoreError(f"No credit reservation exists for job {job_id}") + if reservation.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + if reservation.get("status") == "committed": + existing = self.find_ledger_by_key(f"debit:{job_id}") + if existing: + return existing + raise BillingStoreError(f"Reservation for job {job_id} is already committed") + if reservation.get("status") != "active": + raise BillingStoreError(f"Reservation for job {job_id} is not active") + _memory_reservations[job_id]["status"] = "committing" + _memory_reservations[job_id]["final_credits"] = final_amount + _memory_reservations[job_id]["updated_at"] = now + return reservation + + from pymongo import ReturnDocument + + reservation = self.reservations.find_one_and_update( + {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, + { + "$set": { + "status": "committing", + "final_credits": final_amount, + "updated_at": now, + } + }, + return_document=ReturnDocument.BEFORE, + ) + if reservation: + return _without_id(reservation) or {} + + existing = self.find_ledger_by_key(f"debit:{job_id}") + if existing: + return existing + current = self.get_reservation(job_id) + if current and current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + raise BillingStoreError(f"Reservation for job {job_id} is not active") + + def _release_commit_claim(self, account_id: str, job_id: str) -> None: + if self._in_memory: + reservation = _memory_reservations.get(job_id) + if reservation and reservation.get("billing_account_id") == account_id: + reservation["status"] = "active" + reservation.pop("final_credits", None) + reservation["updated_at"] = utc_now() + return + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id, "status": "committing"}, + { + "$set": {"status": "active", "updated_at": utc_now()}, + "$unset": {"final_credits": ""}, + }, + ) + def release_reservation( self, *, @@ -538,22 +680,35 @@ def release_reservation( def _consume_lots(self, account_id: str, amount: int) -> None: remaining = amount now = utc_now() - lots = list(self.active_lots(account_id)) - for lot in lots: - if remaining <= 0: + while remaining > 0: + lots = list(self.active_lots(account_id)) + if not lots: break - take = min(remaining, int(lot.get("remaining_credits") or 0)) - if take <= 0: - continue - remaining -= take - if self._in_memory: - _memory_lots[lot["id"]]["remaining_credits"] -= take - _memory_lots[lot["id"]]["updated_at"] = now - else: - self.lots.update_one( - {"id": lot["id"]}, - {"$inc": {"remaining_credits": -take}, "$set": {"updated_at": now}}, + progressed = False + for lot in lots: + if remaining <= 0: + break + take = min(remaining, int(lot.get("remaining_credits") or 0)) + if take <= 0: + continue + if self._in_memory: + _memory_lots[lot["id"]]["remaining_credits"] -= take + _memory_lots[lot["id"]]["updated_at"] = now + remaining -= take + progressed = True + continue + result = self.lots.update_one( + {"id": lot["id"], "remaining_credits": {"$gte": take}}, + { + "$inc": {"remaining_credits": -take}, + "$set": {"updated_at": now}, + }, ) + if getattr(result, "modified_count", 0) == 1: + remaining -= take + progressed = True + if not progressed: + break if remaining > 0: raise BillingStoreError( f"Wallet had credits but credit lots were short by {remaining}" diff --git a/tests/test_billing.py b/tests/test_billing.py index 919fcbf8..106bb9e9 100644 --- a/tests/test_billing.py +++ b/tests/test_billing.py @@ -9,7 +9,7 @@ from src.billing import store as billing_store from src.billing.service import BillingService -from src.billing.store import BillingStore, InsufficientCredits, utc_now +from src.billing.store import BillingStore, BillingStoreError, InsufficientCredits, utc_now from src.utils import billing as billing_config @@ -89,6 +89,32 @@ def test_insufficient_credits_blocks_reservation(): svc.reserve_credits(account["id"], "job_1", 10_001) +def test_reservation_cannot_be_reused_by_another_account(): + svc = service() + account_1 = svc.ensure_billing_account({"id": "user_1"}) + account_2 = svc.ensure_billing_account({"id": "user_2"}) + + svc.reserve_credits(account_1["id"], "job_1", 100) + + with pytest.raises(BillingStoreError): + svc.reserve_credits(account_2["id"], "job_1", 100) + + +def test_duplicate_commit_does_not_double_debit(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + svc.reserve_credits(account["id"], "job_1", 1000) + + svc.commit_job_debit(account["id"], "job_1", 750) + svc.commit_job_debit(account["id"], "job_1", 750) + + summary = svc.get_billing_summary(user) + debits = [entry for entry in svc.list_ledger(user) if entry["type"] == "debit"] + assert summary.available_credits == 9250 + assert len(debits) == 1 + + def test_pro_grant_is_idempotent_per_payment(): svc = service() From 0aaa7a24020f18f479576f24f6bc2fbefff51335 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 12:51:27 +0530 Subject: [PATCH 03/20] Fix billing imports for CI --- CHANGELOG.md | 1 + src/api/routes/v2/jobs.py | 2 +- src/api/routes/v2/memory.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af837a3a..0e069a92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Add modular Razorpay billing, credit wallets, ledger reservations, and v2 memory workflow metering. - Add durable Temporal-backed v2 memory and scanner workflow APIs with job status, retry, cancel, and dead-letter endpoints. - Add modular LoCoMo and BEAM benchmark runners for the Python XMem API. - Add local XMem setup through `npx create-xmem@latest` and `npm run dev`. diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index d79a8367..668833f3 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -17,7 +17,7 @@ ) from src.api.routes.v2.temporal_client import cancel_job_workflow, start_job_workflow from src.api.schemas import APIResponse -from src.billing.service import InsufficientCredits, get_default_billing_service, release_job_billing +from src.billing import InsufficientCredits, get_default_billing_service, release_job_billing from src.jobs.durable import DEAD_LETTER, QUEUED, RUNNING, get_default_job_store router = APIRouter( diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index 56e063b6..9468576b 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -19,7 +19,7 @@ ) from src.api.routes.v2.temporal_client import start_job_workflow from src.api.schemas import APIResponse, BatchIngestRequest, IngestRequest, ScrapeRequest, StatusEnum -from src.billing.service import InsufficientCredits, get_default_billing_service +from src.billing import InsufficientCredits, get_default_billing_service from src.config import settings from src.jobs.durable import QUEUED, get_default_job_store, idempotency_key, new_attempt_id, stable_hash From f22c59f1a6554e8ce434303f8b280b290a9e0627 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 12:55:32 +0530 Subject: [PATCH 04/20] Harden billing webhook and reservation idempotency --- src/api/routes/billing.py | 2 ++ src/billing/store.py | 64 ++++++++++++++++++++++++++++----------- tests/test_billing.py | 36 +++++++++++++++++++++- 3 files changed, 83 insertions(+), 19 deletions(-) diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 06712638..7ce7c2af 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -290,6 +290,8 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: or payload.get("id") or "" ) + if not event_id: + raise HTTPException(status_code=400, detail="Webhook event id is required") event_name = str(payload.get("event") or "") service = get_default_billing_service() first_seen = await asyncio.to_thread( diff --git a/src/billing/store.py b/src/billing/store.py index e67aff74..5204d7b5 100644 --- a/src/billing/store.py +++ b/src/billing/store.py @@ -7,6 +7,11 @@ from src.config import settings +try: + from pymongo.errors import DuplicateKeyError +except Exception: # pragma: no cover - pymongo may be absent in memory-only dev/test. + DuplicateKeyError = None # type: ignore[assignment] + logger = logging.getLogger("xmem.billing.store") _memory_accounts: dict[str, dict[str, Any]] = {} @@ -15,7 +20,9 @@ _memory_ledger: dict[str, dict[str, Any]] = {} _memory_reservations: dict[str, dict[str, Any]] = {} _memory_usage_events: list[dict[str, Any]] = [] -_memory_payments: dict[str, dict[str, Any]] = {} +_memory_checkouts: dict[str, dict[str, Any]] = {} +_memory_payment_events: dict[str, dict[str, Any]] = {} +_memory_payments = _memory_checkouts class BillingStoreError(RuntimeError): @@ -53,6 +60,10 @@ def _is_expired(doc: dict[str, Any], now: Optional[datetime] = None) -> bool: return expires_at <= now +def _is_duplicate_key_error(exc: Exception) -> bool: + return DuplicateKeyError is not None and isinstance(exc, DuplicateKeyError) + + class BillingStore: """Mongo-backed credit ledger with in-memory fallback for local development.""" @@ -247,7 +258,7 @@ def _insert_ledger(self, entry: dict[str, Any]) -> Optional[dict[str, Any]]: self.ledger.insert_one(entry) return None except Exception as exc: - if exc.__class__.__name__ != "DuplicateKeyError": + if not _is_duplicate_key_error(exc): raise return _without_id(self.ledger.find_one({"idempotency_key": entry["idempotency_key"]})) @@ -413,7 +424,7 @@ def reserve_credits( try: self.reservations.insert_one(reservation) except Exception as exc: - if exc.__class__.__name__ != "DuplicateKeyError": + if not _is_duplicate_key_error(exc): raise current = self.get_reservation(job_id) if ( @@ -641,18 +652,38 @@ def release_reservation( job_id: str, metadata: Optional[dict[str, Any]] = None, ) -> Optional[dict[str, Any]]: - reservation = self.get_reservation(job_id) - if not reservation or reservation.get("status") != "active": - return reservation now = utc_now() - amount = int(reservation.get("reserved_credits") or 0) if self._in_memory: + reservation = _memory_reservations.get(job_id) + if not reservation or reservation.get("status") != "active": + return dict(reservation) if reservation else None + if reservation.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + reservation["status"] = "released" + reservation["updated_at"] = now + amount = int(reservation.get("reserved_credits") or 0) _memory_wallets[account_id]["available_credits"] += amount _memory_wallets[account_id]["reserved_credits"] -= amount _memory_wallets[account_id]["updated_at"] = now - _memory_reservations[job_id]["status"] = "released" - _memory_reservations[job_id]["updated_at"] = now else: + from pymongo import ReturnDocument + + reservation_doc = self.reservations.find_one_and_update( + {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, + {"$set": {"status": "released", "updated_at": now}}, + return_document=ReturnDocument.BEFORE, + ) + if not reservation_doc: + current = self.get_reservation(job_id) + if current and current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + return current + reservation = _without_id(reservation_doc) or {} + amount = int(reservation.get("reserved_credits") or 0) self.wallets.update_one( {"billing_account_id": account_id}, { @@ -660,9 +691,6 @@ def release_reservation( "$set": {"updated_at": now}, }, ) - self.reservations.update_one( - {"job_id": job_id}, {"$set": {"status": "released", "updated_at": now}} - ) self._insert_ledger( { "id": uuid.uuid4().hex, @@ -766,29 +794,29 @@ def save_checkout(self, checkout_id: str, payload: dict[str, Any]) -> None: now = utc_now() doc = {"id": checkout_id, **payload, "updated_at": now} if self._in_memory: - _memory_payments[checkout_id] = doc + _memory_checkouts[checkout_id] = doc return self.payments.update_one({"id": checkout_id}, {"$set": doc, "$setOnInsert": {"created_at": now}}, upsert=True) def get_checkout(self, checkout_id: str) -> Optional[dict[str, Any]]: if self._in_memory: - return dict(_memory_payments[checkout_id]) if checkout_id in _memory_payments else None + return dict(_memory_checkouts[checkout_id]) if checkout_id in _memory_checkouts else None return _without_id(self.payments.find_one({"id": checkout_id})) def mark_payment_event(self, event_id: str, payload: dict[str, Any]) -> bool: if not event_id: - event_id = uuid.uuid4().hex + raise ValueError("Razorpay webhook event id is required") now = utc_now() if self._in_memory: - if event_id in _memory_payments: + if event_id in _memory_payment_events: return False - _memory_payments[event_id] = {"razorpay_event_id": event_id, **payload, "created_at": now} + _memory_payment_events[event_id] = {"razorpay_event_id": event_id, **payload, "created_at": now} return True try: self.payments.insert_one({"razorpay_event_id": event_id, **payload, "created_at": now}) return True except Exception as exc: - if exc.__class__.__name__ == "DuplicateKeyError": + if _is_duplicate_key_error(exc): return False raise diff --git a/tests/test_billing.py b/tests/test_billing.py index 106bb9e9..8884452e 100644 --- a/tests/test_billing.py +++ b/tests/test_billing.py @@ -21,7 +21,8 @@ def clear_memory_billing(): billing_store._memory_ledger.clear() billing_store._memory_reservations.clear() billing_store._memory_usage_events.clear() - billing_store._memory_payments.clear() + billing_store._memory_checkouts.clear() + billing_store._memory_payment_events.clear() def service() -> BillingService: @@ -80,6 +81,22 @@ def test_failed_job_releases_reserved_credits(): assert summary.reserved_credits == 0 +def test_duplicate_release_does_not_double_refund(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + svc.reserve_credits(account["id"], "job_1", 1000) + svc.release_job_reservation(account["id"], "job_1") + svc.release_job_reservation(account["id"], "job_1") + + summary = svc.get_billing_summary(user) + releases = [entry for entry in svc.list_ledger(user) if entry["type"] == "release"] + assert summary.available_credits == 10_000 + assert summary.reserved_credits == 0 + assert len(releases) == 1 + + def test_insufficient_credits_blocks_reservation(): svc = service() user = {"id": "user_1"} @@ -150,3 +167,20 @@ def test_billing_config_changes_affect_estimates(monkeypatch): changed = svc.estimate_required_credits("memory_ingest", payload) assert changed.billable_credits == baseline.billable_credits * 2 + + +def test_missing_payment_event_id_is_rejected(): + store = BillingStore() + + with pytest.raises(ValueError): + store.mark_payment_event("", {"event": "payment.captured"}) + + +def test_in_memory_checkout_and_webhook_events_are_isolated(): + store = BillingStore() + + store.save_checkout("same_id", {"user_id": "user_1", "package_id": "topup_99"}) + + assert store.mark_payment_event("same_id", {"event": "payment.captured"}) + assert not store.mark_payment_event("same_id", {"event": "payment.captured"}) + assert store.get_checkout("same_id")["package_id"] == "topup_99" From bfc5254e7b8303eebb0511470adc1ac53a21447b Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 13:12:47 +0530 Subject: [PATCH 05/20] Remove unused billing memory alias --- src/billing/store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/billing/store.py b/src/billing/store.py index 5204d7b5..397b25ae 100644 --- a/src/billing/store.py +++ b/src/billing/store.py @@ -22,7 +22,6 @@ _memory_usage_events: list[dict[str, Any]] = [] _memory_checkouts: dict[str, dict[str, Any]] = {} _memory_payment_events: dict[str, dict[str, Any]] = {} -_memory_payments = _memory_checkouts class BillingStoreError(RuntimeError): From fd1acfd2b8bd18e72cacdcd7b7567d64d8b20ca6 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 13:17:00 +0530 Subject: [PATCH 06/20] Avoid releasing reused billing reservations --- src/api/routes/v2/memory.py | 12 ++++++++---- src/billing/service.py | 1 + src/billing/store.py | 10 +++++----- src/billing/types.py | 1 + tests/test_billing.py | 15 +++++++++++++++ 5 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index 9468576b..f6c68b29 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -135,6 +135,7 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De } job_id = _durable_job_id("memory_ingest", idempotency_fields) billing_service = get_default_billing_service() + billing_reservation_created = False try: account, estimate, reservation = await asyncio.to_thread( @@ -147,6 +148,7 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De payload["billing_account_id"] = account["id"] payload["billing_reservation_id"] = reservation.reservation_id payload["billing_estimate"] = estimate.model_dump() + billing_reservation_created = reservation.created job, created = await _enqueue_and_start( job_type="memory_ingest", payload=payload, @@ -163,7 +165,7 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De elapsed_ms(start), ) except WorkflowStartFailed as exc: - if payload.get("billing_account_id"): + if billing_reservation_created and payload.get("billing_account_id"): await asyncio.to_thread( billing_service.release_job_reservation, payload["billing_account_id"], @@ -179,7 +181,7 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De except InsufficientCredits as exc: return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: - if payload.get("billing_account_id"): + if billing_reservation_created and payload.get("billing_account_id"): await asyncio.to_thread( billing_service.release_job_reservation, payload["billing_account_id"], @@ -221,6 +223,7 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user } job_id = _durable_job_id("memory_batch_ingest", idempotency_fields) billing_service = get_default_billing_service() + billing_reservation_created = False try: account, estimate, reservation = await asyncio.to_thread( @@ -233,6 +236,7 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user payload["billing_account_id"] = account["id"] payload["billing_reservation_id"] = reservation.reservation_id payload["billing_estimate"] = estimate.model_dump() + billing_reservation_created = reservation.created job, created = await _enqueue_and_start( job_type="memory_batch_ingest", payload=payload, @@ -249,7 +253,7 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user elapsed_ms(start), ) except WorkflowStartFailed as exc: - if payload.get("billing_account_id"): + if billing_reservation_created and payload.get("billing_account_id"): await asyncio.to_thread( billing_service.release_job_reservation, payload["billing_account_id"], @@ -265,7 +269,7 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user except InsufficientCredits as exc: return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: - if payload.get("billing_account_id"): + if billing_reservation_created and payload.get("billing_account_id"): await asyncio.to_thread( billing_service.release_job_reservation, payload["billing_account_id"], diff --git a/src/billing/service.py b/src/billing/service.py index 569c195d..3618d68b 100644 --- a/src/billing/service.py +++ b/src/billing/service.py @@ -103,6 +103,7 @@ def reserve_credits( reserved_credits=int(reservation.get("reserved_credits") or 0), status=reservation.get("status", "active"), available_credits=int(wallet.get("available_credits") or 0), + created=bool(reservation.get("created")), ) def reserve_job_credits( diff --git a/src/billing/store.py b/src/billing/store.py index 397b25ae..f72677a1 100644 --- a/src/billing/store.py +++ b/src/billing/store.py @@ -333,7 +333,7 @@ def reserve_credits( raise BillingStoreError( f"Reservation for job {job_id} belongs to a different billing account" ) - return existing + return {**existing, "created": False} now = utc_now() if self._in_memory: @@ -371,7 +371,7 @@ def reserve_credits( "created_at": now, } ) - return dict(reservation) + return {**reservation, "created": True} from pymongo import ReturnDocument @@ -404,7 +404,7 @@ def reserve_credits( and current.get("billing_account_id") == account_id and current.get("status") != "reserving" ): - return current + return {**current, "created": False} raise BillingStoreError(f"Reservation for job {job_id} is already active") reservation = _without_id(reserved_doc) or {} version = int(reservation.get("version") or 1) @@ -431,7 +431,7 @@ def reserve_credits( and current.get("billing_account_id") == account_id and current.get("status") != "reserving" ): - return current + return {**current, "created": False} if current and current.get("billing_account_id") == account_id: raise BillingStoreError( f"Reservation for job {job_id} is already being created" @@ -474,7 +474,7 @@ def reserve_credits( "created_at": now, } ) - return reservation + return {**reservation, "created": True} def get_reservation(self, job_id: str) -> Optional[dict[str, Any]]: if self._in_memory: diff --git a/src/billing/types.py b/src/billing/types.py index 10f1394b..0eac7427 100644 --- a/src/billing/types.py +++ b/src/billing/types.py @@ -61,6 +61,7 @@ class ReservationResult(BaseModel): reserved_credits: int status: Literal["active", "committed", "released", "expired"] available_credits: int + created: bool = False class CheckoutRequest(BaseModel): diff --git a/tests/test_billing.py b/tests/test_billing.py index 8884452e..295431fc 100644 --- a/tests/test_billing.py +++ b/tests/test_billing.py @@ -68,6 +68,21 @@ def test_reservation_debit_and_refund_flow(): assert {"reserve", "debit", "refund"}.issubset(ledger_types) +def test_reused_reservation_is_marked_not_created(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + first = svc.reserve_credits(account["id"], "job_1", 1000) + second = svc.reserve_credits(account["id"], "job_1", 1000) + + summary = svc.get_billing_summary(user) + assert first.created + assert not second.created + assert summary.available_credits == 9000 + assert summary.reserved_credits == 1000 + + def test_failed_job_releases_reserved_credits(): svc = service() user = {"id": "user_1"} From 5ddb4b1102c7f82c701f9e2b7622909ff024156c Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 13:36:58 +0530 Subject: [PATCH 07/20] Release retry billing reservation on start failure --- src/api/routes/v2/jobs.py | 15 +++- tests/api/test_memory_versioning.py | 114 ++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 668833f3..4cb38308 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -17,7 +17,12 @@ ) from src.api.routes.v2.temporal_client import cancel_job_workflow, start_job_workflow from src.api.schemas import APIResponse -from src.billing import InsufficientCredits, get_default_billing_service, release_job_billing +from src.billing import ( + InsufficientCredits, + get_default_billing_service, + release_job_billing, + release_job_reservation, +) from src.jobs.durable import DEAD_LETTER, QUEUED, RUNNING, get_default_job_store router = APIRouter( @@ -115,6 +120,7 @@ async def retry_job(job_id: str, request: Request, user: dict = Depends(require_ payload = job.get("payload") if isinstance(job.get("payload"), dict) else {} billing_account_id = payload.get("billing_account_id") + billing_reservation_created = False if billing_account_id: billing_service = get_default_billing_service() try: @@ -127,15 +133,18 @@ async def retry_job(job_id: str, request: Request, user: dict = Depends(require_ ) payload["billing_reservation_id"] = reservation.reservation_id payload["billing_estimate"] = estimate.model_dump() + billing_reservation_created = reservation.created await asyncio.to_thread(get_default_job_store().update_payload, job_id, payload) except InsufficientCredits as exc: return _error(request, str(exc), 402, elapsed_ms(start)) - await asyncio.to_thread(get_default_job_store().reset_for_retry, job_id, True) - job = await asyncio.to_thread(get_default_job_store().get, job_id) try: + await asyncio.to_thread(get_default_job_store().reset_for_retry, job_id, True) + job = await asyncio.to_thread(get_default_job_store().get, job_id) await start_job_workflow(job) except Exception as exc: + if billing_reservation_created and billing_account_id: + await asyncio.to_thread(release_job_reservation, billing_account_id, job_id) error = str(exc) or exc.__class__.__name__ await asyncio.to_thread(get_default_job_store().mark_failed, job_id, error) return _error(request, f"Retry failed to start workflow: {error}", 503, elapsed_ms(start)) diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index 8c45bc7f..16aa03a7 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -7,7 +7,9 @@ from src.api import dependencies as deps from src.api.routes import memory +from src.api.routes.v2 import jobs as jobs_v2 from src.api.routes.v2 import memory as memory_v2 +from src.jobs import durable class FakeIngestPipeline: @@ -54,6 +56,17 @@ def mark_failed(self, job_id, error): job["error"] = error return "failed" + def reset_for_retry(self, job_id, clear_workflow=False): + job = self.jobs[job_id] + job["status"] = "queued" + job["error"] = None + if clear_workflow: + job["workflow_id"] = None + job["run_id"] = None + + def update_payload(self, job_id, payload): + self.jobs[job_id]["payload"] = payload + def reserve_workflow_start(self, job_id, workflow_id): job = self.jobs[job_id] if job["status"] != "queued" or job.get("workflow_id"): @@ -86,6 +99,7 @@ async def fake_rate_limit(): app.include_router(memory.router) app.include_router(memory_v2.scrape_router) app.include_router(memory_v2.router) + app.include_router(jobs_v2.router) return app, ingest @@ -182,6 +196,106 @@ async def fake_start_job_workflow(job): assert ingest.calls == [] +def test_v2_retry_start_failure_releases_fresh_billing_reservation(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + released = [] + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=True) + + async def fake_start_job_workflow(job): + raise RuntimeError("temporal unavailable") + + def fake_release(account_id, job_id): + released.append((account_id, job_id)) + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr(jobs_v2, "release_job_reservation", fake_release) + monkeypatch.setattr(jobs_v2, "start_job_workflow", fake_start_job_workflow) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert released == [("acct-1", "job-1")] + assert store.jobs["job-1"]["status"] == "failed" + assert store.jobs["job-1"]["error"] == "temporal unavailable" + + +def test_v2_retry_start_failure_keeps_reused_billing_reservation(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + released = [] + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=False) + + async def fake_start_job_workflow(job): + raise RuntimeError("temporal unavailable") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr( + jobs_v2, + "release_job_reservation", + lambda account_id, job_id: released.append((account_id, job_id)), + ) + monkeypatch.setattr(jobs_v2, "start_job_workflow", fake_start_job_workflow) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert released == [] + assert store.jobs["job-1"]["status"] == "failed" + + def test_v1_batch_ingest_scopes_each_item_for_local_static_key(monkeypatch): monkeypatch.setattr(memory.settings, "environment", "development", raising=False) static_user = {"id": "static-key", "name": "Static Key User", "email": "static@xmem.ai"} From 620bbcb5bb89aa14a09ac0ba914eefc33617c631 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 14:04:28 +0530 Subject: [PATCH 08/20] Make billing grants and debits atomic --- src/api/routes/billing.py | 20 ++-- src/billing/store.py | 189 +++++++++++++++++++++++--------------- 2 files changed, 131 insertions(+), 78 deletions(-) diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 7ce7c2af..eca97f84 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -294,12 +294,7 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: raise HTTPException(status_code=400, detail="Webhook event id is required") event_name = str(payload.get("event") or "") service = get_default_billing_service() - first_seen = await asyncio.to_thread( - service.store.mark_payment_event, - event_id, - {"event": event_name, "payload": payload}, - ) - if not first_seen: + if await asyncio.to_thread(service.store.has_payment_event, event_id): return {"status": "ignored_duplicate"} payment = (((payload.get("payload") or {}).get("payment") or {}).get("entity") or {}) @@ -314,6 +309,11 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: if not user_id or not package_id: logger.info("Ignoring Razorpay webhook without XMem user/package notes: %s", event_name) + await asyncio.to_thread( + service.store.mark_payment_event, + event_id, + {"event": event_name, "payload": payload}, + ) return {"status": "ignored"} if event_name in {"payment.captured", "order.paid", "subscription.charged"}: @@ -333,6 +333,14 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: order_id=order_id or event_id, ) + first_seen = await asyncio.to_thread( + service.store.mark_payment_event, + event_id, + {"event": event_name, "payload": payload}, + ) + if not first_seen: + return {"status": "ignored_duplicate"} + return {"status": "ok"} diff --git a/src/billing/store.py b/src/billing/store.py index f72677a1..d906ee15 100644 --- a/src/billing/store.py +++ b/src/billing/store.py @@ -285,10 +285,6 @@ def grant_credits( "metadata": metadata or {}, "created_at": now, } - duplicate = self._insert_ledger(ledger_entry) - if duplicate: - return duplicate - lot = { "id": uuid.uuid4().hex, "billing_account_id": account_id, @@ -300,6 +296,9 @@ def grant_credits( "updated_at": now, } if self._in_memory: + duplicate = self.find_ledger_by_key(idempotency_key) + if duplicate: + return duplicate _memory_lots[lot["id"]] = lot wallet = _memory_wallets.setdefault( account_id, @@ -307,14 +306,32 @@ def grant_credits( ) wallet["available_credits"] += amount wallet["updated_at"] = now + self._insert_ledger(ledger_entry) return dict(ledger_entry) - self.lots.insert_one(lot) - self.wallets.update_one( - {"billing_account_id": account_id}, - {"$inc": {"available_credits": amount}, "$set": {"updated_at": now}}, - upsert=True, - ) + try: + with self._client.start_session() as session: + with session.start_transaction(): + existing = self.ledger.find_one( + {"idempotency_key": idempotency_key}, + session=session, + ) + if existing: + return _without_id(existing) or {} + self.lots.insert_one(lot, session=session) + self.wallets.update_one( + {"billing_account_id": account_id}, + {"$inc": {"available_credits": amount}, "$set": {"updated_at": now}}, + upsert=True, + session=session, + ) + self.ledger.insert_one(ledger_entry, session=session) + except Exception as exc: + if _is_duplicate_key_error(exc): + existing = self.find_ledger_by_key(idempotency_key) + if existing: + return existing + raise return ledger_entry def reserve_credits( @@ -503,76 +520,95 @@ def commit_debit( extra = max(final_amount - reserved, 0) refund = max(reserved - final_amount, 0) - try: - if extra and self._in_memory: - wallet = self.get_wallet(account_id) - if int(wallet.get("available_credits") or 0) < extra: - raise InsufficientCredits(extra, int(wallet.get("available_credits") or 0)) - _memory_wallets[account_id]["available_credits"] -= extra - elif extra: - wallet = self.wallets.find_one_and_update( - {"billing_account_id": account_id, "available_credits": {"$gte": extra}}, - {"$inc": {"available_credits": -extra}, "$set": {"updated_at": now}}, - return_document=True, - ) - if not wallet: - current = self.get_wallet(account_id) - raise InsufficientCredits(extra, int(current.get("available_credits") or 0)) + entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "debit", + "amount": -final_amount, + "job_id": job_id, + "idempotency_key": f"debit:{job_id}", + "metadata": metadata or {}, + "created_at": now, + } + refund_entry = None + if refund: + refund_entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "refund", + "amount": refund, + "job_id": job_id, + "idempotency_key": f"refund:{job_id}", + "metadata": {"reason": "unused_reservation"}, + "created_at": now, + } - self._consume_lots(account_id, final_amount) + try: if self._in_memory: + if extra: + wallet = self.get_wallet(account_id) + if int(wallet.get("available_credits") or 0) < extra: + raise InsufficientCredits(extra, int(wallet.get("available_credits") or 0)) + _memory_wallets[account_id]["available_credits"] -= extra + self._consume_lots(account_id, final_amount) _memory_wallets[account_id]["reserved_credits"] -= reserved _memory_wallets[account_id]["available_credits"] += refund _memory_wallets[account_id]["updated_at"] = now _memory_reservations[job_id]["status"] = "committed" _memory_reservations[job_id]["final_credits"] = final_amount _memory_reservations[job_id]["updated_at"] = now - else: - self.wallets.update_one( - {"billing_account_id": account_id}, - { - "$inc": {"reserved_credits": -reserved, "available_credits": refund}, - "$set": {"updated_at": now}, - }, - ) - self.reservations.update_one( - {"job_id": job_id, "billing_account_id": account_id}, - { - "$set": { - "status": "committed", - "final_credits": final_amount, - "updated_at": now, - } - }, - ) + self._insert_ledger(entry) + if refund_entry: + self._insert_ledger(refund_entry) + return entry + + with self._client.start_session() as session: + with session.start_transaction(): + existing = self.ledger.find_one( + {"idempotency_key": f"debit:{job_id}"}, + session=session, + ) + if existing: + return _without_id(existing) or {} + + if extra: + wallet = self.wallets.find_one_and_update( + {"billing_account_id": account_id, "available_credits": {"$gte": extra}}, + {"$inc": {"available_credits": -extra}, "$set": {"updated_at": now}}, + return_document=True, + session=session, + ) + if not wallet: + current = self.get_wallet(account_id) + raise InsufficientCredits(extra, int(current.get("available_credits") or 0)) + + self._consume_lots(account_id, final_amount, session=session) + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"reserved_credits": -reserved, "available_credits": refund}, + "$set": {"updated_at": now}, + }, + session=session, + ) + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id}, + { + "$set": { + "status": "committed", + "final_credits": final_amount, + "updated_at": now, + } + }, + session=session, + ) + self.ledger.insert_one(entry, session=session) + if refund_entry: + self.ledger.insert_one(refund_entry, session=session) except Exception: self._release_commit_claim(account_id, job_id) raise - entry = { - "id": uuid.uuid4().hex, - "billing_account_id": account_id, - "type": "debit", - "amount": -final_amount, - "job_id": job_id, - "idempotency_key": f"debit:{job_id}", - "metadata": metadata or {}, - "created_at": now, - } - self._insert_ledger(entry) - if refund: - self._insert_ledger( - { - "id": uuid.uuid4().hex, - "billing_account_id": account_id, - "type": "refund", - "amount": refund, - "job_id": job_id, - "idempotency_key": f"refund:{job_id}", - "metadata": {"reason": "unused_reservation"}, - "created_at": now, - } - ) return entry def _claim_reservation_for_commit( @@ -704,11 +740,11 @@ def release_reservation( ) return self.get_reservation(job_id) - def _consume_lots(self, account_id: str, amount: int) -> None: + def _consume_lots(self, account_id: str, amount: int, *, session: Any = None) -> None: remaining = amount now = utc_now() while remaining > 0: - lots = list(self.active_lots(account_id)) + lots = list(self.active_lots(account_id, session=session)) if not lots: break progressed = False @@ -730,6 +766,7 @@ def _consume_lots(self, account_id: str, amount: int) -> None: "$inc": {"remaining_credits": -take}, "$set": {"updated_at": now}, }, + session=session, ) if getattr(result, "modified_count", 0) == 1: remaining -= take @@ -741,7 +778,7 @@ def _consume_lots(self, account_id: str, amount: int) -> None: f"Wallet had credits but credit lots were short by {remaining}" ) - def active_lots(self, account_id: str) -> Iterable[dict[str, Any]]: + def active_lots(self, account_id: str, *, session: Any = None) -> Iterable[dict[str, Any]]: now = utc_now() if self._in_memory: lots = [ @@ -757,7 +794,8 @@ def active_lots(self, account_id: str) -> Iterable[dict[str, Any]]: "billing_account_id": account_id, "remaining_credits": {"$gt": 0}, "$or": [{"expires_at": None}, {"expires_at": {"$gt": now}}], - } + }, + session=session, ).sort("expires_at", 1) return [_without_id(lot) or {} for lot in cursor] @@ -767,6 +805,13 @@ def find_ledger_by_key(self, idempotency_key: str) -> Optional[dict[str, Any]]: return dict(entry) if entry else None return _without_id(self.ledger.find_one({"idempotency_key": idempotency_key})) + def has_payment_event(self, event_id: str) -> bool: + if not event_id: + return False + if self._in_memory: + return event_id in _memory_payment_events + return self.payments.find_one({"razorpay_event_id": event_id}) is not None + def list_ledger(self, account_id: str, limit: int = 100) -> list[dict[str, Any]]: if self._in_memory: entries = [ From 70e10afd8e272958040bba7ef409921f39505ea2 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 14:14:46 +0530 Subject: [PATCH 09/20] Make billing reserve and release atomic --- src/billing/store.py | 248 +++++++++++++++++++++++-------------------- 1 file changed, 134 insertions(+), 114 deletions(-) diff --git a/src/billing/store.py b/src/billing/store.py index d906ee15..554092ce 100644 --- a/src/billing/store.py +++ b/src/billing/store.py @@ -392,106 +392,107 @@ def reserve_credits( from pymongo import ReturnDocument - if existing: - if existing.get("billing_account_id") != account_id: - raise BillingStoreError( - f"Reservation for job {job_id} belongs to a different billing account" - ) - reserved_doc = self.reservations.find_one_and_update( - { - "job_id": job_id, - "billing_account_id": account_id, - "status": {"$nin": ["active", "committed", "reserving"]}, - }, - { - "$set": { - "status": "reserving", - "reserved_credits": amount, - "metadata": metadata or {}, - "updated_at": now, - }, - "$inc": {"version": 1}, - }, - return_document=ReturnDocument.AFTER, - ) - if not reserved_doc: - current = self.get_reservation(job_id) - if ( - current - and current.get("billing_account_id") == account_id - and current.get("status") != "reserving" - ): - return {**current, "created": False} - raise BillingStoreError(f"Reservation for job {job_id} is already active") - reservation = _without_id(reserved_doc) or {} - version = int(reservation.get("version") or 1) - else: - reservation = { - "id": uuid.uuid4().hex, - "billing_account_id": account_id, - "job_id": job_id, - "reserved_credits": amount, - "status": "reserving", - "version": 1, - "metadata": metadata or {}, - "created_at": now, - "updated_at": now, - } - try: - self.reservations.insert_one(reservation) - except Exception as exc: - if not _is_duplicate_key_error(exc): - raise + try: + with self._client.start_session() as session: + with session.start_transaction(): + current = self.reservations.find_one({"job_id": job_id}, session=session) + if current: + current = _without_id(current) or {} + if current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + if current.get("status") in {"active", "committed"}: + return {**current, "created": False} + if current.get("status") == "reserving": + raise BillingStoreError( + f"Reservation for job {job_id} is already being created" + ) + reserved_doc = self.reservations.find_one_and_update( + { + "job_id": job_id, + "billing_account_id": account_id, + "status": {"$nin": ["active", "committed", "reserving"]}, + }, + { + "$set": { + "status": "reserving", + "reserved_credits": amount, + "metadata": metadata or {}, + "updated_at": now, + }, + "$inc": {"version": 1}, + }, + return_document=ReturnDocument.AFTER, + session=session, + ) + if not reserved_doc: + raise BillingStoreError(f"Reservation for job {job_id} is already active") + reservation = _without_id(reserved_doc) or {} + version = int(reservation.get("version") or 1) + else: + reservation = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "job_id": job_id, + "reserved_credits": amount, + "status": "reserving", + "version": 1, + "metadata": metadata or {}, + "created_at": now, + "updated_at": now, + } + self.reservations.insert_one(reservation, session=session) + version = 1 + + wallet = self.wallets.find_one_and_update( + {"billing_account_id": account_id, "available_credits": {"$gte": amount}}, + { + "$inc": {"available_credits": -amount, "reserved_credits": amount}, + "$set": {"updated_at": now}, + }, + return_document=True, + session=session, + ) + if not wallet: + current_wallet = self.wallets.find_one( + {"billing_account_id": account_id}, + session=session, + ) or {} + raise InsufficientCredits( + amount, + int(current_wallet.get("available_credits") or 0), + ) + + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id, "status": "reserving"}, + {"$set": {"status": "active", "updated_at": now}}, + session=session, + ) + self.ledger.insert_one( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "reserve", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"reserve:{job_id}:{version}", + "metadata": metadata or {}, + "created_at": now, + }, + session=session, + ) + except Exception as exc: + if _is_duplicate_key_error(exc): current = self.get_reservation(job_id) if ( current and current.get("billing_account_id") == account_id - and current.get("status") != "reserving" + and current.get("status") in {"active", "committed"} ): return {**current, "created": False} - if current and current.get("billing_account_id") == account_id: - raise BillingStoreError( - f"Reservation for job {job_id} is already being created" - ) from exc - raise BillingStoreError( - f"Reservation for job {job_id} belongs to a different billing account" - ) from exc - version = 1 - - wallet = self.wallets.find_one_and_update( - {"billing_account_id": account_id, "available_credits": {"$gte": amount}}, - { - "$inc": {"available_credits": -amount, "reserved_credits": amount}, - "$set": {"updated_at": now}, - }, - return_document=True, - ) - if not wallet: - self.reservations.update_one( - {"job_id": job_id, "status": "reserving"}, - {"$set": {"status": "released", "updated_at": now}}, - ) - current = self.get_wallet(account_id) - raise InsufficientCredits(amount, int(current.get("available_credits") or 0)) - - self.reservations.update_one( - {"job_id": job_id, "billing_account_id": account_id, "status": "reserving"}, - {"$set": {"status": "active", "updated_at": now}}, - ) - reservation = self.get_reservation(job_id) or reservation - self._insert_ledger( - { - "id": uuid.uuid4().hex, - "billing_account_id": account_id, - "type": "reserve", - "amount": amount, - "job_id": job_id, - "idempotency_key": f"reserve:{job_id}:{version}", - "metadata": metadata or {}, - "created_at": now, - } - ) - return {**reservation, "created": True} + raise + return {**reservation, "status": "active", "updated_at": now, "created": True} def get_reservation(self, job_id: str) -> Optional[dict[str, Any]]: if self._in_memory: @@ -705,27 +706,46 @@ def release_reservation( else: from pymongo import ReturnDocument - reservation_doc = self.reservations.find_one_and_update( - {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, - {"$set": {"status": "released", "updated_at": now}}, - return_document=ReturnDocument.BEFORE, - ) - if not reservation_doc: - current = self.get_reservation(job_id) - if current and current.get("billing_account_id") != account_id: - raise BillingStoreError( - f"Reservation for job {job_id} belongs to a different billing account" + with self._client.start_session() as session: + with session.start_transaction(): + reservation_doc = self.reservations.find_one_and_update( + {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, + {"$set": {"status": "released", "updated_at": now}}, + return_document=ReturnDocument.BEFORE, + session=session, ) - return current - reservation = _without_id(reservation_doc) or {} - amount = int(reservation.get("reserved_credits") or 0) - self.wallets.update_one( - {"billing_account_id": account_id}, - { - "$inc": {"available_credits": amount, "reserved_credits": -amount}, - "$set": {"updated_at": now}, - }, - ) + if not reservation_doc: + current = self.reservations.find_one({"job_id": job_id}, session=session) + current = _without_id(current) + if current and current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + return current + reservation = _without_id(reservation_doc) or {} + amount = int(reservation.get("reserved_credits") or 0) + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"available_credits": amount, "reserved_credits": -amount}, + "$set": {"updated_at": now}, + }, + session=session, + ) + self.ledger.insert_one( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "release", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"release:{job_id}:{reservation.get('version', 1)}", + "metadata": metadata or {}, + "created_at": now, + }, + session=session, + ) + return self.get_reservation(job_id) self._insert_ledger( { "id": uuid.uuid4().hex, From 1beb1e6356596486e5397ae3ff25b497113a672f Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 14:29:25 +0530 Subject: [PATCH 10/20] Mark jobs cancelled before billing release --- src/api/routes/v2/jobs.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 4cb38308..311d5b85 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -23,7 +23,7 @@ release_job_billing, release_job_reservation, ) -from src.jobs.durable import DEAD_LETTER, QUEUED, RUNNING, get_default_job_store +from src.jobs.durable import CANCELLED, DEAD_LETTER, QUEUED, RUNNING, get_default_job_store router = APIRouter( prefix="/v2/jobs", @@ -159,15 +159,17 @@ async def cancel_job(job_id: str, request: Request, user: dict = Depends(require job = await read_user_job(job_id, user_id) if not job: return _error(request, "Job not found.", 404, elapsed_ms(start)) - if job.get("status") not in {QUEUED, RUNNING}: - return _error(request, "Only queued or running jobs can be cancelled.", 409, elapsed_ms(start)) - try: - await cancel_job_workflow(job) - except Exception as exc: - error = str(exc) or exc.__class__.__name__ - return _error(request, f"Cancel failed to reach workflow: {error}", 503, elapsed_ms(start)) + if job.get("status") not in {QUEUED, RUNNING, CANCELLED}: + return _error(request, "Only queued, running, or already-cancelled jobs can be cancelled.", 409, elapsed_ms(start)) + if job.get("status") != CANCELLED: + try: + await cancel_job_workflow(job) + except Exception as exc: + error = str(exc) or exc.__class__.__name__ + return _error(request, f"Cancel failed to reach workflow: {error}", 503, elapsed_ms(start)) + await asyncio.to_thread(get_default_job_store().mark_cancelled, job_id) + await asyncio.to_thread(_mark_scanner_job_cancelled, job) + job = await asyncio.to_thread(get_default_job_store().get, job_id) await asyncio.to_thread(release_job_billing, job, "cancelled") - await asyncio.to_thread(get_default_job_store().mark_cancelled, job_id) - await asyncio.to_thread(_mark_scanner_job_cancelled, job) job = await asyncio.to_thread(get_default_job_store().get, job_id) return _wrap(request, job_status_data(job), elapsed_ms(start)) From 68dc4feef370383218e3602d91d900a5b6cb415a Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 14:40:06 +0530 Subject: [PATCH 11/20] Keep billing debit claim transactional --- src/billing/store.py | 187 +++++++++++++++++++++++++++---------------- 1 file changed, 120 insertions(+), 67 deletions(-) diff --git a/src/billing/store.py b/src/billing/store.py index 554092ce..a4d62e11 100644 --- a/src/billing/store.py +++ b/src/billing/store.py @@ -514,38 +514,37 @@ def commit_debit( if duplicate: return duplicate now = utc_now() - reservation = self._claim_reservation_for_commit(account_id, job_id, final_amount, now) - if reservation.get("type") == "debit": - return reservation - reserved = int(reservation.get("reserved_credits") or 0) - extra = max(final_amount - reserved, 0) - refund = max(reserved - final_amount, 0) - entry = { - "id": uuid.uuid4().hex, - "billing_account_id": account_id, - "type": "debit", - "amount": -final_amount, - "job_id": job_id, - "idempotency_key": f"debit:{job_id}", - "metadata": metadata or {}, - "created_at": now, - } - refund_entry = None - if refund: - refund_entry = { + if self._in_memory: + reservation = self._claim_reservation_for_commit(account_id, job_id, final_amount, now) + if reservation.get("type") == "debit": + return reservation + reserved = int(reservation.get("reserved_credits") or 0) + extra = max(final_amount - reserved, 0) + refund = max(reserved - final_amount, 0) + entry = { "id": uuid.uuid4().hex, "billing_account_id": account_id, - "type": "refund", - "amount": refund, + "type": "debit", + "amount": -final_amount, "job_id": job_id, - "idempotency_key": f"refund:{job_id}", - "metadata": {"reason": "unused_reservation"}, + "idempotency_key": f"debit:{job_id}", + "metadata": metadata or {}, "created_at": now, } - - try: - if self._in_memory: + refund_entry = None + if refund: + refund_entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "refund", + "amount": refund, + "job_id": job_id, + "idempotency_key": f"refund:{job_id}", + "metadata": {"reason": "unused_reservation"}, + "created_at": now, + } + try: if extra: wallet = self.get_wallet(account_id) if int(wallet.get("available_credits") or 0) < extra: @@ -562,53 +561,107 @@ def commit_debit( if refund_entry: self._insert_ledger(refund_entry) return entry + except Exception: + self._release_commit_claim(account_id, job_id) + raise - with self._client.start_session() as session: - with session.start_transaction(): - existing = self.ledger.find_one( - {"idempotency_key": f"debit:{job_id}"}, - session=session, - ) - if existing: - return _without_id(existing) or {} + from pymongo import ReturnDocument - if extra: - wallet = self.wallets.find_one_and_update( - {"billing_account_id": account_id, "available_credits": {"$gte": extra}}, - {"$inc": {"available_credits": -extra}, "$set": {"updated_at": now}}, - return_document=True, - session=session, + entry = None + with self._client.start_session() as session: + with session.start_transaction(): + existing = self.ledger.find_one( + {"idempotency_key": f"debit:{job_id}"}, + session=session, + ) + if existing: + return _without_id(existing) or {} + + reservation_doc = self.reservations.find_one_and_update( + {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, + { + "$set": { + "status": "committing", + "final_credits": final_amount, + "updated_at": now, + } + }, + return_document=ReturnDocument.BEFORE, + session=session, + ) + if not reservation_doc: + current = self.reservations.find_one({"job_id": job_id}, session=session) + current = _without_id(current) + if current and current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" ) - if not wallet: - current = self.get_wallet(account_id) - raise InsufficientCredits(extra, int(current.get("available_credits") or 0)) + raise BillingStoreError(f"Reservation for job {job_id} is not active") - self._consume_lots(account_id, final_amount, session=session) - self.wallets.update_one( - {"billing_account_id": account_id}, - { - "$inc": {"reserved_credits": -reserved, "available_credits": refund}, - "$set": {"updated_at": now}, - }, - session=session, - ) - self.reservations.update_one( - {"job_id": job_id, "billing_account_id": account_id}, - { - "$set": { - "status": "committed", - "final_credits": final_amount, - "updated_at": now, - } - }, + reservation = _without_id(reservation_doc) or {} + reserved = int(reservation.get("reserved_credits") or 0) + extra = max(final_amount - reserved, 0) + refund = max(reserved - final_amount, 0) + entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "debit", + "amount": -final_amount, + "job_id": job_id, + "idempotency_key": f"debit:{job_id}", + "metadata": metadata or {}, + "created_at": now, + } + refund_entry = None + if refund: + refund_entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "refund", + "amount": refund, + "job_id": job_id, + "idempotency_key": f"refund:{job_id}", + "metadata": {"reason": "unused_reservation"}, + "created_at": now, + } + + if extra: + wallet = self.wallets.find_one_and_update( + {"billing_account_id": account_id, "available_credits": {"$gte": extra}}, + {"$inc": {"available_credits": -extra}, "$set": {"updated_at": now}}, + return_document=ReturnDocument.AFTER, session=session, ) - self.ledger.insert_one(entry, session=session) - if refund_entry: - self.ledger.insert_one(refund_entry, session=session) - except Exception: - self._release_commit_claim(account_id, job_id) - raise + if not wallet: + current = self.wallets.find_one( + {"billing_account_id": account_id}, + session=session, + ) or {} + raise InsufficientCredits(extra, int(current.get("available_credits") or 0)) + + self._consume_lots(account_id, final_amount, session=session) + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"reserved_credits": -reserved, "available_credits": refund}, + "$set": {"updated_at": now}, + }, + session=session, + ) + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id, "status": "committing"}, + { + "$set": { + "status": "committed", + "final_credits": final_amount, + "updated_at": now, + } + }, + session=session, + ) + self.ledger.insert_one(entry, session=session) + if refund_entry: + self.ledger.insert_one(refund_entry, session=session) return entry From c4b90b72fe974f8310906d01809ea7459ee1b22f Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 14:45:40 +0530 Subject: [PATCH 12/20] Release retry reservation on payload update failure --- src/api/routes/v2/jobs.py | 12 +++---- tests/api/test_memory_versioning.py | 51 +++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 311d5b85..5e5a97c8 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -121,9 +121,9 @@ async def retry_job(job_id: str, request: Request, user: dict = Depends(require_ payload = job.get("payload") if isinstance(job.get("payload"), dict) else {} billing_account_id = payload.get("billing_account_id") billing_reservation_created = False - if billing_account_id: - billing_service = get_default_billing_service() - try: + try: + if billing_account_id: + billing_service = get_default_billing_service() estimate = billing_service.estimate_required_credits(job.get("job_type") or "", payload) reservation = await asyncio.to_thread( billing_service.reserve_credits, @@ -135,13 +135,11 @@ async def retry_job(job_id: str, request: Request, user: dict = Depends(require_ payload["billing_estimate"] = estimate.model_dump() billing_reservation_created = reservation.created await asyncio.to_thread(get_default_job_store().update_payload, job_id, payload) - except InsufficientCredits as exc: - return _error(request, str(exc), 402, elapsed_ms(start)) - - try: await asyncio.to_thread(get_default_job_store().reset_for_retry, job_id, True) job = await asyncio.to_thread(get_default_job_store().get, job_id) await start_job_workflow(job) + except InsufficientCredits as exc: + return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: if billing_reservation_created and billing_account_id: await asyncio.to_thread(release_job_reservation, billing_account_id, job_id) diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index 16aa03a7..fc4fbfbe 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -246,6 +246,57 @@ def fake_release(account_id, job_id): assert store.jobs["job-1"]["error"] == "temporal unavailable" +def test_v2_retry_payload_update_failure_releases_fresh_billing_reservation(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + released = [] + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=True) + + def fail_update_payload(job_id, payload): + raise RuntimeError("payload write failed") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr(store, "update_payload", fail_update_payload) + monkeypatch.setattr( + jobs_v2, + "release_job_reservation", + lambda account_id, job_id: released.append((account_id, job_id)), + ) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert released == [("acct-1", "job-1")] + assert store.jobs["job-1"]["status"] == "failed" + assert store.jobs["job-1"]["error"] == "payload write failed" + + def test_v2_retry_start_failure_keeps_reused_billing_reservation(monkeypatch): app, _ = _build_app(monkeypatch) store = FakeJobStore() From 800a82c9d3aed0b023f6598997df443618b10f8a Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 14:59:17 +0530 Subject: [PATCH 13/20] Close remaining Greptile billing gaps --- src/api/routes/billing.py | 12 ++++---- src/api/routes/v2/jobs.py | 9 ++++-- tests/api/test_billing_routes.py | 40 +++++++++++++++++++++++++ tests/api/test_memory_versioning.py | 46 +++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 tests/api/test_billing_routes.py diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index eca97f84..34c451ae 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -208,6 +208,12 @@ async def verify_razorpay_payment( user_id = _user_id(current_user) if request.razorpay_subscription_id: + if not verify_subscription_signature( + request.razorpay_subscription_id, + request.razorpay_payment_id, + request.razorpay_signature, + ): + raise HTTPException(status_code=400, detail="Invalid Razorpay signature") checkout = await asyncio.to_thread( service.store.get_checkout, request.razorpay_subscription_id, @@ -218,12 +224,6 @@ async def verify_razorpay_payment( raise HTTPException(status_code=403, detail="Payment subscription does not belong to this user") if checkout.get("package_id") != "pro": raise HTTPException(status_code=400, detail="Payment subscription package mismatch") - if not verify_subscription_signature( - request.razorpay_subscription_id, - request.razorpay_payment_id, - request.razorpay_signature, - ): - raise HTTPException(status_code=400, detail="Invalid Razorpay signature") await asyncio.to_thread( service.grant_pro_subscription, user_id=user_id, diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 5e5a97c8..3fe47011 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -160,13 +160,18 @@ async def cancel_job(job_id: str, request: Request, user: dict = Depends(require if job.get("status") not in {QUEUED, RUNNING, CANCELLED}: return _error(request, "Only queued, running, or already-cancelled jobs can be cancelled.", 409, elapsed_ms(start)) if job.get("status") != CANCELLED: + cancel_signal_sent = False try: await cancel_job_workflow(job) + cancel_signal_sent = True + await asyncio.to_thread(get_default_job_store().mark_cancelled, job_id) + await asyncio.to_thread(_mark_scanner_job_cancelled, job) except Exception as exc: error = str(exc) or exc.__class__.__name__ + if cancel_signal_sent: + await asyncio.to_thread(release_job_billing, job, "cancel_signal_sent") + return _error(request, f"Cancel failed after reaching workflow: {error}", 503, elapsed_ms(start)) return _error(request, f"Cancel failed to reach workflow: {error}", 503, elapsed_ms(start)) - await asyncio.to_thread(get_default_job_store().mark_cancelled, job_id) - await asyncio.to_thread(_mark_scanner_job_cancelled, job) job = await asyncio.to_thread(get_default_job_store().get, job_id) await asyncio.to_thread(release_job_billing, job, "cancelled") job = await asyncio.to_thread(get_default_job_store().get, job_id) diff --git a/tests/api/test_billing_routes.py b/tests/api/test_billing_routes.py new file mode 100644 index 00000000..78f8d47e --- /dev/null +++ b/tests/api/test_billing_routes.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import pytest + +fastapi = pytest.importorskip("fastapi") +testclient = pytest.importorskip("fastapi.testclient") + +from src.api.routes import billing + + +def test_subscription_verify_checks_signature_before_checkout_lookup(monkeypatch): + class Store: + def get_checkout(self, checkout_id): + raise AssertionError("checkout lookup should not run before signature verification") + + class Service: + store = Store() + + async def fake_auth(): + return {"id": "user-1"} + + monkeypatch.setattr(billing, "get_default_billing_service", lambda: Service()) + monkeypatch.setattr(billing, "verify_subscription_signature", lambda *args: False) + + app = fastapi.FastAPI() + app.dependency_overrides[billing.require_auth] = fake_auth + app.include_router(billing.router) + + response = testclient.TestClient(app).post( + "/api/billing/razorpay/verify", + json={ + "package_id": "pro", + "razorpay_payment_id": "pay_1", + "razorpay_signature": "bad", + "razorpay_subscription_id": "sub_1", + }, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Invalid Razorpay signature" diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index fc4fbfbe..9a8cb314 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -56,6 +56,12 @@ def mark_failed(self, job_id, error): job["error"] = error return "failed" + def mark_cancelled(self, job_id): + job = self.jobs[job_id] + job["status"] = "cancelled" + job["cancelled_at"] = "now" + job["completed_at"] = "now" + def reset_for_retry(self, job_id, clear_workflow=False): job = self.jobs[job_id] job["status"] = "queued" @@ -347,6 +353,46 @@ async def fake_start_job_workflow(job): assert store.jobs["job-1"]["status"] == "failed" +def test_v2_cancel_mark_failure_releases_billing_after_signal(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "running", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 0, + "attempt_count": 1, + "workflow_id": "workflow-1", + } + released = [] + cancelled = [] + + async def fake_cancel_job_workflow(job): + cancelled.append(job["job_id"]) + + def fail_mark_cancelled(job_id): + raise RuntimeError("cancel status write failed") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "cancel_job_workflow", fake_cancel_job_workflow) + monkeypatch.setattr(store, "mark_cancelled", fail_mark_cancelled) + monkeypatch.setattr( + jobs_v2, + "release_job_billing", + lambda job, reason: released.append((job["job_id"], reason)), + ) + + response = TestClient(app).post("/v2/jobs/job-1/cancel") + + assert response.status_code == 503 + assert cancelled == ["job-1"] + assert released == [("job-1", "cancel_signal_sent")] + + def test_v1_batch_ingest_scopes_each_item_for_local_static_key(monkeypatch): monkeypatch.setattr(memory.settings, "environment", "development", raising=False) static_user = {"id": "static-key", "name": "Static Key User", "email": "static@xmem.ai"} From 3a04784ca7cb4eb4d53d5750ec3d345fe7b01795 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 15:08:51 +0530 Subject: [PATCH 14/20] Stabilize billing verification route test --- tests/api/test_billing_routes.py | 39 ++++++++++++++------------------ 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/api/test_billing_routes.py b/tests/api/test_billing_routes.py index 78f8d47e..39eb62b5 100644 --- a/tests/api/test_billing_routes.py +++ b/tests/api/test_billing_routes.py @@ -2,13 +2,14 @@ import pytest -fastapi = pytest.importorskip("fastapi") -testclient = pytest.importorskip("fastapi.testclient") +pytest.importorskip("fastapi") from src.api.routes import billing +from src.billing.types import VerifyPaymentRequest -def test_subscription_verify_checks_signature_before_checkout_lookup(monkeypatch): +@pytest.mark.asyncio +async def test_subscription_verify_checks_signature_before_checkout_lookup(monkeypatch): class Store: def get_checkout(self, checkout_id): raise AssertionError("checkout lookup should not run before signature verification") @@ -16,25 +17,19 @@ def get_checkout(self, checkout_id): class Service: store = Store() - async def fake_auth(): - return {"id": "user-1"} - monkeypatch.setattr(billing, "get_default_billing_service", lambda: Service()) monkeypatch.setattr(billing, "verify_subscription_signature", lambda *args: False) - app = fastapi.FastAPI() - app.dependency_overrides[billing.require_auth] = fake_auth - app.include_router(billing.router) - - response = testclient.TestClient(app).post( - "/api/billing/razorpay/verify", - json={ - "package_id": "pro", - "razorpay_payment_id": "pay_1", - "razorpay_signature": "bad", - "razorpay_subscription_id": "sub_1", - }, - ) - - assert response.status_code == 400 - assert response.json()["detail"] == "Invalid Razorpay signature" + with pytest.raises(billing.HTTPException) as exc: + await billing.verify_razorpay_payment( + VerifyPaymentRequest( + package_id="pro", + razorpay_payment_id="pay_1", + razorpay_signature="bad", + razorpay_subscription_id="sub_1", + ), + current_user={"id": "user-1"}, + ) + + assert exc.value.status_code == 400 + assert exc.value.detail == "Invalid Razorpay signature" From 6744051f4d47fbb83ac7b81519db47e61e2cb502 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 15:16:18 +0530 Subject: [PATCH 15/20] Fix cancel billing regression test store patch --- tests/api/test_memory_versioning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index 9a8cb314..9a9dcf73 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -378,6 +378,7 @@ def fail_mark_cancelled(job_id): raise RuntimeError("cancel status write failed") monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) monkeypatch.setattr(jobs_v2, "cancel_job_workflow", fake_cancel_job_workflow) monkeypatch.setattr(store, "mark_cancelled", fail_mark_cancelled) monkeypatch.setattr( From 2a49af573049e481397d6588897449c3f054c70e Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 15:25:27 +0530 Subject: [PATCH 16/20] Avoid webhook grants without payment id --- src/api/routes/billing.py | 20 ++++++++----- tests/api/test_billing_routes.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 34c451ae..5e068d82 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -303,7 +303,7 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: notes = payment.get("notes") or subscription.get("notes") or order.get("notes") or {} user_id = str(notes.get("user_id") or "") package_id = str(notes.get("package_id") or "") - payment_id = str(payment.get("id") or payload.get("id") or "") + payment_id = str(payment.get("id") or "") order_id = str(payment.get("order_id") or order.get("id") or "") subscription_id = str(payment.get("subscription_id") or subscription.get("id") or "") @@ -317,21 +317,27 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: return {"status": "ignored"} if event_name in {"payment.captured", "order.paid", "subscription.charged"}: - if package_id == "pro": + if not payment_id: + logger.info("Ignoring Razorpay webhook without payment id: %s", event_name) + elif package_id == "pro" and (subscription_id or order_id): await asyncio.to_thread( service.grant_pro_subscription, user_id=user_id, - payment_id=payment_id or event_id, - subscription_id=subscription_id or order_id or event_id, + payment_id=payment_id, + subscription_id=subscription_id or order_id, ) - elif package_id in billing_config.TOP_UP_PACKS: + elif package_id == "pro": + logger.info("Ignoring Razorpay pro webhook without subscription/order id: %s", event_name) + elif package_id in billing_config.TOP_UP_PACKS and order_id: await asyncio.to_thread( service.grant_topup, user_id=user_id, pack_id=package_id, - payment_id=payment_id or event_id, - order_id=order_id or event_id, + payment_id=payment_id, + order_id=order_id, ) + elif package_id in billing_config.TOP_UP_PACKS: + logger.info("Ignoring Razorpay top-up webhook without order id: %s", event_name) first_seen = await asyncio.to_thread( service.store.mark_payment_event, diff --git a/tests/api/test_billing_routes.py b/tests/api/test_billing_routes.py index 39eb62b5..29c49547 100644 --- a/tests/api/test_billing_routes.py +++ b/tests/api/test_billing_routes.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + import pytest pytest.importorskip("fastapi") @@ -33,3 +35,49 @@ class Service: assert exc.value.status_code == 400 assert exc.value.detail == "Invalid Razorpay signature" + + +@pytest.mark.asyncio +async def test_webhook_without_payment_id_does_not_grant_with_event_id(monkeypatch): + events = [] + + class Store: + def has_payment_event(self, event_id): + return False + + def mark_payment_event(self, event_id, payload): + events.append((event_id, payload["event"])) + return True + + class Service: + store = Store() + + def grant_pro_subscription(self, **kwargs): + raise AssertionError("webhook without payment id must not grant credits") + + class Request: + headers = {"x-razorpay-signature": "valid", "x-razorpay-event-id": "evt_1"} + + async def body(self): + return json.dumps( + { + "id": "evt_1", + "event": "subscription.charged", + "payload": { + "subscription": { + "entity": { + "id": "sub_1", + "notes": {"user_id": "user-1", "package_id": "pro"}, + } + } + }, + } + ).encode("utf-8") + + monkeypatch.setattr(billing, "get_default_billing_service", lambda: Service()) + monkeypatch.setattr(billing, "verify_webhook_signature", lambda body, signature: True) + + response = await billing.razorpay_webhook(Request()) + + assert response == {"status": "ok"} + assert events == [("evt_1", "subscription.charged")] From fb3237c9b68d98b3c3cc7f77ce4952cf69cc218d Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 15:30:16 +0530 Subject: [PATCH 17/20] Keep billing route test lightweight --- tests/api/test_billing_routes.py | 48 -------------------------------- 1 file changed, 48 deletions(-) diff --git a/tests/api/test_billing_routes.py b/tests/api/test_billing_routes.py index 29c49547..39eb62b5 100644 --- a/tests/api/test_billing_routes.py +++ b/tests/api/test_billing_routes.py @@ -1,7 +1,5 @@ from __future__ import annotations -import json - import pytest pytest.importorskip("fastapi") @@ -35,49 +33,3 @@ class Service: assert exc.value.status_code == 400 assert exc.value.detail == "Invalid Razorpay signature" - - -@pytest.mark.asyncio -async def test_webhook_without_payment_id_does_not_grant_with_event_id(monkeypatch): - events = [] - - class Store: - def has_payment_event(self, event_id): - return False - - def mark_payment_event(self, event_id, payload): - events.append((event_id, payload["event"])) - return True - - class Service: - store = Store() - - def grant_pro_subscription(self, **kwargs): - raise AssertionError("webhook without payment id must not grant credits") - - class Request: - headers = {"x-razorpay-signature": "valid", "x-razorpay-event-id": "evt_1"} - - async def body(self): - return json.dumps( - { - "id": "evt_1", - "event": "subscription.charged", - "payload": { - "subscription": { - "entity": { - "id": "sub_1", - "notes": {"user_id": "user-1", "package_id": "pro"}, - } - } - }, - } - ).encode("utf-8") - - monkeypatch.setattr(billing, "get_default_billing_service", lambda: Service()) - monkeypatch.setattr(billing, "verify_webhook_signature", lambda body, signature: True) - - response = await billing.razorpay_webhook(Request()) - - assert response == {"status": "ok"} - assert events == [("evt_1", "subscription.charged")] From 894795e9c89a781d3d3e5fb298b4577282cd76c7 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 15:40:49 +0530 Subject: [PATCH 18/20] Keep billing reserved if cancel status write fails --- src/api/routes/v2/jobs.py | 3 +-- tests/api/test_memory_versioning.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 3fe47011..23f0a082 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -169,8 +169,7 @@ async def cancel_job(job_id: str, request: Request, user: dict = Depends(require except Exception as exc: error = str(exc) or exc.__class__.__name__ if cancel_signal_sent: - await asyncio.to_thread(release_job_billing, job, "cancel_signal_sent") - return _error(request, f"Cancel failed after reaching workflow: {error}", 503, elapsed_ms(start)) + return _error(request, f"Cancel reached workflow but failed to persist status: {error}", 503, elapsed_ms(start)) return _error(request, f"Cancel failed to reach workflow: {error}", 503, elapsed_ms(start)) job = await asyncio.to_thread(get_default_job_store().get, job_id) await asyncio.to_thread(release_job_billing, job, "cancelled") diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index 9a9dcf73..efb05d9d 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -353,7 +353,7 @@ async def fake_start_job_workflow(job): assert store.jobs["job-1"]["status"] == "failed" -def test_v2_cancel_mark_failure_releases_billing_after_signal(monkeypatch): +def test_v2_cancel_mark_failure_keeps_billing_reserved_after_signal(monkeypatch): app, _ = _build_app(monkeypatch) store = FakeJobStore() store.jobs["job-1"] = { @@ -391,7 +391,8 @@ def fail_mark_cancelled(job_id): assert response.status_code == 503 assert cancelled == ["job-1"] - assert released == [("job-1", "cancel_signal_sent")] + assert released == [] + assert store.jobs["job-1"]["status"] == "running" def test_v1_batch_ingest_scopes_each_item_for_local_static_key(monkeypatch): From c85e9a12c175c161226bc1fe21ffc10a491d701c Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 15:53:20 +0530 Subject: [PATCH 19/20] Reject grantable webhooks missing payment ids --- src/api/routes/billing.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 5e068d82..36c5164d 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -318,7 +318,8 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: if event_name in {"payment.captured", "order.paid", "subscription.charged"}: if not payment_id: - logger.info("Ignoring Razorpay webhook without payment id: %s", event_name) + logger.warning("Razorpay webhook missing payment id for grantable event: %s", event_name) + raise HTTPException(status_code=400, detail="Webhook payment id is required for credit grant") elif package_id == "pro" and (subscription_id or order_id): await asyncio.to_thread( service.grant_pro_subscription, @@ -327,7 +328,8 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: subscription_id=subscription_id or order_id, ) elif package_id == "pro": - logger.info("Ignoring Razorpay pro webhook without subscription/order id: %s", event_name) + logger.warning("Razorpay pro webhook missing subscription/order id: %s", event_name) + raise HTTPException(status_code=400, detail="Webhook subscription or order id is required for credit grant") elif package_id in billing_config.TOP_UP_PACKS and order_id: await asyncio.to_thread( service.grant_topup, @@ -337,7 +339,8 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: order_id=order_id, ) elif package_id in billing_config.TOP_UP_PACKS: - logger.info("Ignoring Razorpay top-up webhook without order id: %s", event_name) + logger.warning("Razorpay top-up webhook missing order id: %s", event_name) + raise HTTPException(status_code=400, detail="Webhook order id is required for credit grant") first_seen = await asyncio.to_thread( service.store.mark_payment_event, From 3ead8fa1135676842447d3dc57f371fc1d744bf0 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Mon, 1 Jun 2026 16:07:28 +0530 Subject: [PATCH 20/20] Handle grant webhook and retry failure edges --- src/api/routes/billing.py | 3 ++ src/api/routes/v2/jobs.py | 8 ++++- tests/api/test_memory_versioning.py | 50 +++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 36c5164d..d34a76e2 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -341,6 +341,9 @@ async def razorpay_webhook(request: Request) -> dict[str, str]: elif package_id in billing_config.TOP_UP_PACKS: logger.warning("Razorpay top-up webhook missing order id: %s", event_name) raise HTTPException(status_code=400, detail="Webhook order id is required for credit grant") + else: + logger.warning("Razorpay webhook has unknown grant package id: %s", package_id) + raise HTTPException(status_code=400, detail="Webhook package id is not configured for credit grant") first_seen = await asyncio.to_thread( service.store.mark_payment_event, diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 23f0a082..7796eebc 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -141,9 +141,15 @@ async def retry_job(job_id: str, request: Request, user: dict = Depends(require_ except InsufficientCredits as exc: return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: + release_error = None if billing_reservation_created and billing_account_id: - await asyncio.to_thread(release_job_reservation, billing_account_id, job_id) + try: + await asyncio.to_thread(release_job_reservation, billing_account_id, job_id) + except Exception as release_exc: + release_error = str(release_exc) or release_exc.__class__.__name__ error = str(exc) or exc.__class__.__name__ + if release_error: + error = f"{error}; billing reservation release failed: {release_error}" await asyncio.to_thread(get_default_job_store().mark_failed, job_id, error) return _error(request, f"Retry failed to start workflow: {error}", 503, elapsed_ms(start)) job = await asyncio.to_thread(get_default_job_store().get, job_id) diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index efb05d9d..d7e1f900 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -303,6 +303,56 @@ def fail_update_payload(job_id, payload): assert store.jobs["job-1"]["error"] == "payload write failed" +def test_v2_retry_release_failure_still_marks_job_failed(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=True) + + async def fake_start_job_workflow(job): + raise RuntimeError("temporal unavailable") + + def fail_release(account_id, job_id): + raise RuntimeError("mongo unavailable") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr(jobs_v2, "release_job_reservation", fail_release) + monkeypatch.setattr(jobs_v2, "start_job_workflow", fake_start_job_workflow) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert store.jobs["job-1"]["status"] == "failed" + assert store.jobs["job-1"]["error"] == ( + "temporal unavailable; billing reservation release failed: mongo unavailable" + ) + + def test_v2_retry_start_failure_keeps_reused_billing_reservation(monkeypatch): app, _ = _build_app(monkeypatch) store = FakeJobStore()