diff --git a/src/api/routes/billing.py b/src/api/routes/billing.py index d34a76e..d6feb94 100644 --- a/src/api/routes/billing.py +++ b/src/api/routes/billing.py @@ -82,6 +82,19 @@ def _pack_or_plan(package_id: str) -> tuple[str, dict[str, Any]]: raise HTTPException(status_code=400, detail="Unknown billing package") +def _checkout_package(package_id: str, package: dict[str, Any], region: str) -> dict[str, Any]: + checkout_package = dict(package) + if package_id in billing_config.PLANS: + checkout_package.update(billing_config.plan_price(package_id, region)) + return checkout_package + + +def _pro_plan_id_for_region(region: str) -> str | None: + if region == billing_config.BILLING_REGION_GLOBAL: + return settings.razorpay_global_pro_plan_id + return settings.razorpay_pro_plan_id + + @router.get("/plans", response_model=list[PlanPublic]) async def list_billing_plans() -> list[PlanPublic]: return public_plans() @@ -110,6 +123,8 @@ async def create_razorpay_checkout( user_id = _user_id(current_user) package_type, package = _pack_or_plan(request.package_id) + billing_region = billing_config.normalize_billing_region(request.billing_region) + checkout_package = _checkout_package(request.package_id, package, billing_region) service = get_default_billing_service() account = await asyncio.to_thread(service.ensure_billing_account, current_user) @@ -121,13 +136,15 @@ async def create_razorpay_checkout( "billing_account_id": account["id"], "package_id": request.package_id, "package_type": package_type, + "billing_region": billing_region, } receipt = _receipt(user_id, request.package_id) try: - if request.package_id == "pro" and settings.razorpay_pro_plan_id: + pro_plan_id = _pro_plan_id_for_region(billing_region) + if request.package_id == "pro" and pro_plan_id: subscription = await create_subscription( - plan_id=settings.razorpay_pro_plan_id, + plan_id=pro_plan_id, notes=notes, ) checkout_id = str(subscription["id"]) @@ -139,6 +156,7 @@ async def create_razorpay_checkout( "user_id": user_id, "billing_account_id": account["id"], "package_id": request.package_id, + "billing_region": billing_region, "subscription_id": checkout_id, "status": "created", }, @@ -147,16 +165,16 @@ async def create_razorpay_checkout( id=checkout_id, subscription_id=checkout_id, package_id=request.package_id, - amount=int(package["price_paise"]), - currency=str(package.get("currency") or "INR"), + amount=int(checkout_package["price_minor_unit"]), + currency=str(checkout_package.get("currency") or "INR"), key_id=key_id, receipt=receipt, ) - amount = int(package["price_paise"]) + amount = int(checkout_package.get("price_minor_unit") or checkout_package["price_paise"]) order = await create_order( amount_paise=amount, - currency=str(package.get("currency") or "INR"), + currency=str(checkout_package.get("currency") or "INR"), receipt=receipt, notes=notes, ) @@ -172,9 +190,10 @@ async def create_razorpay_checkout( "user_id": user_id, "billing_account_id": account["id"], "package_id": request.package_id, + "billing_region": billing_region, "order_id": order_id, "amount": amount, - "currency": str(package.get("currency") or "INR"), + "currency": str(checkout_package.get("currency") or "INR"), "status": "created", }, ) @@ -183,7 +202,7 @@ async def create_razorpay_checkout( order_id=order_id, package_id=request.package_id, amount=amount, - currency=str(package.get("currency") or "INR"), + currency=str(checkout_package.get("currency") or "INR"), key_id=key_id, receipt=receipt, ) diff --git a/src/billing/service.py b/src/billing/service.py index e555680..064ee66 100644 --- a/src/billing/service.py +++ b/src/billing/service.py @@ -43,6 +43,7 @@ def public_plans() -> list[PlanPublic]: trial_credits=int(plan.get("trial_credits") or 0), trial_days=int(plan.get("trial_days") or 0), nominal_paise_per_credit=billing_config.nominal_paise_per_credit(plan_id), + regional_prices=billing_config.plan_price_options(plan_id), ) for plan_id, plan in billing_config.PLANS.items() ] diff --git a/src/billing/types.py b/src/billing/types.py index 0eac742..fdf8754 100644 --- a/src/billing/types.py +++ b/src/billing/types.py @@ -6,6 +6,11 @@ from pydantic import BaseModel, Field +class PlanPricePublic(BaseModel): + price_minor_unit: int + currency: str = "INR" + + class PlanPublic(BaseModel): id: str name: str @@ -15,6 +20,7 @@ class PlanPublic(BaseModel): trial_credits: int = 0 trial_days: int = 0 nominal_paise_per_credit: float = 0.0 + regional_prices: dict[str, PlanPricePublic] = Field(default_factory=dict) class TopUpPackPublic(BaseModel): @@ -66,6 +72,13 @@ class ReservationResult(BaseModel): class CheckoutRequest(BaseModel): package_id: str = Field(..., description="Plan ID or top-up pack ID") + billing_region: Optional[str] = Field( + default=None, + description=( + "Client billing-region hint, e.g. IN for India or GLOBAL for non-India " + "pricing. Blank or missing hints use global pricing." + ), + ) class CheckoutResponse(BaseModel): @@ -85,6 +98,7 @@ class VerifyPaymentRequest(BaseModel): razorpay_signature: str razorpay_order_id: Optional[str] = None razorpay_subscription_id: Optional[str] = None + billing_region: Optional[str] = None class LedgerEntryPublic(BaseModel): diff --git a/src/config/settings.py b/src/config/settings.py index b3afff5..acc7fc0 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -499,7 +499,11 @@ class Settings(BaseSettings): ) razorpay_pro_plan_id: Optional[str] = Field( default=None, - description="Razorpay subscription plan ID for the Pro plan", + description="Razorpay subscription plan ID for the India Pro plan", + ) + razorpay_global_pro_plan_id: Optional[str] = Field( + default=None, + description="Razorpay subscription plan ID for the global USD Pro plan", ) @field_validator("fallback_order") diff --git a/src/utils/billing.py b/src/utils/billing.py index 0adb763..25372ac 100644 --- a/src/utils/billing.py +++ b/src/utils/billing.py @@ -27,6 +27,26 @@ }, } +BILLING_REGION_IN = "IN" +BILLING_REGION_GLOBAL = "GLOBAL" +BILLING_REGION_ALIASES = { + "IN": BILLING_REGION_IN, + "IND": BILLING_REGION_IN, + "INDIA": BILLING_REGION_IN, + "GLOBAL": BILLING_REGION_GLOBAL, + "US": BILLING_REGION_GLOBAL, + "USD": BILLING_REGION_GLOBAL, + "INTERNATIONAL": BILLING_REGION_GLOBAL, + "WORLD": BILLING_REGION_GLOBAL, +} + +PLAN_REGIONAL_PRICES: dict[str, dict[str, dict[str, Any]]] = { + "pro": { + BILLING_REGION_IN: {"price_minor_unit": 9_900, "currency": "INR"}, + BILLING_REGION_GLOBAL: {"price_minor_unit": 300, "currency": "USD"}, + }, +} + TOP_UP_PACKS: dict[str, dict[str, Any]] = { "topup_99": {"price_paise": 9_900, "credits": 5_000, "currency": "INR"}, "topup_199": {"price_paise": 19_900, "credits": 12_000, "currency": "INR"}, @@ -75,6 +95,42 @@ def workflow_multiplier(job_type: str, payload: Mapping[str, Any]) -> float: return WORKFLOW_MULTIPLIERS.get(job_type, 1.0) +def normalize_billing_region(region: str | None) -> str: + # Client-provided region is only a pricing hint; blank hints use global + # pricing to avoid undercharging when a client cannot derive location. + if not region or not region.strip(): + return BILLING_REGION_GLOBAL + return BILLING_REGION_ALIASES.get(region.strip().upper(), BILLING_REGION_GLOBAL) + + +def plan_price(plan_id: str, region: str | None = None) -> dict[str, Any]: + plan = PLANS[plan_id] + normalized_region = normalize_billing_region(region) + regional_price = PLAN_REGIONAL_PRICES.get(plan_id, {}).get(normalized_region) + if not regional_price: + return { + "price_minor_unit": int(plan.get("price_paise") or 0), + "currency": str(plan.get("currency") or "INR"), + } + return { + "price_minor_unit": int(regional_price["price_minor_unit"]), + "currency": str(regional_price.get("currency") or plan.get("currency") or "INR"), + } + + +def plan_price_options(plan_id: str) -> dict[str, dict[str, Any]]: + options = PLAN_REGIONAL_PRICES.get(plan_id) + if not options: + return {} + return { + region: { + "price_minor_unit": int(price["price_minor_unit"]), + "currency": str(price.get("currency") or "INR"), + } + for region, price in options.items() + } + + def nominal_paise_per_credit(plan_id: str) -> float: plan = PLANS[plan_id] credits = int(plan.get("monthly_credits") or plan.get("trial_credits") or 0) diff --git a/tests/unit/test_billing_config.py b/tests/unit/test_billing_config.py new file mode 100644 index 0000000..81f1e92 --- /dev/null +++ b/tests/unit/test_billing_config.py @@ -0,0 +1,29 @@ +from src.utils import billing + + +def test_pro_plan_has_india_and_global_prices() -> None: + assert billing.plan_price("pro", "IN") == { + "price_minor_unit": 9_900, + "currency": "INR", + } + assert billing.plan_price("pro", "GLOBAL") == { + "price_minor_unit": 300, + "currency": "USD", + } + assert billing.PLANS["pro"]["monthly_credits"] == 5_000 + + +def test_billing_region_defaults_to_global_and_unknowns_are_global() -> None: + assert billing.normalize_billing_region(None) == billing.BILLING_REGION_GLOBAL + assert billing.normalize_billing_region("") == billing.BILLING_REGION_GLOBAL + assert billing.normalize_billing_region(" ") == billing.BILLING_REGION_GLOBAL + assert billing.normalize_billing_region("india") == billing.BILLING_REGION_IN + assert billing.normalize_billing_region("outside-india") == billing.BILLING_REGION_GLOBAL + assert billing.normalize_billing_region("UK") == billing.BILLING_REGION_GLOBAL + + +def test_plan_price_options_are_serializable() -> None: + assert billing.plan_price_options("pro") == { + "IN": {"price_minor_unit": 9_900, "currency": "INR"}, + "GLOBAL": {"price_minor_unit": 300, "currency": "USD"}, + }