diff --git a/CHANGELOG.md b/CHANGELOG.md index af837a3a..0e069a92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Add modular Razorpay billing, credit wallets, ledger reservations, and v2 memory workflow metering. - Add durable Temporal-backed v2 memory and scanner workflow APIs with job status, retry, cancel, and dead-letter endpoints. - Add modular LoCoMo and BEAM benchmark runners for the Python XMem API. - Add local XMem setup through `npx create-xmem@latest` and `npm run dev`. diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index 449d661c..d34a76e2 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -1,255 +1,56 @@ -"""Billing and Razorpay payment routes.""" +"""Billing routes backed by the modular credit ledger.""" from __future__ import annotations -import hashlib -import hmac +import asyncio +import json import logging from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional +from typing import Any import httpx -from fastapi import APIRouter, Depends, HTTPException, status -from pydantic import BaseModel, Field +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from pydantic import BaseModel from src.api.dependencies import get_current_user +from src.billing.razorpay import ( + RazorpayConfigError, + create_order, + create_subscription, + require_razorpay_keys, + verify_order_signature, + verify_subscription_signature, + verify_webhook_signature, +) +from src.billing.service import get_default_billing_service, public_plans, public_topups +from src.billing.types import ( + BillingSummary, + CheckoutRequest, + CheckoutResponse, + LedgerEntryPublic, + PlanPublic, + TopUpPackPublic, + VerifyPaymentRequest, +) from src.config import settings +from src.utils import billing as billing_config logger = logging.getLogger("xmem.api.billing") router = APIRouter(prefix="/api/billing", tags=["Billing"]) -class BillingPlan(BaseModel): - id: str - name: str - amount: int - currency: str - description: str - features: List[str] - - -class UsageSnapshot(BaseModel): - memories_written: int = 0 - retrievals: int = 0 - graph_queries: int = 0 - credits_used: int = 0 - credits_limit: int = 5000 - - -class Invoice(BaseModel): - id: str - date: datetime - amount_paise: int - status: Literal["paid", "pending", "failed"] - credits: int = 0 - receipt_url: Optional[str] = None - - -class BillingSummary(BaseModel): - plan_name: str - account_status: Literal["active", "trial", "paused", "past_due"] - currency: str - credit_balance: int - prepaid_balance_paise: int - current_month: UsageSnapshot - next_invoice_paise: int - last_payment_at: Optional[datetime] = None - invoices: List[Invoice] = Field(default_factory=list) - - class BillingSummaryResponse(BaseModel): summary: BillingSummary - plans: List[BillingPlan] - - -class CreateRazorpayOrderRequest(BaseModel): - package_id: str = Field(..., description="Plan/package ID selected by the user") - credits: int = Field(default=0, ge=0) - amount: int = Field(default=0, ge=0) - currency: str = Field(default="INR", min_length=3, max_length=3) + plans: list[PlanPublic] + topups: list[TopUpPackPublic] -class RazorpayOrderResponse(BaseModel): - id: str - order_id: str - amount: int - currency: str - key_id: str - receipt: str - package_id: str - - -class VerifyRazorpayPaymentRequest(BaseModel): - razorpay_payment_id: str - razorpay_order_id: str - razorpay_signature: str - package_id: str - credits: int = Field(default=0, ge=0) - amount: int = Field(default=0, ge=0) - currency: str = Field(default="INR", min_length=3, max_length=3) - - -class VerifyRazorpayPaymentResponse(BaseModel): - status: Literal["ok"] +class VerifyPaymentResponse(BaseModel): + status: str = "ok" summary: BillingSummary -_PLANS: Dict[str, BillingPlan] = { - "free": BillingPlan( - id="free", - name="Free", - amount=0, - currency="USD", - description="30 days free with access to the core platform, Chrome extension, MCP, and SDKs.", - features=[ - "Full XMem dashboard access", - "Chrome extension included", - "MCP server access included", - "Python and TypeScript SDKs included", - "No credit card required", - ], - ), - "pro": BillingPlan( - id="pro", - name="Pro", - amount=100, - currency="USD", - description="Full access for production apps, priority support, and pay-as-you-go usage.", - features=[ - "Everything in Free", - "Production-ready API access", - "Pay-as-you-go usage for higher volume", - "24/7 customer support", - "Access to exclusive features coming soon", - ], - ), - "enterprise": BillingPlan( - id="enterprise", - name="Enterprise", - amount=0, - currency="USD", - description="Dedicated onboarding, custom limits, security reviews, and team support.", - features=[ - "Everything in Pro", - "Custom usage limits", - "Security and procurement support", - "Dedicated onboarding", - ], - ), -} - -_in_memory_billing: Dict[str, Dict[str, Any]] = {} -_in_memory_orders: Dict[str, Dict[str, Any]] = {} - - -class BillingStore: - """Small billing metadata store backed by MongoDB with local memory fallback.""" - - def __init__(self) -> None: - self._client = None - self._db = None - self.billing = None - self.orders = None - self._connected = False - self._in_memory = False - self._try_connect() - - def _requires_durable_storage(self) -> bool: - return settings.environment.lower() in {"production", "prod"} - - def _enable_memory_fallback(self, error: Exception) -> None: - message = f"MongoDB connection failed for billing storage: {error}" - if self._requires_durable_storage(): - logger.error("%s; refusing in-memory fallback in production", message) - raise RuntimeError( - "MongoDB is required for billing storage when ENVIRONMENT=production" - ) from error - logger.warning("%s; using in-memory billing storage", message) - self._connected = False - self._in_memory = True - - def _try_connect(self) -> None: - provider = (settings.app_store_provider or "mongo").strip().lower() - if provider == "memory": - self._connected = False - self._in_memory = True - return - if provider == "postgres": - self._enable_memory_fallback(RuntimeError("Postgres billing storage is not implemented")) - return - - try: - from pymongo import ASCENDING, MongoClient - - self._client = MongoClient(settings.mongodb_uri, serverSelectionTimeoutMS=5000) - self._client.admin.command("ping") - self._db = self._client[settings.mongodb_database] - self.billing = self._db["billing_profiles"] - self.orders = self._db["billing_orders"] - self.billing.create_index([("user_id", ASCENDING)], unique=True) - self.orders.create_index([("order_id", ASCENDING)], unique=True) - self.orders.create_index([("user_id", ASCENDING)]) - self._connected = True - self._in_memory = False - except Exception as exc: - self._enable_memory_fallback(exc) - - def get_summary(self, user_id: str) -> BillingSummary: - if self._in_memory: - summary = _in_memory_billing.setdefault( - user_id, - _default_summary().model_dump(), - ) - return BillingSummary.model_validate(summary) - - doc = self.billing.find_one({"user_id": user_id}) - if not doc: - summary = _default_summary() - self.save_summary(user_id, summary) - return summary - - doc.pop("_id", None) - doc.pop("user_id", None) - return BillingSummary.model_validate(doc) - - def save_summary(self, user_id: str, summary: BillingSummary) -> None: - payload = summary.model_dump() - if self._in_memory: - _in_memory_billing[user_id] = payload - return - - self.billing.update_one( - {"user_id": user_id}, - {"$set": {"user_id": user_id, **payload}}, - upsert=True, - ) - - def save_order(self, order_id: str, order: Dict[str, Any]) -> None: - payload = {"order_id": order_id, **order} - if self._in_memory: - _in_memory_orders[order_id] = payload - return - - self.orders.update_one( - {"order_id": order_id}, - {"$set": payload}, - upsert=True, - ) - - def get_order(self, order_id: str) -> Optional[Dict[str, Any]]: - if self._in_memory: - return _in_memory_orders.get(order_id) - - doc = self.orders.find_one({"order_id": order_id}) - if doc: - doc.pop("_id", None) - return doc - - -_billing_store = BillingStore() - - async def require_auth(current_user: dict = Depends(get_current_user)) -> dict: if not current_user: raise HTTPException( @@ -267,169 +68,311 @@ def _user_id(user: dict) -> str: return str(user_id) -def _default_summary() -> BillingSummary: - return BillingSummary( - plan_name="Free trial", - account_status="trial", - currency="INR", - credit_balance=5000, - prepaid_balance_paise=0, - current_month=UsageSnapshot(), - next_invoice_paise=0, - invoices=[], - ) - +def _receipt(user_id: str, package_id: str) -> str: + ts = int(datetime.now(timezone.utc).timestamp()) + safe_user = user_id.replace(":", "_")[:16] + return f"xmem-{package_id}-{safe_user}-{ts}" -def _get_summary(user_id: str) -> BillingSummary: - return _billing_store.get_summary(user_id) +def _pack_or_plan(package_id: str) -> tuple[str, dict[str, Any]]: + if package_id in billing_config.PLANS: + return "plan", billing_config.PLANS[package_id] + if package_id in billing_config.TOP_UP_PACKS: + return "topup", billing_config.TOP_UP_PACKS[package_id] + raise HTTPException(status_code=400, detail="Unknown billing package") -def _require_razorpay_config() -> tuple[str, str]: - if not settings.razorpay_key_id or not settings.razorpay_key_secret: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Razorpay is not configured. Set RAZORPAY_KEY_ID and RAZORPAY_KEY_SECRET.", - ) - return settings.razorpay_key_id, settings.razorpay_key_secret - -def _get_plan(package_id: str) -> BillingPlan: - plan = _PLANS.get(package_id) - if not plan: - raise HTTPException(status_code=400, detail="Unknown billing plan") - return plan - - -def _verify_signature(order_id: str, payment_id: str, signature: str, secret: str) -> bool: - payload = f"{order_id}|{payment_id}".encode("utf-8") - expected = hmac.new(secret.encode("utf-8"), payload, hashlib.sha256).hexdigest() - return hmac.compare_digest(expected, signature) - - -@router.get("/plans", response_model=List[BillingPlan]) -async def list_billing_plans() -> List[BillingPlan]: - """Return the server-authoritative billing plans.""" - return list(_PLANS.values()) +@router.get("/plans", response_model=list[PlanPublic]) +async def list_billing_plans() -> list[PlanPublic]: + return public_plans() @router.get("/summary", response_model=BillingSummaryResponse) async def billing_summary(current_user: dict = Depends(require_auth)) -> BillingSummaryResponse: - """Return the current user's billing summary.""" + service = get_default_billing_service() + summary = await asyncio.to_thread(service.get_billing_summary, current_user) return BillingSummaryResponse( - summary=_get_summary(_user_id(current_user)), - plans=list(_PLANS.values()), + summary=summary, + plans=public_plans(), + topups=public_topups(), ) -@router.post("/razorpay/order", response_model=RazorpayOrderResponse) -async def create_razorpay_order( - request: CreateRazorpayOrderRequest, +@router.post("/razorpay/order", response_model=CheckoutResponse) +async def create_razorpay_checkout( + request: CheckoutRequest, current_user: dict = Depends(require_auth), -) -> RazorpayOrderResponse: - """Create a Razorpay order for the selected plan. - - The server owns plan amount and currency. Client-supplied amount/currency are - accepted only for compatibility and are intentionally ignored. - """ - key_id, key_secret = _require_razorpay_config() - plan = _get_plan(request.package_id) - - if plan.id != "pro": - raise HTTPException(status_code=400, detail="Only the Pro plan can be purchased online") +) -> CheckoutResponse: + try: + key_id, _ = require_razorpay_keys() + except RazorpayConfigError as exc: + raise HTTPException(status_code=503, detail=str(exc)) from exc user_id = _user_id(current_user) - receipt = f"xmem-{user_id[:16]}-{int(datetime.now(timezone.utc).timestamp())}" + package_type, package = _pack_or_plan(request.package_id) + service = get_default_billing_service() + account = await asyncio.to_thread(service.ensure_billing_account, current_user) - payload = { - "amount": plan.amount, - "currency": plan.currency, - "receipt": receipt, - "notes": { - "user_id": user_id, - "package_id": plan.id, - "plan_name": plan.name, - }, + if request.package_id == "free": + raise HTTPException(status_code=400, detail="Free plan does not require checkout") + + notes = { + "user_id": user_id, + "billing_account_id": account["id"], + "package_id": request.package_id, + "package_type": package_type, } + receipt = _receipt(user_id, request.package_id) try: - async with httpx.AsyncClient(timeout=20) as client: - response = await client.post( - "https://api.razorpay.com/v1/orders", - auth=(key_id, key_secret), - json=payload, + if request.package_id == "pro" and settings.razorpay_pro_plan_id: + subscription = await create_subscription( + plan_id=settings.razorpay_pro_plan_id, + notes=notes, + ) + checkout_id = str(subscription["id"]) + await asyncio.to_thread( + service.store.save_checkout, + checkout_id, + { + "type": "subscription", + "user_id": user_id, + "billing_account_id": account["id"], + "package_id": request.package_id, + "subscription_id": checkout_id, + "status": "created", + }, + ) + return CheckoutResponse( + id=checkout_id, + subscription_id=checkout_id, + package_id=request.package_id, + amount=int(package["price_paise"]), + currency=str(package.get("currency") or "INR"), + key_id=key_id, + receipt=receipt, ) - except httpx.HTTPError as exc: - logger.exception("Failed to create Razorpay order") - raise HTTPException(status_code=502, detail="Failed to reach Razorpay") from exc - - if response.status_code >= 400: - logger.warning("Razorpay order creation failed: %s %s", response.status_code, response.text[:500]) - raise HTTPException(status_code=502, detail="Razorpay order creation failed") - order = response.json() - order_id = order["id"] - _billing_store.save_order(order_id, { - "user_id": user_id, - "package_id": plan.id, - "amount": plan.amount, - "currency": plan.currency, - "receipt": receipt, - "created_at": datetime.now(timezone.utc), - }) - - return RazorpayOrderResponse( + amount = int(package["price_paise"]) + order = await create_order( + amount_paise=amount, + currency=str(package.get("currency") or "INR"), + receipt=receipt, + notes=notes, + ) + except httpx.HTTPError as exc: + raise HTTPException(status_code=502, detail="Razorpay checkout creation failed") from exc + + order_id = str(order["id"]) + await asyncio.to_thread( + service.store.save_checkout, + order_id, + { + "type": package_type, + "user_id": user_id, + "billing_account_id": account["id"], + "package_id": request.package_id, + "order_id": order_id, + "amount": amount, + "currency": str(package.get("currency") or "INR"), + "status": "created", + }, + ) + return CheckoutResponse( id=order_id, order_id=order_id, - amount=plan.amount, - currency=plan.currency, + package_id=request.package_id, + amount=amount, + currency=str(package.get("currency") or "INR"), key_id=key_id, receipt=receipt, - package_id=plan.id, ) -@router.post("/razorpay/verify", response_model=VerifyRazorpayPaymentResponse) -async def verify_razorpay_payment( - request: VerifyRazorpayPaymentRequest, +@router.post("/topups", response_model=CheckoutResponse) +async def create_topup_checkout( + request: CheckoutRequest, current_user: dict = Depends(require_auth), -) -> VerifyRazorpayPaymentResponse: - """Verify a Razorpay checkout signature and activate the paid plan.""" - _, key_secret = _require_razorpay_config() - plan = _get_plan(request.package_id) - - if not _verify_signature( - request.razorpay_order_id, - request.razorpay_payment_id, - request.razorpay_signature, - key_secret, - ): - raise HTTPException(status_code=400, detail="Invalid Razorpay signature") +) -> CheckoutResponse: + if request.package_id not in billing_config.TOP_UP_PACKS: + raise HTTPException(status_code=400, detail="Unknown top-up pack") + return await create_razorpay_checkout(request, current_user) + +@router.post("/razorpay/verify", response_model=VerifyPaymentResponse) +async def verify_razorpay_payment( + request: VerifyPaymentRequest, + current_user: dict = Depends(require_auth), +) -> VerifyPaymentResponse: + service = get_default_billing_service() user_id = _user_id(current_user) - tracked_order = _billing_store.get_order(request.razorpay_order_id) - if tracked_order and tracked_order.get("user_id") != user_id: - raise HTTPException(status_code=403, detail="Payment order does not belong to this user") - if tracked_order and tracked_order.get("package_id") != plan.id: - raise HTTPException(status_code=400, detail="Payment order package mismatch") - - now = datetime.now(timezone.utc) - summary = _get_summary(user_id) - summary.plan_name = plan.name - summary.account_status = "active" - summary.currency = plan.currency - summary.prepaid_balance_paise += plan.amount - summary.next_invoice_paise = 0 - summary.last_payment_at = now - summary.invoices.insert( - 0, - Invoice( - id=request.razorpay_payment_id, - date=now, - amount_paise=plan.amount, - status="paid", - credits=0, - ), + + if request.razorpay_subscription_id: + if not verify_subscription_signature( + request.razorpay_subscription_id, + request.razorpay_payment_id, + request.razorpay_signature, + ): + raise HTTPException(status_code=400, detail="Invalid Razorpay signature") + checkout = await asyncio.to_thread( + service.store.get_checkout, + request.razorpay_subscription_id, + ) + if not checkout: + raise HTTPException(status_code=400, detail="Unknown Razorpay subscription checkout") + if checkout.get("user_id") != user_id: + raise HTTPException(status_code=403, detail="Payment subscription does not belong to this user") + if checkout.get("package_id") != "pro": + raise HTTPException(status_code=400, detail="Payment subscription package mismatch") + await asyncio.to_thread( + service.grant_pro_subscription, + user_id=user_id, + payment_id=request.razorpay_payment_id, + subscription_id=request.razorpay_subscription_id, + ) + elif request.razorpay_order_id: + if not verify_order_signature( + request.razorpay_order_id, + request.razorpay_payment_id, + request.razorpay_signature, + ): + raise HTTPException(status_code=400, detail="Invalid Razorpay signature") + checkout = await asyncio.to_thread( + service.store.get_checkout, + request.razorpay_order_id, + ) + if not checkout: + raise HTTPException(status_code=400, detail="Unknown Razorpay payment order") + if checkout.get("user_id") != user_id: + raise HTTPException(status_code=403, detail="Payment order does not belong to this user") + if checkout.get("package_id") != request.package_id: + raise HTTPException(status_code=400, detail="Payment order package mismatch") + package_id = str(checkout.get("package_id")) + if package_id == "pro": + await asyncio.to_thread( + service.grant_pro_subscription, + user_id=user_id, + payment_id=request.razorpay_payment_id, + subscription_id=request.razorpay_order_id, + ) + else: + await asyncio.to_thread( + service.grant_topup, + user_id=user_id, + pack_id=package_id, + payment_id=request.razorpay_payment_id, + order_id=request.razorpay_order_id, + ) + else: + raise HTTPException(status_code=400, detail="Missing Razorpay order or subscription id") + + summary = await asyncio.to_thread(service.get_billing_summary, current_user) + return VerifyPaymentResponse(summary=summary) + + +@router.post("/razorpay/webhook") +async def razorpay_webhook(request: Request) -> dict[str, str]: + body = await request.body() + signature = request.headers.get("x-razorpay-signature", "") + try: + if not verify_webhook_signature(body, signature): + raise HTTPException(status_code=400, detail="Invalid Razorpay webhook signature") + except RazorpayConfigError as exc: + raise HTTPException(status_code=503, detail=str(exc)) from exc + + try: + payload = json.loads(body.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + logger.warning("Razorpay webhook body is not valid JSON: %s", exc) + raise HTTPException(status_code=400, detail="Webhook body must be valid JSON") from exc + event_id = str( + request.headers.get("x-razorpay-event-id") + or payload.get("id") + or "" + ) + if not event_id: + raise HTTPException(status_code=400, detail="Webhook event id is required") + event_name = str(payload.get("event") or "") + service = get_default_billing_service() + if await asyncio.to_thread(service.store.has_payment_event, event_id): + return {"status": "ignored_duplicate"} + + payment = (((payload.get("payload") or {}).get("payment") or {}).get("entity") or {}) + subscription = (((payload.get("payload") or {}).get("subscription") or {}).get("entity") or {}) + order = (((payload.get("payload") or {}).get("order") or {}).get("entity") or {}) + notes = payment.get("notes") or subscription.get("notes") or order.get("notes") or {} + user_id = str(notes.get("user_id") or "") + package_id = str(notes.get("package_id") or "") + payment_id = str(payment.get("id") or "") + order_id = str(payment.get("order_id") or order.get("id") or "") + subscription_id = str(payment.get("subscription_id") or subscription.get("id") or "") + + if not user_id or not package_id: + logger.info("Ignoring Razorpay webhook without XMem user/package notes: %s", event_name) + await asyncio.to_thread( + service.store.mark_payment_event, + event_id, + {"event": event_name, "payload": payload}, + ) + return {"status": "ignored"} + + if event_name in {"payment.captured", "order.paid", "subscription.charged"}: + if not payment_id: + logger.warning("Razorpay webhook missing payment id for grantable event: %s", event_name) + raise HTTPException(status_code=400, detail="Webhook payment id is required for credit grant") + elif package_id == "pro" and (subscription_id or order_id): + await asyncio.to_thread( + service.grant_pro_subscription, + user_id=user_id, + payment_id=payment_id, + subscription_id=subscription_id or order_id, + ) + elif package_id == "pro": + logger.warning("Razorpay pro webhook missing subscription/order id: %s", event_name) + raise HTTPException(status_code=400, detail="Webhook subscription or order id is required for credit grant") + elif package_id in billing_config.TOP_UP_PACKS and order_id: + await asyncio.to_thread( + service.grant_topup, + user_id=user_id, + pack_id=package_id, + payment_id=payment_id, + order_id=order_id, + ) + elif package_id in billing_config.TOP_UP_PACKS: + logger.warning("Razorpay top-up webhook missing order id: %s", event_name) + raise HTTPException(status_code=400, detail="Webhook order id is required for credit grant") + else: + logger.warning("Razorpay webhook has unknown grant package id: %s", package_id) + raise HTTPException(status_code=400, detail="Webhook package id is not configured for credit grant") + + first_seen = await asyncio.to_thread( + service.store.mark_payment_event, + event_id, + {"event": event_name, "payload": payload}, ) - _billing_store.save_summary(user_id, summary) + if not first_seen: + return {"status": "ignored_duplicate"} - return VerifyRazorpayPaymentResponse(status="ok", summary=summary) + return {"status": "ok"} + + +@router.get("/ledger", response_model=list[LedgerEntryPublic]) +async def billing_ledger( + current_user: dict = Depends(require_auth), + limit: int = Query(default=100, ge=1, le=500), +) -> list[LedgerEntryPublic]: + service = get_default_billing_service() + entries = await asyncio.to_thread(service.list_ledger, current_user, limit) + return [ + LedgerEntryPublic( + id=str(entry["id"]), + type=str(entry["type"]), + amount=int(entry["amount"]), + idempotency_key=str(entry["idempotency_key"]), + job_id=entry.get("job_id"), + source=entry.get("source"), + metadata=entry.get("metadata") or {}, + created_at=entry["created_at"], + ) + for entry in entries + ] diff --git a/src/api/routes/v2/activities.py b/src/api/routes/v2/activities.py index 4282898c..b1500f7f 100644 --- a/src/api/routes/v2/activities.py +++ b/src/api/routes/v2/activities.py @@ -9,6 +9,7 @@ from src.api.dependencies import get_ingest_pipeline from src.api.routes import memory as memory_v1 +from src.billing.service import commit_job_billing, release_job_billing from src.jobs.durable import get_default_job_store try: # pragma: no cover - no-op fallback keeps imports working without SDK. @@ -50,15 +51,22 @@ async def mark_job_progress_activity(payload: Dict[str, Any]) -> None: @activity.defn async def mark_job_succeeded_activity(payload: Dict[str, Any]) -> None: + job = await asyncio.to_thread(get_default_job_store().get, payload["job_id"]) + result = payload.get("result") or {} + if job: + result = await asyncio.to_thread(commit_job_billing, job, result) await asyncio.to_thread( get_default_job_store().mark_succeeded, payload["job_id"], - payload.get("result") or {}, + result, ) @activity.defn async def mark_job_dead_letter_activity(payload: Dict[str, Any]) -> None: + job = await asyncio.to_thread(get_default_job_store().get, payload["job_id"]) + if job: + await asyncio.to_thread(release_job_billing, job, "dead_letter") await asyncio.to_thread( get_default_job_store().mark_dead_letter, payload["job_id"], diff --git a/src/api/routes/v2/jobs.py b/src/api/routes/v2/jobs.py index 1e0cff33..7796eebc 100644 --- a/src/api/routes/v2/jobs.py +++ b/src/api/routes/v2/jobs.py @@ -17,7 +17,13 @@ ) from src.api.routes.v2.temporal_client import cancel_job_workflow, start_job_workflow from src.api.schemas import APIResponse -from src.jobs.durable import DEAD_LETTER, QUEUED, RUNNING, get_default_job_store +from src.billing import ( + InsufficientCredits, + get_default_billing_service, + release_job_billing, + release_job_reservation, +) +from src.jobs.durable import CANCELLED, DEAD_LETTER, QUEUED, RUNNING, get_default_job_store router = APIRouter( prefix="/v2/jobs", @@ -112,12 +118,38 @@ async def retry_job(job_id: str, request: Request, user: dict = Depends(require_ if job.get("status") not in {"failed", "dead_letter", "cancelled"}: return _error(request, "Only failed, dead-lettered, or cancelled jobs can be retried.", 409, elapsed_ms(start)) - await asyncio.to_thread(get_default_job_store().reset_for_retry, job_id, True) - job = await asyncio.to_thread(get_default_job_store().get, job_id) + payload = job.get("payload") if isinstance(job.get("payload"), dict) else {} + billing_account_id = payload.get("billing_account_id") + billing_reservation_created = False try: + if billing_account_id: + billing_service = get_default_billing_service() + estimate = billing_service.estimate_required_credits(job.get("job_type") or "", payload) + reservation = await asyncio.to_thread( + billing_service.reserve_credits, + billing_account_id, + job_id, + estimate.reserved_credits, + ) + payload["billing_reservation_id"] = reservation.reservation_id + payload["billing_estimate"] = estimate.model_dump() + billing_reservation_created = reservation.created + await asyncio.to_thread(get_default_job_store().update_payload, job_id, payload) + await asyncio.to_thread(get_default_job_store().reset_for_retry, job_id, True) + job = await asyncio.to_thread(get_default_job_store().get, job_id) await start_job_workflow(job) + except InsufficientCredits as exc: + return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: + release_error = None + if billing_reservation_created and billing_account_id: + try: + await asyncio.to_thread(release_job_reservation, billing_account_id, job_id) + except Exception as release_exc: + release_error = str(release_exc) or release_exc.__class__.__name__ error = str(exc) or exc.__class__.__name__ + if release_error: + error = f"{error}; billing reservation release failed: {release_error}" await asyncio.to_thread(get_default_job_store().mark_failed, job_id, error) return _error(request, f"Retry failed to start workflow: {error}", 503, elapsed_ms(start)) job = await asyncio.to_thread(get_default_job_store().get, job_id) @@ -131,14 +163,21 @@ async def cancel_job(job_id: str, request: Request, user: dict = Depends(require job = await read_user_job(job_id, user_id) if not job: return _error(request, "Job not found.", 404, elapsed_ms(start)) - if job.get("status") not in {QUEUED, RUNNING}: - return _error(request, "Only queued or running jobs can be cancelled.", 409, elapsed_ms(start)) - try: - await cancel_job_workflow(job) - except Exception as exc: - error = str(exc) or exc.__class__.__name__ - return _error(request, f"Cancel failed to reach workflow: {error}", 503, elapsed_ms(start)) - await asyncio.to_thread(get_default_job_store().mark_cancelled, job_id) - await asyncio.to_thread(_mark_scanner_job_cancelled, job) + if job.get("status") not in {QUEUED, RUNNING, CANCELLED}: + return _error(request, "Only queued, running, or already-cancelled jobs can be cancelled.", 409, elapsed_ms(start)) + if job.get("status") != CANCELLED: + cancel_signal_sent = False + try: + await cancel_job_workflow(job) + cancel_signal_sent = True + await asyncio.to_thread(get_default_job_store().mark_cancelled, job_id) + await asyncio.to_thread(_mark_scanner_job_cancelled, job) + except Exception as exc: + error = str(exc) or exc.__class__.__name__ + if cancel_signal_sent: + return _error(request, f"Cancel reached workflow but failed to persist status: {error}", 503, elapsed_ms(start)) + return _error(request, f"Cancel failed to reach workflow: {error}", 503, elapsed_ms(start)) + job = await asyncio.to_thread(get_default_job_store().get, job_id) + await asyncio.to_thread(release_job_billing, job, "cancelled") job = await asyncio.to_thread(get_default_job_store().get, job_id) return _wrap(request, job_status_data(job), elapsed_ms(start)) diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index 8fab333d..f6c68b29 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -19,8 +19,9 @@ ) from src.api.routes.v2.temporal_client import start_job_workflow from src.api.schemas import APIResponse, BatchIngestRequest, IngestRequest, ScrapeRequest, StatusEnum +from src.billing import InsufficientCredits, get_default_billing_service from src.config import settings -from src.jobs.durable import QUEUED, get_default_job_store, new_attempt_id, stable_hash +from src.jobs.durable import QUEUED, get_default_job_store, idempotency_key, new_attempt_id, stable_hash router = APIRouter( prefix="/v2/memory", @@ -39,6 +40,10 @@ def _content_hash(payload: Dict[str, Any]) -> str: return stable_hash(payload) +def _durable_job_id(job_type: str, fields: Dict[str, Any]) -> str: + return f"{job_type}:{idempotency_key(job_type, fields)}" + + class WorkflowStartFailed(RuntimeError): def __init__(self, job: Dict[str, Any], error: str) -> None: super().__init__(error) @@ -117,22 +122,37 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De payload = req.model_dump() payload["user_id"] = user_id payload["timeout_seconds"] = float(settings.memory_ingest_timeout_seconds) + idempotency_fields = { + "user_id": user_id, + "org_id": payload.get("org_id", "default"), + "content_hash": _content_hash({ + "user_query": req.user_query, + "agent_response": req.agent_response or "", + "session_datetime": req.session_datetime, + "image_url": req.image_url, + "effort_level": req.effort_level, + }), + } + job_id = _durable_job_id("memory_ingest", idempotency_fields) + billing_service = get_default_billing_service() + billing_reservation_created = False try: + account, estimate, reservation = await asyncio.to_thread( + billing_service.reserve_job_credits, + user=user, + job_id=job_id, + job_type="memory_ingest", + payload=payload, + ) + payload["billing_account_id"] = account["id"] + payload["billing_reservation_id"] = reservation.reservation_id + payload["billing_estimate"] = estimate.model_dump() + billing_reservation_created = reservation.created job, created = await _enqueue_and_start( job_type="memory_ingest", payload=payload, - idempotency_fields={ - "user_id": user_id, - "org_id": payload.get("org_id", "default"), - "content_hash": _content_hash({ - "user_query": req.user_query, - "agent_response": req.agent_response or "", - "session_datetime": req.session_datetime, - "image_url": req.image_url, - "effort_level": req.effort_level, - }), - }, + idempotency_fields=idempotency_fields, user_id=job_user_id, timeout_seconds=float(settings.memory_ingest_timeout_seconds), max_attempts=3, @@ -145,6 +165,12 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De elapsed_ms(start), ) except WorkflowStartFailed as exc: + if billing_reservation_created and payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _workflow_start_error( request, exc.job, @@ -152,7 +178,15 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De f"/v2/memory/ingest/{exc.job['job_id']}/status", elapsed_ms(start), ) + except InsufficientCredits as exc: + return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: + if billing_reservation_created and payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _error(request, str(exc), 500, elapsed_ms(start)) @@ -183,15 +217,30 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user min(len(req.items) * float(settings.memory_ingest_timeout_seconds), 3600.0), ), } + idempotency_fields = { + "user_id": user_id, + "content_hash": _content_hash({"items": items}), + } + job_id = _durable_job_id("memory_batch_ingest", idempotency_fields) + billing_service = get_default_billing_service() + billing_reservation_created = False try: + account, estimate, reservation = await asyncio.to_thread( + billing_service.reserve_job_credits, + user=user, + job_id=job_id, + job_type="memory_batch_ingest", + payload=payload, + ) + payload["billing_account_id"] = account["id"] + payload["billing_reservation_id"] = reservation.reservation_id + payload["billing_estimate"] = estimate.model_dump() + billing_reservation_created = reservation.created job, created = await _enqueue_and_start( job_type="memory_batch_ingest", payload=payload, - idempotency_fields={ - "user_id": user_id, - "content_hash": _content_hash({"items": items}), - }, + idempotency_fields=idempotency_fields, user_id=user_id, timeout_seconds=payload["timeout_seconds"], max_attempts=3, @@ -204,6 +253,12 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user elapsed_ms(start), ) except WorkflowStartFailed as exc: + if billing_reservation_created and payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _workflow_start_error( request, exc.job, @@ -211,7 +266,15 @@ async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user f"/v2/memory/jobs/{exc.job['job_id']}/status", elapsed_ms(start), ) + except InsufficientCredits as exc: + return _error(request, str(exc), 402, elapsed_ms(start)) except Exception as exc: + if billing_reservation_created and payload.get("billing_account_id"): + await asyncio.to_thread( + billing_service.release_job_reservation, + payload["billing_account_id"], + job_id, + ) return _error(request, str(exc), 500, elapsed_ms(start)) diff --git a/src/billing/__init__.py b/src/billing/__init__.py new file mode 100644 index 00000000..64cb4935 --- /dev/null +++ b/src/billing/__init__.py @@ -0,0 +1,31 @@ +"""Billing and credit ledger package.""" + +from .store import InsufficientCredits +from .service import ( + commit_job_billing, + commit_job_debit, + ensure_billing_account, + estimate_required_credits, + get_billing_summary, + get_default_billing_service, + record_usage_event, + release_job_billing, + release_job_reservation, + reserve_credits, + reserve_job_credits, +) + +__all__ = [ + "InsufficientCredits", + "commit_job_billing", + "commit_job_debit", + "ensure_billing_account", + "estimate_required_credits", + "get_billing_summary", + "get_default_billing_service", + "record_usage_event", + "release_job_billing", + "release_job_reservation", + "reserve_credits", + "reserve_job_credits", +] diff --git a/src/billing/metering.py b/src/billing/metering.py new file mode 100644 index 00000000..8d03ea7a --- /dev/null +++ b/src/billing/metering.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import math +from typing import Any, Mapping + +from src.billing.types import CreditEstimate +from src.utils import billing as billing_config + + +def _ingest_text(payload: Mapping[str, Any]) -> str: + return "\n".join( + str(payload.get(key) or "") + for key in ("user_query", "agent_response") + if payload.get(key) + ) + + +def content_tokens_for_job(job_type: str, payload: Mapping[str, Any]) -> int: + if job_type == "memory_batch_ingest": + return sum( + content_tokens_for_job("memory_ingest", item) + for item in list(payload.get("items") or []) + if isinstance(item, Mapping) + ) + if job_type == "memory_ingest": + return billing_config.estimate_tokens(_ingest_text(payload)) + if job_type == "memory_retrieve": + return billing_config.estimate_tokens(str(payload.get("query") or "")) + return billing_config.estimate_tokens(str(payload)) + + +def estimate_required_credits( + job_type: str, + payload: Mapping[str, Any], + *, + include_reservation_buffer: bool = True, +) -> CreditEstimate: + tokens = content_tokens_for_job(job_type, payload) + multiplier = billing_config.workflow_multiplier(job_type, payload) + billable = max(1, math.ceil(tokens * multiplier)) + buffer = billing_config.RESERVATION_BUFFER_MULTIPLIER if include_reservation_buffer else 1.0 + reserved = max(billable, math.ceil(billable * buffer)) + return CreditEstimate( + job_type=job_type, + content_tokens=tokens, + multiplier=multiplier, + billable_credits=billable, + reserved_credits=reserved, + ) diff --git a/src/billing/razorpay.py b/src/billing/razorpay.py new file mode 100644 index 00000000..aa5013bd --- /dev/null +++ b/src/billing/razorpay.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +from typing import Any, Optional + +import httpx + +from src.config import settings + +logger = logging.getLogger("xmem.billing.razorpay") + +RAZORPAY_API = "https://api.razorpay.com/v1" + + +class RazorpayConfigError(RuntimeError): + pass + + +def require_razorpay_keys() -> tuple[str, str]: + if not settings.razorpay_key_id or not settings.razorpay_key_secret: + raise RazorpayConfigError( + "Razorpay is not configured. Set RAZORPAY_KEY_ID and RAZORPAY_KEY_SECRET." + ) + return settings.razorpay_key_id, settings.razorpay_key_secret + + +def verify_order_signature(order_id: str, payment_id: str, signature: str) -> bool: + _, secret = require_razorpay_keys() + payload = f"{order_id}|{payment_id}".encode("utf-8") + expected = hmac.new(secret.encode("utf-8"), payload, hashlib.sha256).hexdigest() + return hmac.compare_digest(expected, signature) + + +def verify_subscription_signature( + subscription_id: str, + payment_id: str, + signature: str, +) -> bool: + _, secret = require_razorpay_keys() + payload = f"{payment_id}|{subscription_id}".encode("utf-8") + expected = hmac.new(secret.encode("utf-8"), payload, hashlib.sha256).hexdigest() + return hmac.compare_digest(expected, signature) + + +def verify_webhook_signature(body: bytes, signature: str) -> bool: + if not settings.razorpay_webhook_secret: + raise RazorpayConfigError("RAZORPAY_WEBHOOK_SECRET is not configured.") + expected = hmac.new( + settings.razorpay_webhook_secret.encode("utf-8"), + body, + hashlib.sha256, + ).hexdigest() + return hmac.compare_digest(expected, signature) + + +async def create_order( + *, + amount_paise: int, + currency: str, + receipt: str, + notes: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + key_id, key_secret = require_razorpay_keys() + payload = { + "amount": amount_paise, + "currency": currency, + "receipt": receipt, + "notes": notes or {}, + } + async with httpx.AsyncClient(timeout=20) as client: + response = await client.post( + f"{RAZORPAY_API}/orders", + auth=(key_id, key_secret), + json=payload, + ) + if response.status_code >= 400: + logger.warning("Razorpay order creation failed: %s %s", response.status_code, response.text[:500]) + response.raise_for_status() + return response.json() + + +async def create_subscription( + *, + plan_id: str, + total_count: int = 120, + notes: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + key_id, key_secret = require_razorpay_keys() + payload = { + "plan_id": plan_id, + "total_count": total_count, + "quantity": 1, + "customer_notify": 1, + "notes": notes or {}, + } + async with httpx.AsyncClient(timeout=20) as client: + response = await client.post( + f"{RAZORPAY_API}/subscriptions", + auth=(key_id, key_secret), + json=payload, + ) + if response.status_code >= 400: + logger.warning("Razorpay subscription creation failed: %s %s", response.status_code, response.text[:500]) + response.raise_for_status() + return response.json() diff --git a/src/billing/service.py b/src/billing/service.py new file mode 100644 index 00000000..3618d68b --- /dev/null +++ b/src/billing/service.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import Any, Mapping, Optional + +from src.billing.metering import estimate_required_credits as _estimate_required_credits +from src.billing.store import BillingStore, get_default_billing_store, utc_now +from src.billing.types import BillingSummary, CreditEstimate, CreditLotPublic, PlanPublic, ReservationResult, TopUpPackPublic +from src.utils import billing as billing_config + +logger = logging.getLogger("xmem.billing.service") + + +def _user_id(user: Mapping[str, Any]) -> str: + user_id = user.get("id") or user.get("_id") or user.get("sub") + if not user_id: + raise ValueError("Authenticated user is missing an id") + return str(user_id) + + +def public_plans() -> list[PlanPublic]: + return [ + PlanPublic( + id=plan_id, + name=str(plan["name"]), + price_paise=int(plan.get("price_paise") or 0), + currency=str(plan.get("currency") or "INR"), + monthly_credits=int(plan.get("monthly_credits") or 0), + trial_credits=int(plan.get("trial_credits") or 0), + trial_days=int(plan.get("trial_days") or 0), + nominal_paise_per_credit=billing_config.nominal_paise_per_credit(plan_id), + ) + for plan_id, plan in billing_config.PLANS.items() + ] + + +def public_topups() -> list[TopUpPackPublic]: + return [ + TopUpPackPublic( + id=pack_id, + price_paise=int(pack["price_paise"]), + currency=str(pack.get("currency") or "INR"), + credits=int(pack["credits"]), + ) + for pack_id, pack in billing_config.TOP_UP_PACKS.items() + ] + + +class BillingService: + def __init__(self, store: Optional[BillingStore] = None) -> None: + self.store = store or get_default_billing_store() + + def ensure_billing_account(self, user: Mapping[str, Any]) -> dict[str, Any]: + owner_id = _user_id(user) + account = self.store.ensure_account(owner_id=owner_id) + free_plan = billing_config.PLANS["free"] + trial_credits = int(free_plan.get("trial_credits") or 0) + trial_days = int(free_plan.get("trial_days") or 30) + if trial_credits > 0: + self.store.grant_credits( + account_id=account["id"], + amount=trial_credits, + source="free_trial", + expires_at=utc_now() + timedelta(days=trial_days), + idempotency_key=f"free_trial:{owner_id}", + metadata={"plan_id": "free", "trial_days": trial_days}, + ) + return account + + def estimate_required_credits( + self, + job_type: str, + payload: Mapping[str, Any], + *, + include_reservation_buffer: bool = True, + ) -> CreditEstimate: + return _estimate_required_credits( + job_type, + payload, + include_reservation_buffer=include_reservation_buffer, + ) + + def reserve_credits( + self, + account_id: str, + job_id: str, + estimated_credits: int, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> ReservationResult: + reservation = self.store.reserve_credits( + account_id=account_id, + job_id=job_id, + amount=estimated_credits, + metadata=metadata, + ) + wallet = self.store.get_wallet(account_id) + return ReservationResult( + reservation_id=reservation["id"], + billing_account_id=account_id, + job_id=job_id, + reserved_credits=int(reservation.get("reserved_credits") or 0), + status=reservation.get("status", "active"), + available_credits=int(wallet.get("available_credits") or 0), + created=bool(reservation.get("created")), + ) + + def reserve_job_credits( + self, + *, + user: Mapping[str, Any], + job_id: str, + job_type: str, + payload: Mapping[str, Any], + ) -> tuple[dict[str, Any], CreditEstimate, ReservationResult]: + account = self.ensure_billing_account(user) + estimate = self.estimate_required_credits(job_type, payload) + reservation = self.reserve_credits( + account["id"], + job_id, + estimate.reserved_credits, + metadata={ + "job_type": job_type, + "billable_credits": estimate.billable_credits, + "content_tokens": estimate.content_tokens, + }, + ) + return account, estimate, reservation + + def commit_job_debit( + self, + account_id: str, + job_id: str, + final_credits: int, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + return self.store.commit_debit( + account_id=account_id, + job_id=job_id, + final_amount=final_credits, + metadata=metadata, + ) + + def release_job_reservation( + self, + account_id: str, + job_id: str, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[dict[str, Any]]: + return self.store.release_reservation( + account_id=account_id, + job_id=job_id, + metadata=metadata, + ) + + def commit_job_billing(self, job: Mapping[str, Any], result: Mapping[str, Any]) -> dict[str, Any]: + payload = job.get("payload") if isinstance(job.get("payload"), Mapping) else {} + account_id = payload.get("billing_account_id") + if not account_id: + return dict(result) + job_type = str(job.get("job_type") or payload.get("job_type") or "memory_ingest") + estimate = self.estimate_required_credits( + job_type, + payload, + include_reservation_buffer=False, + ) + self.commit_job_debit( + str(account_id), + str(job["job_id"]), + estimate.billable_credits, + metadata={ + "job_type": job_type, + "content_tokens": estimate.content_tokens, + "multiplier": estimate.multiplier, + }, + ) + enriched = dict(result) + enriched["billing"] = { + "billing_account_id": account_id, + "billable_credits": estimate.billable_credits, + "content_tokens": estimate.content_tokens, + "multiplier": estimate.multiplier, + } + return enriched + + def release_job_billing(self, job: Mapping[str, Any], reason: str = "job_not_completed") -> None: + payload = job.get("payload") if isinstance(job.get("payload"), Mapping) else {} + account_id = payload.get("billing_account_id") + if not account_id: + return + self.release_job_reservation( + str(account_id), + str(job["job_id"]), + metadata={"reason": reason}, + ) + + def grant_pro_subscription( + self, + *, + user_id: str, + payment_id: str, + subscription_id: str, + period_end=None, + ) -> dict[str, Any]: + account = self.store.ensure_account(owner_id=user_id) + plan = billing_config.PLANS["pro"] + expires_at = period_end or (utc_now() + timedelta(days=30)) + self.store.update_account( + account["id"], + { + "plan_id": "pro", + "status": "active", + "razorpay_subscription_id": subscription_id, + "current_period_end": expires_at, + }, + ) + return self.store.grant_credits( + account_id=account["id"], + amount=int(plan["monthly_credits"]), + source="pro_monthly", + expires_at=expires_at, + idempotency_key=f"pro_grant:{subscription_id}:{payment_id}", + metadata={"payment_id": payment_id, "subscription_id": subscription_id}, + ) + + def grant_topup( + self, + *, + user_id: str, + pack_id: str, + payment_id: str, + order_id: str, + ) -> dict[str, Any]: + pack = billing_config.TOP_UP_PACKS[pack_id] + account = self.store.ensure_account(owner_id=user_id) + return self.store.grant_credits( + account_id=account["id"], + amount=int(pack["credits"]), + source=pack_id, + expires_at=utc_now() + timedelta(days=billing_config.TOP_UP_EXPIRY_DAYS), + idempotency_key=f"topup_grant:{order_id}:{payment_id}", + metadata={"payment_id": payment_id, "order_id": order_id, "pack_id": pack_id}, + ) + + def get_billing_summary(self, user: Mapping[str, Any]) -> BillingSummary: + account = self.ensure_billing_account(user) + wallet = self.store.get_wallet(account["id"]) + plan_id = str(account.get("plan_id") or "free") + plan = billing_config.PLANS.get(plan_id, billing_config.PLANS["free"]) + lots = [ + CreditLotPublic( + id=str(lot["id"]), + source=str(lot.get("source") or ""), + remaining_credits=int(lot.get("remaining_credits") or 0), + expires_at=lot.get("expires_at"), + ) + for lot in self.store.active_lots(account["id"]) + ] + return BillingSummary( + billing_account_id=account["id"], + owner_type=str(account.get("owner_type") or "user"), + owner_id=str(account.get("owner_id")), + plan_id=plan_id, + plan_name=str(plan.get("name") or plan_id), + status=str(account.get("status") or "trialing"), + currency=str(plan.get("currency") or "INR"), + available_credits=int(wallet.get("available_credits") or 0), + reserved_credits=int(wallet.get("reserved_credits") or 0), + current_period_start=account.get("current_period_start"), + current_period_end=account.get("current_period_end"), + credit_lots=lots, + ) + + def list_ledger(self, user: Mapping[str, Any], limit: int = 100) -> list[dict[str, Any]]: + account = self.ensure_billing_account(user) + return self.store.list_ledger(account["id"], limit=limit) + + def record_usage_event(self, **event: Any) -> None: + self.store.record_usage_event(event) + + +_default_service: Optional[BillingService] = None + + +def get_default_billing_service() -> BillingService: + global _default_service + if _default_service is None: + _default_service = BillingService() + return _default_service + + +def ensure_billing_account(user: Mapping[str, Any]) -> dict[str, Any]: + return get_default_billing_service().ensure_billing_account(user) + + +def estimate_required_credits(job_type: str, payload: Mapping[str, Any]) -> CreditEstimate: + return get_default_billing_service().estimate_required_credits(job_type, payload) + + +def reserve_credits(account_id: str, job_id: str, estimated_credits: int) -> ReservationResult: + return get_default_billing_service().reserve_credits(account_id, job_id, estimated_credits) + + +def reserve_job_credits( + *, + user: Mapping[str, Any], + job_id: str, + job_type: str, + payload: Mapping[str, Any], +) -> tuple[dict[str, Any], CreditEstimate, ReservationResult]: + return get_default_billing_service().reserve_job_credits( + user=user, + job_id=job_id, + job_type=job_type, + payload=payload, + ) + + +def commit_job_debit(account_id: str, job_id: str, final_credits: int) -> dict[str, Any]: + return get_default_billing_service().commit_job_debit(account_id, job_id, final_credits) + + +def release_job_reservation(account_id: str, job_id: str) -> Optional[dict[str, Any]]: + return get_default_billing_service().release_job_reservation(account_id, job_id) + + +def commit_job_billing(job: Mapping[str, Any], result: Mapping[str, Any]) -> dict[str, Any]: + return get_default_billing_service().commit_job_billing(job, result) + + +def release_job_billing(job: Mapping[str, Any], reason: str = "job_not_completed") -> None: + get_default_billing_service().release_job_billing(job, reason) + + +def get_billing_summary(user: Mapping[str, Any]) -> BillingSummary: + return get_default_billing_service().get_billing_summary(user) + + +def record_usage_event(**event: Any) -> None: + get_default_billing_service().record_usage_event(**event) diff --git a/src/billing/store.py b/src/billing/store.py new file mode 100644 index 00000000..a4d62e11 --- /dev/null +++ b/src/billing/store.py @@ -0,0 +1,948 @@ +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timezone +from typing import Any, Iterable, Optional + +from src.config import settings + +try: + from pymongo.errors import DuplicateKeyError +except Exception: # pragma: no cover - pymongo may be absent in memory-only dev/test. + DuplicateKeyError = None # type: ignore[assignment] + +logger = logging.getLogger("xmem.billing.store") + +_memory_accounts: dict[str, dict[str, Any]] = {} +_memory_wallets: dict[str, dict[str, Any]] = {} +_memory_lots: dict[str, dict[str, Any]] = {} +_memory_ledger: dict[str, dict[str, Any]] = {} +_memory_reservations: dict[str, dict[str, Any]] = {} +_memory_usage_events: list[dict[str, Any]] = [] +_memory_checkouts: dict[str, dict[str, Any]] = {} +_memory_payment_events: dict[str, dict[str, Any]] = {} + + +class BillingStoreError(RuntimeError): + pass + + +class InsufficientCredits(BillingStoreError): + def __init__(self, required: int, available: int) -> None: + super().__init__( + f"Insufficient credits: required {required}, available {available}." + ) + self.required = required + self.available = available + + +def utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _without_id(doc: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + if not doc: + return None + result = dict(doc) + result.pop("_id", None) + return result + + +def _is_expired(doc: dict[str, Any], now: Optional[datetime] = None) -> bool: + expires_at = doc.get("expires_at") + if not expires_at: + return False + now = now or utc_now() + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return expires_at <= now + + +def _is_duplicate_key_error(exc: Exception) -> bool: + return DuplicateKeyError is not None and isinstance(exc, DuplicateKeyError) + + +class BillingStore: + """Mongo-backed credit ledger with in-memory fallback for local development.""" + + def __init__(self, uri: Optional[str] = None, database: Optional[str] = None) -> None: + self._uri = uri or settings.mongodb_uri + self._database = database or settings.mongodb_database + self._client = None + self._db = None + self._connected = False + self._in_memory = False + self._try_connect() + + def _requires_durable_storage(self) -> bool: + return settings.environment.lower() in {"production", "prod"} + + def _enable_memory_fallback(self, error: Exception) -> None: + message = f"Billing store connection failed: {error}" + if self._requires_durable_storage(): + logger.error("%s; refusing in-memory fallback in production", message) + raise RuntimeError( + "MongoDB is required for billing storage when ENVIRONMENT=production" + ) from error + logger.warning("%s; using in-memory billing storage", message) + self._connected = False + self._in_memory = True + + def _try_connect(self) -> None: + provider = (settings.app_store_provider or "mongo").strip().lower() + if provider == "memory": + self._in_memory = True + return + if provider == "postgres": + self._enable_memory_fallback( + RuntimeError("Postgres billing storage is not implemented") + ) + return + try: + from pymongo import ASCENDING, MongoClient + + self._client = MongoClient(self._uri, serverSelectionTimeoutMS=5000) + self._client.admin.command("ping") + self._db = self._client[self._database] + self.accounts = self._db["billing_accounts"] + self.wallets = self._db["credit_wallets"] + self.lots = self._db["credit_lots"] + self.ledger = self._db["credit_ledger"] + self.reservations = self._db["credit_reservations"] + self.usage_events = self._db["usage_events"] + self.payments = self._db["billing_payments"] + + self.accounts.create_index([("id", ASCENDING)], unique=True) + self.accounts.create_index([("owner_type", ASCENDING), ("owner_id", ASCENDING)], unique=True) + self.accounts.create_index([("razorpay_subscription_id", ASCENDING)]) + self.wallets.create_index([("billing_account_id", ASCENDING)], unique=True) + self.lots.create_index([("id", ASCENDING)], unique=True) + self.lots.create_index([("billing_account_id", ASCENDING), ("expires_at", ASCENDING)]) + self.ledger.create_index([("id", ASCENDING)], unique=True) + self.ledger.create_index([("idempotency_key", ASCENDING)], unique=True) + self.ledger.create_index([("billing_account_id", ASCENDING), ("created_at", ASCENDING)]) + self.reservations.create_index([("id", ASCENDING)], unique=True) + self.reservations.create_index([("job_id", ASCENDING)], unique=True) + self.payments.create_index([("id", ASCENDING)], unique=True, sparse=True) + self.payments.create_index([("razorpay_event_id", ASCENDING)], unique=True, sparse=True) + self.payments.create_index([("razorpay_payment_id", ASCENDING)], unique=True, sparse=True) + + self._connected = True + self._in_memory = False + except Exception as exc: + self._enable_memory_fallback(exc) + + def ensure_account( + self, + *, + owner_id: str, + owner_type: str = "user", + plan_id: str = "free", + status: str = "trialing", + ) -> dict[str, Any]: + now = utc_now() + if self._in_memory: + key = f"{owner_type}:{owner_id}" + account = _memory_accounts.get(key) + if account: + return dict(account) + account = { + "id": uuid.uuid4().hex, + "owner_type": owner_type, + "owner_id": owner_id, + "plan_id": plan_id, + "status": status, + "created_at": now, + "updated_at": now, + } + _memory_accounts[key] = account + _memory_wallets[account["id"]] = { + "billing_account_id": account["id"], + "available_credits": 0, + "reserved_credits": 0, + "updated_at": now, + } + return dict(account) + + from pymongo import ReturnDocument + + doc = self.accounts.find_one_and_update( + {"owner_type": owner_type, "owner_id": owner_id}, + { + "$setOnInsert": { + "id": uuid.uuid4().hex, + "owner_type": owner_type, + "owner_id": owner_id, + "plan_id": plan_id, + "status": status, + "created_at": now, + }, + "$set": {"updated_at": now}, + }, + upsert=True, + return_document=ReturnDocument.AFTER, + ) + account = _without_id(doc) or {} + self.wallets.update_one( + {"billing_account_id": account["id"]}, + { + "$setOnInsert": { + "billing_account_id": account["id"], + "available_credits": 0, + "reserved_credits": 0, + }, + "$set": {"updated_at": now}, + }, + upsert=True, + ) + return account + + def get_account(self, account_id: str) -> Optional[dict[str, Any]]: + if self._in_memory: + for account in _memory_accounts.values(): + if account["id"] == account_id: + return dict(account) + return None + return _without_id(self.accounts.find_one({"id": account_id})) + + def update_account(self, account_id: str, updates: dict[str, Any]) -> None: + updates = {**updates, "updated_at": utc_now()} + if self._in_memory: + for account in _memory_accounts.values(): + if account["id"] == account_id: + account.update(updates) + return + return + self.accounts.update_one({"id": account_id}, {"$set": updates}) + + def get_wallet(self, account_id: str) -> dict[str, Any]: + if self._in_memory: + return dict( + _memory_wallets.setdefault( + account_id, + { + "billing_account_id": account_id, + "available_credits": 0, + "reserved_credits": 0, + "updated_at": utc_now(), + }, + ) + ) + wallet = self.wallets.find_one({"billing_account_id": account_id}) + if not wallet: + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$setOnInsert": { + "billing_account_id": account_id, + "available_credits": 0, + "reserved_credits": 0, + }, + "$set": {"updated_at": utc_now()}, + }, + upsert=True, + ) + wallet = self.wallets.find_one({"billing_account_id": account_id}) + return _without_id(wallet) or {} + + def _insert_ledger(self, entry: dict[str, Any]) -> Optional[dict[str, Any]]: + if self._in_memory: + key = entry["idempotency_key"] + if key in _memory_ledger: + return dict(_memory_ledger[key]) + _memory_ledger[key] = dict(entry) + return None + try: + self.ledger.insert_one(entry) + return None + except Exception as exc: + if not _is_duplicate_key_error(exc): + raise + return _without_id(self.ledger.find_one({"idempotency_key": entry["idempotency_key"]})) + + def grant_credits( + self, + *, + account_id: str, + amount: int, + source: str, + expires_at: Optional[datetime], + idempotency_key: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + if amount <= 0: + raise ValueError("Credit grant amount must be positive") + now = utc_now() + ledger_entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "grant", + "amount": amount, + "source": source, + "expires_at": expires_at, + "idempotency_key": idempotency_key, + "metadata": metadata or {}, + "created_at": now, + } + lot = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "source": source, + "remaining_credits": amount, + "ledger_id": ledger_entry["id"], + "expires_at": expires_at, + "created_at": now, + "updated_at": now, + } + if self._in_memory: + duplicate = self.find_ledger_by_key(idempotency_key) + if duplicate: + return duplicate + _memory_lots[lot["id"]] = lot + wallet = _memory_wallets.setdefault( + account_id, + {"billing_account_id": account_id, "available_credits": 0, "reserved_credits": 0}, + ) + wallet["available_credits"] += amount + wallet["updated_at"] = now + self._insert_ledger(ledger_entry) + return dict(ledger_entry) + + try: + with self._client.start_session() as session: + with session.start_transaction(): + existing = self.ledger.find_one( + {"idempotency_key": idempotency_key}, + session=session, + ) + if existing: + return _without_id(existing) or {} + self.lots.insert_one(lot, session=session) + self.wallets.update_one( + {"billing_account_id": account_id}, + {"$inc": {"available_credits": amount}, "$set": {"updated_at": now}}, + upsert=True, + session=session, + ) + self.ledger.insert_one(ledger_entry, session=session) + except Exception as exc: + if _is_duplicate_key_error(exc): + existing = self.find_ledger_by_key(idempotency_key) + if existing: + return existing + raise + return ledger_entry + + def reserve_credits( + self, + *, + account_id: str, + job_id: str, + amount: int, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + if amount <= 0: + raise ValueError("Reservation amount must be positive") + existing = self.get_reservation(job_id) + if existing and existing.get("status") in {"active", "committed"}: + if existing.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + return {**existing, "created": False} + + now = utc_now() + if self._in_memory: + if existing and existing.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + wallet = self.get_wallet(account_id) + if int(wallet.get("available_credits") or 0) < amount: + raise InsufficientCredits(amount, int(wallet.get("available_credits") or 0)) + _memory_wallets[account_id]["available_credits"] -= amount + _memory_wallets[account_id]["reserved_credits"] += amount + version = int((existing or {}).get("version") or 0) + 1 + reservation = { + "id": (existing or {}).get("id") or uuid.uuid4().hex, + "billing_account_id": account_id, + "job_id": job_id, + "reserved_credits": amount, + "status": "active", + "version": version, + "metadata": metadata or {}, + "created_at": (existing or {}).get("created_at") or now, + "updated_at": now, + } + _memory_reservations[job_id] = reservation + self._insert_ledger( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "reserve", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"reserve:{job_id}:{version}", + "metadata": metadata or {}, + "created_at": now, + } + ) + return {**reservation, "created": True} + + from pymongo import ReturnDocument + + try: + with self._client.start_session() as session: + with session.start_transaction(): + current = self.reservations.find_one({"job_id": job_id}, session=session) + if current: + current = _without_id(current) or {} + if current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + if current.get("status") in {"active", "committed"}: + return {**current, "created": False} + if current.get("status") == "reserving": + raise BillingStoreError( + f"Reservation for job {job_id} is already being created" + ) + reserved_doc = self.reservations.find_one_and_update( + { + "job_id": job_id, + "billing_account_id": account_id, + "status": {"$nin": ["active", "committed", "reserving"]}, + }, + { + "$set": { + "status": "reserving", + "reserved_credits": amount, + "metadata": metadata or {}, + "updated_at": now, + }, + "$inc": {"version": 1}, + }, + return_document=ReturnDocument.AFTER, + session=session, + ) + if not reserved_doc: + raise BillingStoreError(f"Reservation for job {job_id} is already active") + reservation = _without_id(reserved_doc) or {} + version = int(reservation.get("version") or 1) + else: + reservation = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "job_id": job_id, + "reserved_credits": amount, + "status": "reserving", + "version": 1, + "metadata": metadata or {}, + "created_at": now, + "updated_at": now, + } + self.reservations.insert_one(reservation, session=session) + version = 1 + + wallet = self.wallets.find_one_and_update( + {"billing_account_id": account_id, "available_credits": {"$gte": amount}}, + { + "$inc": {"available_credits": -amount, "reserved_credits": amount}, + "$set": {"updated_at": now}, + }, + return_document=True, + session=session, + ) + if not wallet: + current_wallet = self.wallets.find_one( + {"billing_account_id": account_id}, + session=session, + ) or {} + raise InsufficientCredits( + amount, + int(current_wallet.get("available_credits") or 0), + ) + + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id, "status": "reserving"}, + {"$set": {"status": "active", "updated_at": now}}, + session=session, + ) + self.ledger.insert_one( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "reserve", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"reserve:{job_id}:{version}", + "metadata": metadata or {}, + "created_at": now, + }, + session=session, + ) + except Exception as exc: + if _is_duplicate_key_error(exc): + current = self.get_reservation(job_id) + if ( + current + and current.get("billing_account_id") == account_id + and current.get("status") in {"active", "committed"} + ): + return {**current, "created": False} + raise + return {**reservation, "status": "active", "updated_at": now, "created": True} + + def get_reservation(self, job_id: str) -> Optional[dict[str, Any]]: + if self._in_memory: + reservation = _memory_reservations.get(job_id) + return dict(reservation) if reservation else None + return _without_id(self.reservations.find_one({"job_id": job_id})) + + def commit_debit( + self, + *, + account_id: str, + job_id: str, + final_amount: int, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + if final_amount <= 0: + raise ValueError("Debit amount must be positive") + duplicate = self.find_ledger_by_key(f"debit:{job_id}") + if duplicate: + return duplicate + now = utc_now() + + if self._in_memory: + reservation = self._claim_reservation_for_commit(account_id, job_id, final_amount, now) + if reservation.get("type") == "debit": + return reservation + reserved = int(reservation.get("reserved_credits") or 0) + extra = max(final_amount - reserved, 0) + refund = max(reserved - final_amount, 0) + entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "debit", + "amount": -final_amount, + "job_id": job_id, + "idempotency_key": f"debit:{job_id}", + "metadata": metadata or {}, + "created_at": now, + } + refund_entry = None + if refund: + refund_entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "refund", + "amount": refund, + "job_id": job_id, + "idempotency_key": f"refund:{job_id}", + "metadata": {"reason": "unused_reservation"}, + "created_at": now, + } + try: + if extra: + wallet = self.get_wallet(account_id) + if int(wallet.get("available_credits") or 0) < extra: + raise InsufficientCredits(extra, int(wallet.get("available_credits") or 0)) + _memory_wallets[account_id]["available_credits"] -= extra + self._consume_lots(account_id, final_amount) + _memory_wallets[account_id]["reserved_credits"] -= reserved + _memory_wallets[account_id]["available_credits"] += refund + _memory_wallets[account_id]["updated_at"] = now + _memory_reservations[job_id]["status"] = "committed" + _memory_reservations[job_id]["final_credits"] = final_amount + _memory_reservations[job_id]["updated_at"] = now + self._insert_ledger(entry) + if refund_entry: + self._insert_ledger(refund_entry) + return entry + except Exception: + self._release_commit_claim(account_id, job_id) + raise + + from pymongo import ReturnDocument + + entry = None + with self._client.start_session() as session: + with session.start_transaction(): + existing = self.ledger.find_one( + {"idempotency_key": f"debit:{job_id}"}, + session=session, + ) + if existing: + return _without_id(existing) or {} + + reservation_doc = self.reservations.find_one_and_update( + {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, + { + "$set": { + "status": "committing", + "final_credits": final_amount, + "updated_at": now, + } + }, + return_document=ReturnDocument.BEFORE, + session=session, + ) + if not reservation_doc: + current = self.reservations.find_one({"job_id": job_id}, session=session) + current = _without_id(current) + if current and current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + raise BillingStoreError(f"Reservation for job {job_id} is not active") + + reservation = _without_id(reservation_doc) or {} + reserved = int(reservation.get("reserved_credits") or 0) + extra = max(final_amount - reserved, 0) + refund = max(reserved - final_amount, 0) + entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "debit", + "amount": -final_amount, + "job_id": job_id, + "idempotency_key": f"debit:{job_id}", + "metadata": metadata or {}, + "created_at": now, + } + refund_entry = None + if refund: + refund_entry = { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "refund", + "amount": refund, + "job_id": job_id, + "idempotency_key": f"refund:{job_id}", + "metadata": {"reason": "unused_reservation"}, + "created_at": now, + } + + if extra: + wallet = self.wallets.find_one_and_update( + {"billing_account_id": account_id, "available_credits": {"$gte": extra}}, + {"$inc": {"available_credits": -extra}, "$set": {"updated_at": now}}, + return_document=ReturnDocument.AFTER, + session=session, + ) + if not wallet: + current = self.wallets.find_one( + {"billing_account_id": account_id}, + session=session, + ) or {} + raise InsufficientCredits(extra, int(current.get("available_credits") or 0)) + + self._consume_lots(account_id, final_amount, session=session) + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"reserved_credits": -reserved, "available_credits": refund}, + "$set": {"updated_at": now}, + }, + session=session, + ) + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id, "status": "committing"}, + { + "$set": { + "status": "committed", + "final_credits": final_amount, + "updated_at": now, + } + }, + session=session, + ) + self.ledger.insert_one(entry, session=session) + if refund_entry: + self.ledger.insert_one(refund_entry, session=session) + + return entry + + def _claim_reservation_for_commit( + self, + account_id: str, + job_id: str, + final_amount: int, + now: datetime, + ) -> dict[str, Any]: + if self._in_memory: + reservation = self.get_reservation(job_id) + if not reservation: + raise BillingStoreError(f"No credit reservation exists for job {job_id}") + if reservation.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + if reservation.get("status") == "committed": + existing = self.find_ledger_by_key(f"debit:{job_id}") + if existing: + return existing + raise BillingStoreError(f"Reservation for job {job_id} is already committed") + if reservation.get("status") != "active": + raise BillingStoreError(f"Reservation for job {job_id} is not active") + _memory_reservations[job_id]["status"] = "committing" + _memory_reservations[job_id]["final_credits"] = final_amount + _memory_reservations[job_id]["updated_at"] = now + return reservation + + from pymongo import ReturnDocument + + reservation = self.reservations.find_one_and_update( + {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, + { + "$set": { + "status": "committing", + "final_credits": final_amount, + "updated_at": now, + } + }, + return_document=ReturnDocument.BEFORE, + ) + if reservation: + return _without_id(reservation) or {} + + existing = self.find_ledger_by_key(f"debit:{job_id}") + if existing: + return existing + current = self.get_reservation(job_id) + if current and current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + raise BillingStoreError(f"Reservation for job {job_id} is not active") + + def _release_commit_claim(self, account_id: str, job_id: str) -> None: + if self._in_memory: + reservation = _memory_reservations.get(job_id) + if reservation and reservation.get("billing_account_id") == account_id: + reservation["status"] = "active" + reservation.pop("final_credits", None) + reservation["updated_at"] = utc_now() + return + self.reservations.update_one( + {"job_id": job_id, "billing_account_id": account_id, "status": "committing"}, + { + "$set": {"status": "active", "updated_at": utc_now()}, + "$unset": {"final_credits": ""}, + }, + ) + + def release_reservation( + self, + *, + account_id: str, + job_id: str, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[dict[str, Any]]: + now = utc_now() + if self._in_memory: + reservation = _memory_reservations.get(job_id) + if not reservation or reservation.get("status") != "active": + return dict(reservation) if reservation else None + if reservation.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + reservation["status"] = "released" + reservation["updated_at"] = now + amount = int(reservation.get("reserved_credits") or 0) + _memory_wallets[account_id]["available_credits"] += amount + _memory_wallets[account_id]["reserved_credits"] -= amount + _memory_wallets[account_id]["updated_at"] = now + else: + from pymongo import ReturnDocument + + with self._client.start_session() as session: + with session.start_transaction(): + reservation_doc = self.reservations.find_one_and_update( + {"job_id": job_id, "billing_account_id": account_id, "status": "active"}, + {"$set": {"status": "released", "updated_at": now}}, + return_document=ReturnDocument.BEFORE, + session=session, + ) + if not reservation_doc: + current = self.reservations.find_one({"job_id": job_id}, session=session) + current = _without_id(current) + if current and current.get("billing_account_id") != account_id: + raise BillingStoreError( + f"Reservation for job {job_id} belongs to a different billing account" + ) + return current + reservation = _without_id(reservation_doc) or {} + amount = int(reservation.get("reserved_credits") or 0) + self.wallets.update_one( + {"billing_account_id": account_id}, + { + "$inc": {"available_credits": amount, "reserved_credits": -amount}, + "$set": {"updated_at": now}, + }, + session=session, + ) + self.ledger.insert_one( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "release", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"release:{job_id}:{reservation.get('version', 1)}", + "metadata": metadata or {}, + "created_at": now, + }, + session=session, + ) + return self.get_reservation(job_id) + self._insert_ledger( + { + "id": uuid.uuid4().hex, + "billing_account_id": account_id, + "type": "release", + "amount": amount, + "job_id": job_id, + "idempotency_key": f"release:{job_id}:{reservation.get('version', 1)}", + "metadata": metadata or {}, + "created_at": now, + } + ) + return self.get_reservation(job_id) + + def _consume_lots(self, account_id: str, amount: int, *, session: Any = None) -> None: + remaining = amount + now = utc_now() + while remaining > 0: + lots = list(self.active_lots(account_id, session=session)) + if not lots: + break + progressed = False + for lot in lots: + if remaining <= 0: + break + take = min(remaining, int(lot.get("remaining_credits") or 0)) + if take <= 0: + continue + if self._in_memory: + _memory_lots[lot["id"]]["remaining_credits"] -= take + _memory_lots[lot["id"]]["updated_at"] = now + remaining -= take + progressed = True + continue + result = self.lots.update_one( + {"id": lot["id"], "remaining_credits": {"$gte": take}}, + { + "$inc": {"remaining_credits": -take}, + "$set": {"updated_at": now}, + }, + session=session, + ) + if getattr(result, "modified_count", 0) == 1: + remaining -= take + progressed = True + if not progressed: + break + if remaining > 0: + raise BillingStoreError( + f"Wallet had credits but credit lots were short by {remaining}" + ) + + def active_lots(self, account_id: str, *, session: Any = None) -> Iterable[dict[str, Any]]: + now = utc_now() + if self._in_memory: + lots = [ + dict(lot) + for lot in _memory_lots.values() + if lot["billing_account_id"] == account_id + and int(lot.get("remaining_credits") or 0) > 0 + and not _is_expired(lot, now) + ] + return sorted(lots, key=lambda item: item.get("expires_at") or datetime.max.replace(tzinfo=timezone.utc)) + cursor = self.lots.find( + { + "billing_account_id": account_id, + "remaining_credits": {"$gt": 0}, + "$or": [{"expires_at": None}, {"expires_at": {"$gt": now}}], + }, + session=session, + ).sort("expires_at", 1) + return [_without_id(lot) or {} for lot in cursor] + + def find_ledger_by_key(self, idempotency_key: str) -> Optional[dict[str, Any]]: + if self._in_memory: + entry = _memory_ledger.get(idempotency_key) + return dict(entry) if entry else None + return _without_id(self.ledger.find_one({"idempotency_key": idempotency_key})) + + def has_payment_event(self, event_id: str) -> bool: + if not event_id: + return False + if self._in_memory: + return event_id in _memory_payment_events + return self.payments.find_one({"razorpay_event_id": event_id}) is not None + + def list_ledger(self, account_id: str, limit: int = 100) -> list[dict[str, Any]]: + if self._in_memory: + entries = [ + dict(entry) + for entry in _memory_ledger.values() + if entry.get("billing_account_id") == account_id + ] + return sorted(entries, key=lambda item: item["created_at"], reverse=True)[:limit] + return [ + _without_id(entry) or {} + for entry in self.ledger.find({"billing_account_id": account_id}) + .sort("created_at", -1) + .limit(limit) + ] + + def record_usage_event(self, event: dict[str, Any]) -> None: + payload = {"id": uuid.uuid4().hex, "created_at": utc_now(), **event} + if self._in_memory: + _memory_usage_events.append(payload) + return + self.usage_events.insert_one(payload) + + def save_checkout(self, checkout_id: str, payload: dict[str, Any]) -> None: + now = utc_now() + doc = {"id": checkout_id, **payload, "updated_at": now} + if self._in_memory: + _memory_checkouts[checkout_id] = doc + return + self.payments.update_one({"id": checkout_id}, {"$set": doc, "$setOnInsert": {"created_at": now}}, upsert=True) + + def get_checkout(self, checkout_id: str) -> Optional[dict[str, Any]]: + if self._in_memory: + return dict(_memory_checkouts[checkout_id]) if checkout_id in _memory_checkouts else None + return _without_id(self.payments.find_one({"id": checkout_id})) + + def mark_payment_event(self, event_id: str, payload: dict[str, Any]) -> bool: + if not event_id: + raise ValueError("Razorpay webhook event id is required") + now = utc_now() + if self._in_memory: + if event_id in _memory_payment_events: + return False + _memory_payment_events[event_id] = {"razorpay_event_id": event_id, **payload, "created_at": now} + return True + try: + self.payments.insert_one({"razorpay_event_id": event_id, **payload, "created_at": now}) + return True + except Exception as exc: + if _is_duplicate_key_error(exc): + return False + raise + + +_default_store: Optional[BillingStore] = None + + +def get_default_billing_store() -> BillingStore: + global _default_store + if _default_store is None: + _default_store = BillingStore() + return _default_store diff --git a/src/billing/types.py b/src/billing/types.py new file mode 100644 index 00000000..0eac7427 --- /dev/null +++ b/src/billing/types.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Literal, Optional + +from pydantic import BaseModel, Field + + +class PlanPublic(BaseModel): + id: str + name: str + price_paise: int + currency: str = "INR" + monthly_credits: int = 0 + trial_credits: int = 0 + trial_days: int = 0 + nominal_paise_per_credit: float = 0.0 + + +class TopUpPackPublic(BaseModel): + id: str + price_paise: int + currency: str = "INR" + credits: int + + +class CreditLotPublic(BaseModel): + id: str + source: str + remaining_credits: int + expires_at: Optional[datetime] = None + + +class BillingSummary(BaseModel): + billing_account_id: str + owner_type: str = "user" + owner_id: str + plan_id: str + plan_name: str + status: str + currency: str = "INR" + available_credits: int = 0 + reserved_credits: int = 0 + current_period_start: Optional[datetime] = None + current_period_end: Optional[datetime] = None + credit_lots: list[CreditLotPublic] = Field(default_factory=list) + + +class CreditEstimate(BaseModel): + job_type: str + content_tokens: int + multiplier: float + billable_credits: int + reserved_credits: int + + +class ReservationResult(BaseModel): + reservation_id: str + billing_account_id: str + job_id: str + reserved_credits: int + status: Literal["active", "committed", "released", "expired"] + available_credits: int + created: bool = False + + +class CheckoutRequest(BaseModel): + package_id: str = Field(..., description="Plan ID or top-up pack ID") + + +class CheckoutResponse(BaseModel): + id: str + package_id: str + amount: int + currency: str + key_id: str + order_id: Optional[str] = None + subscription_id: Optional[str] = None + receipt: Optional[str] = None + + +class VerifyPaymentRequest(BaseModel): + package_id: str + razorpay_payment_id: str + razorpay_signature: str + razorpay_order_id: Optional[str] = None + razorpay_subscription_id: Optional[str] = None + + +class LedgerEntryPublic(BaseModel): + id: str + type: str + amount: int + idempotency_key: str + job_id: Optional[str] = None + source: Optional[str] = None + metadata: dict[str, Any] = Field(default_factory=dict) + created_at: datetime diff --git a/src/config/settings.py b/src/config/settings.py index 78630cee..d227259b 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -473,6 +473,10 @@ class Settings(BaseSettings): default=None, description="Optional Razorpay webhook signing secret" ) + razorpay_pro_plan_id: Optional[str] = Field( + default=None, + description="Razorpay subscription plan ID for the Pro plan", + ) @field_validator("fallback_order") @classmethod diff --git a/src/utils/billing.py b/src/utils/billing.py new file mode 100644 index 00000000..921c4739 --- /dev/null +++ b/src/utils/billing.py @@ -0,0 +1,72 @@ +"""Editable billing knobs for XMem. + +Change this file when plan pricing, monthly credits, top-up packs, or +workflow credit consumption rules need to change. The billing service imports +these values at runtime so the rest of the implementation can stay stable. +""" + +from __future__ import annotations + +import math +from typing import Any, Mapping + +PLANS: dict[str, dict[str, Any]] = { + "free": { + "name": "Free Trial", + "price_paise": 0, + "currency": "INR", + "trial_credits": 10_000, + "trial_days": 30, + "monthly_credits": 0, + }, + "pro": { + "name": "Pro", + "price_paise": 9_900, + "currency": "INR", + "monthly_credits": 5_000, + }, +} + +TOP_UP_PACKS: dict[str, dict[str, Any]] = { + "topup_99": {"price_paise": 9_900, "credits": 5_000, "currency": "INR"}, + "topup_199": {"price_paise": 19_900, "credits": 12_000, "currency": "INR"}, + "topup_499": {"price_paise": 49_900, "credits": 35_000, "currency": "INR"}, +} + +WORKFLOW_MULTIPLIERS: dict[str, float] = { + "memory_ingest_low": 1.0, + "memory_ingest_standard": 1.5, + "memory_ingest_high": 2.5, + "memory_batch_ingest": 1.5, + "memory_retrieve": 0.5, +} + +RESERVATION_BUFFER_MULTIPLIER = 1.25 +TOKEN_ESTIMATE_CHARS_PER_TOKEN = 4 +TOP_UP_EXPIRY_DAYS = 365 + + +def estimate_tokens(text: str) -> int: + """Approximate billable tokens when provider usage is not available.""" + if not text: + return 0 + return max(1, math.ceil(len(text) / TOKEN_ESTIMATE_CHARS_PER_TOKEN)) + + +def workflow_multiplier(job_type: str, payload: Mapping[str, Any]) -> float: + if job_type == "memory_ingest": + effort = str(payload.get("effort_level") or "low").strip().lower() + if effort == "high": + return WORKFLOW_MULTIPLIERS["memory_ingest_high"] + if effort in {"standard", "medium"}: + return WORKFLOW_MULTIPLIERS["memory_ingest_standard"] + return WORKFLOW_MULTIPLIERS["memory_ingest_low"] + return WORKFLOW_MULTIPLIERS.get(job_type, 1.0) + + +def nominal_paise_per_credit(plan_id: str) -> float: + plan = PLANS[plan_id] + credits = int(plan.get("monthly_credits") or plan.get("trial_credits") or 0) + if credits <= 0: + return 0.0 + return float(plan["price_paise"]) / credits diff --git a/tests/api/test_billing_routes.py b/tests/api/test_billing_routes.py new file mode 100644 index 00000000..39eb62b5 --- /dev/null +++ b/tests/api/test_billing_routes.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import pytest + +pytest.importorskip("fastapi") + +from src.api.routes import billing +from src.billing.types import VerifyPaymentRequest + + +@pytest.mark.asyncio +async def test_subscription_verify_checks_signature_before_checkout_lookup(monkeypatch): + class Store: + def get_checkout(self, checkout_id): + raise AssertionError("checkout lookup should not run before signature verification") + + class Service: + store = Store() + + monkeypatch.setattr(billing, "get_default_billing_service", lambda: Service()) + monkeypatch.setattr(billing, "verify_subscription_signature", lambda *args: False) + + with pytest.raises(billing.HTTPException) as exc: + await billing.verify_razorpay_payment( + VerifyPaymentRequest( + package_id="pro", + razorpay_payment_id="pay_1", + razorpay_signature="bad", + razorpay_subscription_id="sub_1", + ), + current_user={"id": "user-1"}, + ) + + assert exc.value.status_code == 400 + assert exc.value.detail == "Invalid Razorpay signature" diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index 8c45bc7f..d7e1f900 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -7,7 +7,9 @@ from src.api import dependencies as deps from src.api.routes import memory +from src.api.routes.v2 import jobs as jobs_v2 from src.api.routes.v2 import memory as memory_v2 +from src.jobs import durable class FakeIngestPipeline: @@ -54,6 +56,23 @@ def mark_failed(self, job_id, error): job["error"] = error return "failed" + def mark_cancelled(self, job_id): + job = self.jobs[job_id] + job["status"] = "cancelled" + job["cancelled_at"] = "now" + job["completed_at"] = "now" + + def reset_for_retry(self, job_id, clear_workflow=False): + job = self.jobs[job_id] + job["status"] = "queued" + job["error"] = None + if clear_workflow: + job["workflow_id"] = None + job["run_id"] = None + + def update_payload(self, job_id, payload): + self.jobs[job_id]["payload"] = payload + def reserve_workflow_start(self, job_id, workflow_id): job = self.jobs[job_id] if job["status"] != "queued" or job.get("workflow_id"): @@ -86,6 +105,7 @@ async def fake_rate_limit(): app.include_router(memory.router) app.include_router(memory_v2.scrape_router) app.include_router(memory_v2.router) + app.include_router(jobs_v2.router) return app, ingest @@ -182,6 +202,249 @@ async def fake_start_job_workflow(job): assert ingest.calls == [] +def test_v2_retry_start_failure_releases_fresh_billing_reservation(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + released = [] + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=True) + + async def fake_start_job_workflow(job): + raise RuntimeError("temporal unavailable") + + def fake_release(account_id, job_id): + released.append((account_id, job_id)) + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr(jobs_v2, "release_job_reservation", fake_release) + monkeypatch.setattr(jobs_v2, "start_job_workflow", fake_start_job_workflow) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert released == [("acct-1", "job-1")] + assert store.jobs["job-1"]["status"] == "failed" + assert store.jobs["job-1"]["error"] == "temporal unavailable" + + +def test_v2_retry_payload_update_failure_releases_fresh_billing_reservation(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + released = [] + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=True) + + def fail_update_payload(job_id, payload): + raise RuntimeError("payload write failed") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr(store, "update_payload", fail_update_payload) + monkeypatch.setattr( + jobs_v2, + "release_job_reservation", + lambda account_id, job_id: released.append((account_id, job_id)), + ) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert released == [("acct-1", "job-1")] + assert store.jobs["job-1"]["status"] == "failed" + assert store.jobs["job-1"]["error"] == "payload write failed" + + +def test_v2_retry_release_failure_still_marks_job_failed(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=True) + + async def fake_start_job_workflow(job): + raise RuntimeError("temporal unavailable") + + def fail_release(account_id, job_id): + raise RuntimeError("mongo unavailable") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr(jobs_v2, "release_job_reservation", fail_release) + monkeypatch.setattr(jobs_v2, "start_job_workflow", fake_start_job_workflow) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert store.jobs["job-1"]["status"] == "failed" + assert store.jobs["job-1"]["error"] == ( + "temporal unavailable; billing reservation release failed: mongo unavailable" + ) + + +def test_v2_retry_start_failure_keeps_reused_billing_reservation(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "failed", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 1, + "attempt_count": 1, + "workflow_id": "old-workflow", + } + released = [] + + class FakeEstimate: + reserved_credits = 100 + + def model_dump(self): + return {"reserved_credits": self.reserved_credits} + + class FakeBillingService: + def estimate_required_credits(self, job_type, payload): + return FakeEstimate() + + def reserve_credits(self, account_id, job_id, estimated_credits): + return SimpleNamespace(reservation_id="reservation-1", created=False) + + async def fake_start_job_workflow(job): + raise RuntimeError("temporal unavailable") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "get_default_billing_service", lambda: FakeBillingService()) + monkeypatch.setattr( + jobs_v2, + "release_job_reservation", + lambda account_id, job_id: released.append((account_id, job_id)), + ) + monkeypatch.setattr(jobs_v2, "start_job_workflow", fake_start_job_workflow) + + response = TestClient(app).post("/v2/jobs/job-1/retry") + + assert response.status_code == 503 + assert released == [] + assert store.jobs["job-1"]["status"] == "failed" + + +def test_v2_cancel_mark_failure_keeps_billing_reserved_after_signal(monkeypatch): + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + store.jobs["job-1"] = { + "job_id": "job-1", + "job_type": "memory_ingest", + "payload": {"billing_account_id": "acct-1", "user_id": "hunter"}, + "user_id": "hunter", + "status": "running", + "timeout_seconds": 30, + "max_attempts": 3, + "retry_count": 0, + "attempt_count": 1, + "workflow_id": "workflow-1", + } + released = [] + cancelled = [] + + async def fake_cancel_job_workflow(job): + cancelled.append(job["job_id"]) + + def fail_mark_cancelled(job_id): + raise RuntimeError("cancel status write failed") + + monkeypatch.setattr(jobs_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(durable, "get_default_job_store", lambda: store) + monkeypatch.setattr(jobs_v2, "cancel_job_workflow", fake_cancel_job_workflow) + monkeypatch.setattr(store, "mark_cancelled", fail_mark_cancelled) + monkeypatch.setattr( + jobs_v2, + "release_job_billing", + lambda job, reason: released.append((job["job_id"], reason)), + ) + + response = TestClient(app).post("/v2/jobs/job-1/cancel") + + assert response.status_code == 503 + assert cancelled == ["job-1"] + assert released == [] + assert store.jobs["job-1"]["status"] == "running" + + def test_v1_batch_ingest_scopes_each_item_for_local_static_key(monkeypatch): monkeypatch.setattr(memory.settings, "environment", "development", raising=False) static_user = {"id": "static-key", "name": "Static Key User", "email": "static@xmem.ai"} diff --git a/tests/test_billing.py b/tests/test_billing.py new file mode 100644 index 00000000..295431fc --- /dev/null +++ b/tests/test_billing.py @@ -0,0 +1,201 @@ +import os +from datetime import timedelta + +os.environ.setdefault("APP_STORE_PROVIDER", "memory") +os.environ.setdefault("FALLBACK_ORDER", '["ollama"]') +os.environ.setdefault("NEO4J_PASSWORD", "test") + +import pytest + +from src.billing import store as billing_store +from src.billing.service import BillingService +from src.billing.store import BillingStore, BillingStoreError, InsufficientCredits, utc_now +from src.utils import billing as billing_config + + +@pytest.fixture(autouse=True) +def clear_memory_billing(): + billing_store._memory_accounts.clear() + billing_store._memory_wallets.clear() + billing_store._memory_lots.clear() + billing_store._memory_ledger.clear() + billing_store._memory_reservations.clear() + billing_store._memory_usage_events.clear() + billing_store._memory_checkouts.clear() + billing_store._memory_payment_events.clear() + + +def service() -> BillingService: + return BillingService(BillingStore()) + + +def test_token_estimate_uses_configured_chars_per_token(): + assert billing_config.estimate_tokens("a" * 9) == 3 + + +def test_pro_nominal_credit_value(): + assert billing_config.nominal_paise_per_credit("pro") == pytest.approx(1.98) + + +def test_free_trial_grant_is_idempotent(): + svc = service() + user = {"id": "user_1"} + + first = svc.ensure_billing_account(user) + second = svc.ensure_billing_account(user) + + assert first["id"] == second["id"] + summary = svc.get_billing_summary(user) + assert summary.available_credits == 10_000 + grants = [entry for entry in svc.list_ledger(user) if entry["type"] == "grant"] + assert len(grants) == 1 + + +def test_reservation_debit_and_refund_flow(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + reservation = svc.reserve_credits(account["id"], "job_1", 1000) + assert reservation.available_credits == 9000 + + svc.commit_job_debit(account["id"], "job_1", 750) + summary = svc.get_billing_summary(user) + + assert summary.available_credits == 9250 + assert summary.reserved_credits == 0 + ledger_types = {entry["type"] for entry in svc.list_ledger(user)} + assert {"reserve", "debit", "refund"}.issubset(ledger_types) + + +def test_reused_reservation_is_marked_not_created(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + first = svc.reserve_credits(account["id"], "job_1", 1000) + second = svc.reserve_credits(account["id"], "job_1", 1000) + + summary = svc.get_billing_summary(user) + assert first.created + assert not second.created + assert summary.available_credits == 9000 + assert summary.reserved_credits == 1000 + + +def test_failed_job_releases_reserved_credits(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + svc.reserve_credits(account["id"], "job_1", 1000) + svc.release_job_reservation(account["id"], "job_1") + + summary = svc.get_billing_summary(user) + assert summary.available_credits == 10_000 + assert summary.reserved_credits == 0 + + +def test_duplicate_release_does_not_double_refund(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + svc.reserve_credits(account["id"], "job_1", 1000) + svc.release_job_reservation(account["id"], "job_1") + svc.release_job_reservation(account["id"], "job_1") + + summary = svc.get_billing_summary(user) + releases = [entry for entry in svc.list_ledger(user) if entry["type"] == "release"] + assert summary.available_credits == 10_000 + assert summary.reserved_credits == 0 + assert len(releases) == 1 + + +def test_insufficient_credits_blocks_reservation(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + + with pytest.raises(InsufficientCredits): + svc.reserve_credits(account["id"], "job_1", 10_001) + + +def test_reservation_cannot_be_reused_by_another_account(): + svc = service() + account_1 = svc.ensure_billing_account({"id": "user_1"}) + account_2 = svc.ensure_billing_account({"id": "user_2"}) + + svc.reserve_credits(account_1["id"], "job_1", 100) + + with pytest.raises(BillingStoreError): + svc.reserve_credits(account_2["id"], "job_1", 100) + + +def test_duplicate_commit_does_not_double_debit(): + svc = service() + user = {"id": "user_1"} + account = svc.ensure_billing_account(user) + svc.reserve_credits(account["id"], "job_1", 1000) + + svc.commit_job_debit(account["id"], "job_1", 750) + svc.commit_job_debit(account["id"], "job_1", 750) + + summary = svc.get_billing_summary(user) + debits = [entry for entry in svc.list_ledger(user) if entry["type"] == "debit"] + assert summary.available_credits == 9250 + assert len(debits) == 1 + + +def test_pro_grant_is_idempotent_per_payment(): + svc = service() + + svc.grant_pro_subscription( + user_id="user_1", + payment_id="pay_1", + subscription_id="sub_1", + period_end=utc_now() + timedelta(days=30), + ) + svc.grant_pro_subscription( + user_id="user_1", + payment_id="pay_1", + subscription_id="sub_1", + period_end=utc_now() + timedelta(days=30), + ) + + summary = svc.get_billing_summary({"id": "user_1"}) + assert summary.available_credits == 15_000 + pro_grants = [ + entry + for entry in svc.list_ledger({"id": "user_1"}) + if entry["source"] == "pro_monthly" + ] + assert len(pro_grants) == 1 + + +def test_billing_config_changes_affect_estimates(monkeypatch): + svc = service() + payload = {"user_query": "a" * 400, "agent_response": "", "effort_level": "low"} + baseline = svc.estimate_required_credits("memory_ingest", payload) + + monkeypatch.setitem(billing_config.WORKFLOW_MULTIPLIERS, "memory_ingest_low", 2.0) + changed = svc.estimate_required_credits("memory_ingest", payload) + + assert changed.billable_credits == baseline.billable_credits * 2 + + +def test_missing_payment_event_id_is_rejected(): + store = BillingStore() + + with pytest.raises(ValueError): + store.mark_payment_event("", {"event": "payment.captured"}) + + +def test_in_memory_checkout_and_webhook_events_are_isolated(): + store = BillingStore() + + store.save_checkout("same_id", {"user_id": "user_1", "package_id": "topup_99"}) + + assert store.mark_payment_event("same_id", {"event": "payment.captured"}) + assert not store.mark_payment_event("same_id", {"event": "payment.captured"}) + assert store.get_checkout("same_id")["package_id"] == "topup_99"