diff --git a/src/runpod_flash/cli/commands/build.py b/src/runpod_flash/cli/commands/build.py index 2146ad31..65414a97 100644 --- a/src/runpod_flash/cli/commands/build.py +++ b/src/runpod_flash/cli/commands/build.py @@ -16,11 +16,6 @@ import typer from rich.console import Console -try: - import tomllib # Python 3.11+ -except ImportError: - import tomli as tomllib # Python 3.10 - from runpod_flash.cli.utils.formatting import print_error, print_warning from runpod_flash.core.resources.constants import MAX_TARBALL_SIZE_MB @@ -167,52 +162,6 @@ def _bundle_runpod_flash(build_dir: Path, flash_pkg: Path) -> None: logger.debug("bundled runpod_flash from %s", flash_pkg) -def _extract_runpod_flash_dependencies(flash_pkg_dir: Path) -> list[str]: - """Extract runtime dependencies from runpod_flash's pyproject.toml. - - When bundling local runpod_flash source, we need to also install its dependencies - so they're available in the build environment. - - Args: - flash_pkg_dir: Path to runpod_flash package directory (src/runpod_flash) - - Returns: - List of dependency strings, empty list if parsing fails - """ - try: - # Navigate from runpod_flash package to project root - # flash_pkg_dir is src/runpod_flash, need to go up 2 levels to reach project root - project_root = flash_pkg_dir.parent.parent - pyproject_path = project_root / "pyproject.toml" - - if not pyproject_path.exists(): - console.print( - "[yellow]⚠ runpod_flash pyproject.toml not found, " - "dependencies may be missing[/yellow]" - ) - return [] - - # Parse TOML - with open(pyproject_path, "rb") as f: - data = tomllib.load(f) - - # Extract dependencies from [project.dependencies] - dependencies = data.get("project", {}).get("dependencies", []) - - if dependencies: - console.print( - f"[dim]Found {len(dependencies)} runpod_flash dependencies to install[/dim]" - ) - - return dependencies - - except Exception as e: - console.print( - f"[yellow]⚠ Failed to parse runpod_flash dependencies: {e}[/yellow]" - ) - return [] - - def _normalize_package_name(name: str) -> str: """Normalize a package name for comparison (PEP 503: lowercase, hyphens to underscores).""" return name.lower().replace("-", "_") @@ -1085,17 +1034,6 @@ def _is_excluded_top_dir(top_dir: str) -> bool: tar.add(str(item), arcname=arcname) -def cleanup_build_directory(build_base: Path) -> None: - """ - Remove build directory. - - Args: - build_base: .build directory to remove - """ - if build_base.exists(): - shutil.rmtree(build_base) - - def _display_build_summary( archive_path: Path, app_name: str, diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index 87cd0975..520792ec 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -1,10 +1,8 @@ """Builder for flash_manifest.json.""" import importlib.util -import json import logging import sys -from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional @@ -43,41 +41,6 @@ def _serialize_network_volume(nv) -> dict: return nv_config -@dataclass -class ManifestFunction: - """Function entry in manifest.""" - - name: str - module: str - is_async: bool - is_class: bool - http_method: Optional[str] = None # HTTP method for LB endpoints (GET, POST, etc.) - http_path: Optional[str] = None # HTTP path for LB endpoints (/api/process) - is_load_balanced: bool = False # Determined by isinstance() at scan time - is_live_resource: bool = False # LiveLoadBalancer vs LoadBalancerSlsResource - config_variable: Optional[str] = None # Variable name like "gpu_config" - - -@dataclass -class ManifestResource: - """Resource config entry in manifest.""" - - resource_type: str - functions: List[ManifestFunction] - is_load_balanced: bool = False # Determined by isinstance() at scan time - is_live_resource: bool = False # LiveLoadBalancer vs LoadBalancerSlsResource - config_variable: Optional[str] = None # Variable name for config discovery - is_load_balanced_endpoint: bool = False # Flag for load-balanced endpoint - is_explicit: bool = False # Flag indicating explicit load balancer configuration - main_file: Optional[str] = None # Filename of main entry point - app_variable: Optional[str] = None # Variable name of FastAPI app - imageName: Optional[str] = None # Docker image name for auto-provisioning - templateId: Optional[str] = None # RunPod template ID for auto-provisioning - gpuIds: Optional[list] = None # GPU types/IDs for auto-provisioning - workersMin: Optional[int] = None # Min worker count for auto-provisioning - workersMax: Optional[int] = None # Max worker count for auto-provisioning - - class ManifestBuilder: """Builds flash_manifest.json from discovered remote functions.""" @@ -571,9 +534,3 @@ def build(self) -> Dict[str, Any]: manifest["routes"] = routes_dict return manifest - - def write_to_file(self, output_path: Path) -> Path: - """Write manifest to file.""" - manifest = self.build() - output_path.write_text(json.dumps(manifest, indent=2)) - return output_path diff --git a/src/runpod_flash/cli/commands/resource.py b/src/runpod_flash/cli/commands/resource.py deleted file mode 100644 index 76025bef..00000000 --- a/src/runpod_flash/cli/commands/resource.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Resource management commands.""" - -import asyncio -import time -import typer -from rich.console import Console -from rich.live import Live - -from ...core.resources.resource_manager import ResourceManager - -console = Console() - - -def report_command( - live: bool = typer.Option(False, "--live", "-l", help="Live updating status"), - refresh: int = typer.Option( - 2, "--refresh", "-r", help="Refresh interval for live mode" - ), -): - """Show resource status dashboard.""" - - resource_manager = ResourceManager() - - if live: - try: - with Live( - _render_resource_report(resource_manager), - console=console, - refresh_per_second=1 / refresh, - screen=True, - ) as live_display: - while True: - time.sleep(refresh) - live_display.update(_render_resource_report(resource_manager)) - except KeyboardInterrupt: - console.print("\nStopped") - else: - output = _render_resource_report(resource_manager) - console.print(output) - - -def _render_resource_report(resource_manager: ResourceManager): - """Build a rich renderable for the current resource state.""" - from rich.text import Text - - resources = resource_manager._resources - - if not resources: - return Text("No resources tracked.") - - lines = Text() - lines.append("\nResources\n\n", style="bold") - - active_count = 0 - inactive_count = 0 - - for uid, resource in resources.items(): - try: - is_deployed = asyncio.run(resource.is_deployed()) - if is_deployed: - color, status_text = "green", "active" - active_count += 1 - else: - color, status_text = "red", "inactive" - inactive_count += 1 - except Exception: - color, status_text = "yellow", "unknown" - - resource_type = resource.__class__.__name__ - try: - url = resource.url if hasattr(resource, "url") else "" - except Exception: - url = "" - - display_uid = uid[:20] + "..." if len(uid) > 20 else uid - - lines.append(f" {display_uid}", style="bold") - lines.append(f" {status_text}", style=color) - lines.append(f" {resource_type}") - if url: - lines.append(f" {url}") - lines.append("\n") - - total = len(resources) - unknown_count = total - active_count - inactive_count - parts = [f"{active_count} active"] - if inactive_count > 0: - parts.append(f"{inactive_count} inactive") - if unknown_count > 0: - parts.append(f"{unknown_count} unknown") - - lines.append(f"\n{total} resources ({', '.join(parts)})\n") - - return lines diff --git a/src/runpod_flash/cli/main.py b/src/runpod_flash/cli/main.py index 57d51ab7..1c37b76f 100644 --- a/src/runpod_flash/cli/main.py +++ b/src/runpod_flash/cli/main.py @@ -45,7 +45,6 @@ def get_version() -> str: app.command("login")(login.login_command) app.command("deploy")(deploy.deploy_command) app.command("update")(update.update_command) -# app.command("report")(resource.report_command) # command: flash env diff --git a/src/runpod_flash/cli/utils/conda.py b/src/runpod_flash/cli/utils/conda.py deleted file mode 100644 index 86213767..00000000 --- a/src/runpod_flash/cli/utils/conda.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Conda environment management utilities.""" - -import subprocess -from typing import List, Tuple -from rich.console import Console - -console = Console() - - -def check_conda_available() -> bool: - """Check if conda is available on the system.""" - try: - result = subprocess.run( - ["conda", "--version"], - capture_output=True, - text=True, - timeout=5, - ) - return result.returncode == 0 - except (subprocess.SubprocessError, FileNotFoundError): - return False - - -def create_conda_environment( - env_name: str, python_version: str = "3.11" -) -> Tuple[bool, str]: - """ - Create a new conda environment. - - Args: - env_name: Name of the conda environment - python_version: Python version to use - - Returns: - Tuple of (success, message) - """ - try: - console.print(f"Creating conda environment: {env_name}") - - result = subprocess.run( - ["conda", "create", "-n", env_name, f"python={python_version}", "-y"], - capture_output=True, - text=True, - timeout=300, # 5 minutes timeout - ) - - if result.returncode == 0: - return True, f"Conda environment '{env_name}' created successfully" - else: - return False, f"Failed to create environment: {result.stderr}" - - except subprocess.TimeoutExpired: - return False, "Environment creation timed out" - except Exception as e: - return False, f"Error creating environment: {e}" - - -def install_packages_in_env( - env_name: str, packages: List[str], use_pip: bool = True -) -> Tuple[bool, str]: - """ - Install packages in a conda environment. - - Args: - env_name: Name of the conda environment - packages: List of packages to install - use_pip: If True, use pip install; otherwise use conda install - - Returns: - Tuple of (success, message) - """ - try: - console.print(f"Installing packages: {', '.join(packages)}") - - if use_pip: - # Use conda run to execute pip in the environment - cmd = [ - "conda", - "run", - "-n", - env_name, - "pip", - "install", - ] + packages - else: - cmd = ["conda", "install", "-n", env_name, "-y"] + packages - - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=600, # 10 minutes timeout - ) - - if result.returncode == 0: - return True, "Packages installed successfully" - else: - return False, f"Failed to install packages: {result.stderr}" - - except subprocess.TimeoutExpired: - return False, "Package installation timed out" - except Exception as e: - return False, f"Error installing packages: {e}" - - -def environment_exists(env_name: str) -> bool: - """Check if a conda environment exists.""" - try: - result = subprocess.run( - ["conda", "env", "list"], - capture_output=True, - text=True, - timeout=10, - ) - - if result.returncode == 0: - # Check if environment name appears in the output - return env_name in result.stdout - return False - - except Exception: - return False - - -def get_activation_command(env_name: str) -> str: - """Get the command to activate the conda environment.""" - return f"conda activate {env_name}" diff --git a/src/runpod_flash/cli/utils/deployment.py b/src/runpod_flash/cli/utils/deployment.py index 900c7ca2..3d7f1f84 100644 --- a/src/runpod_flash/cli/utils/deployment.py +++ b/src/runpod_flash/cli/utils/deployment.py @@ -5,10 +5,8 @@ import json import logging from typing import Dict, Any -from datetime import datetime from pathlib import Path -from runpod_flash.config import get_paths from runpod_flash.core.resources.serverless import ServerlessResource from runpod_flash.core.resources.app import FlashApp from runpod_flash.core.resources.resource_manager import ResourceManager @@ -50,171 +48,6 @@ def _resource_config_for_compare(config: Dict[str, Any]) -> Dict[str, Any]: return compare_config -async def upload_build(app_name: str, build_path: str | Path): - app = await FlashApp.from_name(app_name) - await app.upload_build(build_path) - - -def get_deployment_environments() -> Dict[str, Dict[str, Any]]: - """Get all deployment environments.""" - paths = get_paths() - deployments_file = paths.deployments_file - - if not deployments_file.exists(): - return {} - - try: - with open(deployments_file) as f: - return json.load(f) - except (json.JSONDecodeError, FileNotFoundError): - return {} - - -def save_deployment_environments(environments: Dict[str, Dict[str, Any]]): - """Save deployment environments to file.""" - paths = get_paths() - deployments_file = paths.deployments_file - - # Ensure .flash directory exists - paths.ensure_flash_dir() - - with open(deployments_file, "w") as f: - json.dump(environments, f, indent=2) - - -def create_deployment_environment(name: str, config: Dict[str, Any]): - """Create a new deployment environment.""" - environments = get_deployment_environments() - - # Mock environment creation - environments[name] = { - "status": "idle", - "config": config, - "created_at": datetime.now().isoformat(), - "current_version": None, - "last_deployed": None, - "url": None, - "version_history": [], - } - - save_deployment_environments(environments) - - -def remove_deployment_environment(name: str): - """Remove a deployment environment.""" - environments = get_deployment_environments() - - if name in environments: - del environments[name] - save_deployment_environments(environments) - - -async def provision_resources_for_build( - app: FlashApp, build_id: str, environment_name: str, show_progress: bool = True -) -> Dict[str, str]: - """Provision all resources upfront before environment activation. - - Args: - app: FlashApp instance - build_id: ID of the build to provision resources for - environment_name: Name of environment (for logging/context) - show_progress: Whether to show CLI progress - - Returns: - Mapping of resource_name -> endpoint_url - - Raises: - RuntimeError: If provisioning fails for any resource - """ - # Load manifest from build - manifest = await app.get_build_manifest(build_id) - - if not manifest or "resources" not in manifest: - log.warning(f"No resources in manifest for build {build_id}") - return {} - - # Create resource manager - manager = ResourceManager() - resources_to_provision = [] - - # Create resource configs from manifest - manifest_python_version = manifest.get("python_version") - for resource_name, resource_config in manifest["resources"].items(): - resource = create_resource_from_manifest( - resource_name, - resource_config, - python_version=manifest_python_version, - flash_app_name=app.name, - flash_env_name=environment_name, - ) - resources_to_provision.append((resource_name, resource)) - - if show_progress: - print( - f"Provisioning {len(resources_to_provision)} resources for environment '{environment_name}'..." - ) - - # Provision resources in parallel - resources_endpoints = {} - provisioning_results = [] - - try: - # Use asyncio.gather for parallel provisioning - tasks = [ - manager.get_or_deploy_resource(resource) - for _, resource in resources_to_provision - ] - provisioning_results = await asyncio.gather(*tasks) - - except Exception as e: - log.error(f"Provisioning failed: {e}") - raise RuntimeError(f"Failed to provision resources: {e}") from e - - # Build resources_endpoints mapping - lb_endpoint_url = None - for (resource_name, _), deployed_resource in zip( - resources_to_provision, provisioning_results - ): - # Get endpoint URL (both LoadBalancer and Serverless have endpoint_url) - if hasattr(deployed_resource, "endpoint_url"): - endpoint_url = deployed_resource.endpoint_url - else: - log.warning(f"Resource {resource_name} has no endpoint_url attribute") - continue - - resources_endpoints[resource_name] = endpoint_url - - endpoint_id = _normalized_resource_attr(deployed_resource, "endpoint_id", "id") - if endpoint_id: - manifest["resources"][resource_name]["endpoint_id"] = endpoint_id - - ai_key = _normalized_resource_attr(deployed_resource, "aiKey", "ai_key") - if ai_key: - manifest["resources"][resource_name]["aiKey"] = ai_key - - # Track load balancer URL for prominent logging - if manifest["resources"][resource_name].get("is_load_balanced"): - lb_endpoint_url = endpoint_url - - if show_progress: - print(f" ✓ {resource_name}: {endpoint_url}") - - # Update manifest in FlashApp with resources_endpoints - manifest["resources_endpoints"] = resources_endpoints - await app.update_build_manifest(build_id, manifest) - - if show_progress: - print("✓ All resources provisioned and manifest updated") - # Display load balancer URL prominently if present - if lb_endpoint_url: - print() - print("=" * 60) - print(f"Load Balancer Endpoint: {lb_endpoint_url}") - print("=" * 60) - - return resources_endpoints - - async def reconcile_and_provision_resources( app: FlashApp, build_id: str, @@ -508,73 +341,3 @@ async def deploy_from_uploaded_build( result["resources_endpoints"] = resources_endpoints result["local_manifest"] = local_manifest return result - - -def rollback_deployment(name: str, target_version: str): - """Rollback deployment to a previous version (mock implementation).""" - environments = get_deployment_environments() - - if name not in environments: - raise ValueError(f"Environment {name} not found") - - # Find target version - target_version_info = None - for version in environments[name]["version_history"]: - if version["version"] == target_version: - target_version_info = version - break - - if not target_version_info: - raise ValueError(f"Version {target_version} not found") - - # Update current version - environments[name]["current_version"] = target_version - environments[name]["last_deployed"] = datetime.now().isoformat() - - # Update version history - for version in environments[name]["version_history"]: - version["is_current"] = version["version"] == target_version - - save_deployment_environments(environments) - - -def get_environment_info(name: str) -> Dict[str, Any]: - """Get detailed information about an environment.""" - environments = get_deployment_environments() - - if name not in environments: - raise ValueError(f"Environment {name} not found") - - env_info = environments[name].copy() - - # Add mock metrics and additional info - if env_info["status"] == "active": - env_info.update( - { - "uptime": "99.9%", - "requests_24h": 145234, - "avg_response_time": "245ms", - "error_rate": "0.02%", - "cpu_usage": "45%", - "memory_usage": "62%", - } - ) - - # Ensure version history exists and is properly formatted - if "version_history" not in env_info: - env_info["version_history"] = [] - - # Add sample version history if empty - if not env_info["version_history"] and env_info["current_version"]: - env_info["version_history"] = [ - { - "version": env_info["current_version"], - "deployed_at": env_info.get( - "last_deployed", datetime.now().isoformat() - ), - "description": "Initial deployment", - "is_current": True, - } - ] - - return env_info diff --git a/src/runpod_flash/cli/utils/formatting.py b/src/runpod_flash/cli/utils/formatting.py index 1bf16a9d..223236bd 100644 --- a/src/runpod_flash/cli/utils/formatting.py +++ b/src/runpod_flash/cli/utils/formatting.py @@ -4,8 +4,6 @@ from rich.console import Console -STATE_STYLE = {"HEALTHY": "green", "BUILDING": "yellow", "ERROR": "red"} - def print_error(console: Console, message: str) -> None: """Print a standardized error message.""" @@ -17,12 +15,6 @@ def print_warning(console: Console, message: str) -> None: console.print(f"[yellow]![/yellow] {message.lstrip()}") -def state_dot(state: str) -> str: - """Colored ● indicator for a resource/environment state.""" - color = STATE_STYLE.get(state, "yellow") - return f"[{color}]●[/{color}]" - - def format_datetime(value: str | None) -> str: """Format an ISO 8601 datetime string into a human-readable local time. diff --git a/src/runpod_flash/core/api/runpod.py b/src/runpod_flash/core/api/runpod.py index abb56c13..42c23d87 100644 --- a/src/runpod_flash/core/api/runpod.py +++ b/src/runpod_flash/core/api/runpod.py @@ -341,56 +341,6 @@ async def save_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: return endpoint_data - async def get_cpu_types(self) -> Dict[str, Any]: - """Get available CPU types.""" - query = """ - query getCpuTypes { - cpuTypes { - id - displayName - manufacturer - cores - threadsPerCore - groupId - } - } - """ - - result = await self._execute_graphql(query) - return result.get("cpuTypes", []) - - async def get_gpu_types( - self, gpu_filter: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - """Get available GPU types.""" - query = """ - query getGpuTypes($input: GpuTypeFilter) { - gpuTypes(input: $input) { - id - displayName - manufacturer - memoryInGb - cudaCores - secureCloud - communityCloud - securePrice - communityPrice - communitySpotPrice - secureSpotPrice - maxGpuCount - maxGpuCountCommunityCloud - maxGpuCountSecureCloud - minPodGpuCount - nodeGroupGpuSizes - throughput - } - } - """ - - variables = {"input": gpu_filter} if gpu_filter else {} - result = await self._execute_graphql(query, variables) - return result.get("gpuTypes", []) - async def get_gpu_lowest_price_stock_status( self, gpu_id: str, @@ -465,12 +415,6 @@ async def get_cpu_specific_stock_status( return status.strip() return None - async def get_endpoint(self, endpoint_id: str) -> Dict[str, Any]: - """Get endpoint details.""" - # Note: The schema doesn't show a specific endpoint query - # This would need to be implemented if such query exists - raise NotImplementedError("Get endpoint query not available in current schema") - async def delete_endpoint(self, endpoint_id: str) -> Dict[str, Any]: """Delete a serverless endpoint.""" mutation = """ @@ -682,12 +626,6 @@ async def update_build_manifest( f"Expected 'updateFlashBuildManifest' in response, got: {list(result.keys())}" ) - async def get_flash_artifact_url(self, environment_id: str) -> Dict[str, Any]: - result = await self.get_flash_environment( - {"flashEnvironmentId": environment_id} - ) - return result - async def deploy_build_to_environment( self, input_data: Dict[str, Any] ) -> Dict[str, Any]: @@ -759,75 +697,6 @@ async def create_flash_environment( return result["createFlashEnvironment"] - async def register_endpoint_to_environment( - self, input_data: Dict[str, Any] - ) -> Dict[str, Any]: - """Register an endpoint to a Flash environment""" - - log.debug( - f"Registering endpoint to flash environment with input data: {input_data}" - ) - - mutation = """ - mutation addEndpointToFlashEnvironment($input: AddEndpointToEnvironmentInput!) { - addEndpointToFlashEnvironment(input: $input) { - id - name - flashEnvironmentId - } - } - """ - - variables = {"input": input_data} - - result = await self._execute_graphql(mutation, variables) - - return result["addEndpointToFlashEnvironment"] - - async def register_network_volume_to_environment( - self, input_data: Dict[str, Any] - ) -> Dict[str, Any]: - """Register an endpoint to a Flash environment""" - - log.debug( - f"Registering endpoint to flash environment with input data: {input_data}" - ) - - mutation = """ - mutation addNetworkVolumeToFlashEnvironment($input: AddNetworkVolumeToEnvironmentInput!) { - addNetworkVolumeToFlashEnvironment(input: $input) { - id - name - flashEnvironmentId - } - } - """ - - variables = {"input": input_data} - - result = await self._execute_graphql(mutation, variables) - - return result["addNetworkVolumeToFlashEnvironment"] - - async def set_environment_state(self, input_data: Dict[str, Any]) -> Dict[str, Any]: - log.debug(f"Setting Flash environment status with input data: {input_data}") - - mutation = """ - mutation updateFlashEnvironment($input: UpdateFlashEnvironmentInput!) { - updateFlashEnvironment(input: $input) { - id - name - state - } - } - """ - - variables = {"input": input_data} - - result = await self._execute_graphql(mutation, variables) - - return result["updateFlashEnvironment"] - async def get_flash_build(self, build_id: str) -> Dict[str, Any]: """Fetch flash build by ID. diff --git a/src/runpod_flash/core/resources/app.py b/src/runpod_flash/core/resources/app.py index 6d8d5ca5..1a69a342 100644 --- a/src/runpod_flash/core/resources/app.py +++ b/src/runpod_flash/core/resources/app.py @@ -5,8 +5,7 @@ import logging from ..api.runpod import RunpodGraphQLClient -from ..resources.resource_manager import ResourceManager -from ..resources.serverless import ServerlessEndpoint, NetworkVolume + from .constants import ( TARBALL_CONTENT_TYPE, MAX_TARBALL_SIZE_MB, @@ -42,12 +41,6 @@ class FlashEnvironmentNotFoundError(FlashAppError): pass -class FlashBuildNotFoundError(FlashAppError): - """Raised when a Flash build cannot be found.""" - - pass - - def _validate_exclusive_params( param1: Any, param2: Any, name1: str, name2: str ) -> None: @@ -332,21 +325,6 @@ async def _hydrate(self) -> None: self._hydrated = True return - async def _get_id_by_name(self) -> str: - """Get the app ID from the server by name. - - Returns: - The app ID string - - Raises: - FlashAppNotFoundError: If the app is not found on the server - """ - async with RunpodGraphQLClient() as client: - result = await client.get_flash_app_by_name(self.name) - if not result.get("id"): - raise FlashAppNotFoundError(f"Flash app '{self.name}' not found") - return result["id"] - async def create_environment(self, environment_name: str) -> Dict[str, Any]: """Create an environment within an app. @@ -384,28 +362,6 @@ async def _get_tarball_upload_url(self, tarball_size: int) -> Dict[str, str]: {"flashAppId": self.id, "tarballSize": tarball_size} ) - async def _get_active_artifact(self, environment_id: str) -> Dict[str, Any]: - """Get the active artifact for an environment. - - Args: - environment_id: ID of the environment - - Returns: - Dictionary containing artifact information including downloadUrl - - Raises: - RuntimeError: If app is not hydrated (no ID available) - ValueError: If environment has no active artifact - """ - await self._hydrate() - async with RunpodGraphQLClient() as client: - result = await client.get_flash_artifact_url(environment_id) - if not result.get("activeArtifact"): - raise ValueError( - f"No active artifact found for environment ID: {environment_id}" - ) - return result["activeArtifact"] - async def deploy_build_to_environment( self, build_id: str, @@ -442,32 +398,6 @@ async def deploy_build_to_environment( ) return result - async def download_tarball(self, environment_id: str, dest_file: str) -> None: - """Download the active build tarball from an environment. - - Args: - environment_id: ID of the environment to download from - dest_file: Path where the tarball should be saved - - Raises: - RuntimeError: If app is not hydrated (no ID available) - ValueError: If environment has no active artifact - requests.HTTPError: If download fails - """ - from runpod_flash.core.utils.http import get_authenticated_requests_session - - await self._hydrate() - result = await self._get_active_artifact(environment_id) - url = result["downloadUrl"] - - with open(dest_file, "wb") as stream: - with get_authenticated_requests_session() as session: - with session.get(url, stream=True) as resp: - resp.raise_for_status() - for chunk in resp.iter_content(): - if chunk: - stream.write(chunk) - async def _finalize_upload_build( self, object_key: str, manifest: Dict[str, Any] ) -> Dict[str, Any]: @@ -493,53 +423,6 @@ async def _finalize_upload_build( ) return result - async def _register_endpoint_to_environment( - self, environment_id: str, endpoint_id: str - ) -> Dict[str, Any]: - """Register a serverless endpoint to an environment. - - Args: - environment_id: ID of the environment - endpoint_id: ID of the endpoint to register - - Returns: - Dictionary containing registration result - - Raises: - RuntimeError: If app is not hydrated (no ID available) - """ - await self._hydrate() - async with RunpodGraphQLClient() as client: - result = await client.register_endpoint_to_environment( - {"flashEnvironmentId": environment_id, "endpointId": endpoint_id} - ) - return result - - async def _register_network_volume_to_environment( - self, environment_id: str, network_volume_id: str - ) -> Dict[str, Any]: - """Register a network volume to an environment. - - Args: - environment_id: ID of the environment - network_volume_id: ID of the network volume to register - - Returns: - Dictionary containing registration result - - Raises: - RuntimeError: If app is not hydrated (no ID available) - """ - await self._hydrate() - async with RunpodGraphQLClient() as client: - result = await client.register_network_volume_to_environment( - { - "flashEnvironmentId": environment_id, - "networkVolumeId": network_volume_id, - } - ) - return result - async def upload_build(self, tar_path: Union[str, Path]) -> Dict[str, Any]: """Upload a build tarball to the server. @@ -599,81 +482,6 @@ async def upload_build(self, tar_path: Union[str, Path]) -> Dict[str, Any]: _upload_progress_callback = None - async def _set_environment_state(self, environment_id: str, status: str) -> None: - """Set the state of an environment. - - Args: - environment_id: ID of the environment - status: State to set (e.g., "HEALTHY", "DEPLOYING", "PENDING") - - Raises: - RuntimeError: If app is not hydrated (no ID available) - """ - await self._hydrate() - async with RunpodGraphQLClient() as client: - await client.set_environment_state( - {"flashEnvironmentId": environment_id, "status": status} - ) - - async def _get_environment_by_name(self, environment_name: str) -> Dict[str, Any]: - """Get an environment by name. - - Args: - environment_name: Name of the environment to retrieve - - Returns: - Dictionary containing environment data - - Raises: - RuntimeError: If app is not hydrated (no ID available) - ValueError: If environment is not found - """ - await self._hydrate() - async with RunpodGraphQLClient() as client: - result = await client.get_flash_environment_by_name( - {"flashAppId": self.id, "name": environment_name} - ) - return result["flashEnvironmentByName"] - - async def deploy_resources(self, environment_name: str) -> None: - """Deploy all registered resources to an environment. - - This method iterates through all resources registered with the app - (via @remote decorator with resource_config) and deploys them, - then registers them to the specified environment. - - Args: - environment_name: Name of the environment to deploy resources to - - Raises: - RuntimeError: If app is not hydrated (no ID available) - ValueError: If environment is not found - """ - await self._hydrate() - resource_manager = ResourceManager() - environment = await self._get_environment_by_name(environment_name) - - # NOTE(jhcipar) it's pretty fragile to have client managed state like this - # we should enforce this on the server side eventually and either debounce or not allow subsequent deploys - await self._set_environment_state(environment["id"], "DEPLOYING") - - for resource_id, resource in self.resources.items(): - deployed_resource = await resource_manager.get_or_deploy_resource(resource) - if isinstance(deployed_resource, ServerlessEndpoint): - if deployed_resource.id: - await self._register_endpoint_to_environment( - environment["id"], deployed_resource.id - ) - if isinstance(deployed_resource, NetworkVolume): - if deployed_resource.id: - await self._register_network_volume_to_environment( - environment["id"], deployed_resource.id - ) - - # NOTE(jhcipar) we should healthcheck endpoints after provisioning them, for right now we just - # assume this is healthy - await self._set_environment_state(environment["id"], "HEALTHY") - @classmethod async def from_name(cls, app_name: str) -> "FlashApp": async with RunpodGraphQLClient() as client: @@ -753,22 +561,6 @@ async def delete_environment(self, environment_name: str) -> bool: result = await client.delete_flash_environment(environment_id) return result.get("success", False) - async def get_build(self, build_id: str) -> Dict[str, Any]: - """Get a build by ID. - - Args: - build_id: ID of the build to retrieve - - Returns: - Dictionary containing build data - - Raises: - RuntimeError: If app is not hydrated (no ID available) - """ - await self._hydrate() - async with RunpodGraphQLClient() as client: - return await client.get_flash_build(build_id) - async def list_builds(self) -> List[Dict[str, Any]]: """List all builds for this app. diff --git a/src/runpod_flash/core/resources/constants.py b/src/runpod_flash/core/resources/constants.py index 188a2567..d82d6d2c 100644 --- a/src/runpod_flash/core/resources/constants.py +++ b/src/runpod_flash/core/resources/constants.py @@ -4,24 +4,15 @@ # so all resources must share a single Python version. GPU images ship 3.12 # with torch pre-installed; 3.10 and 3.11 are available via side-by-side # install (~7 GB alt-Python overhead) in the same base image. -WORKER_PYTHON_VERSION: str = "3.12" GPU_PYTHON_VERSIONS: tuple[str, ...] = ("3.10", "3.11", "3.12") CPU_PYTHON_VERSIONS: tuple[str, ...] = ("3.10", "3.11", "3.12") -# Base image ships 3.12 with torch pre-installed; non-3.12 targets reinstall -# torch side-by-side for the selected interpreter. -GPU_BASE_IMAGE_PYTHON_VERSION: str = "3.12" DEFAULT_PYTHON_VERSION: str = "3.12" # Python versions that can run the flash SDK locally (for flash build, etc.) SUPPORTED_PYTHON_VERSIONS: tuple[str, ...] = ("3.10", "3.11", "3.12") -def local_python_version() -> str: - """Return the default worker Python version.""" - return DEFAULT_PYTHON_VERSION - - # Image type to repository mapping _IMAGE_REPOS: dict[str, str] = { "gpu": "runpod/flash", @@ -117,20 +108,10 @@ def get_image_name( # Docker image configuration FLASH_IMAGE_TAG = os.environ.get("FLASH_IMAGE_TAG", "latest") -_RESOLVED_TAG = FLASH_IMAGE_TAG -FLASH_GPU_IMAGE = os.environ.get( - "FLASH_GPU_IMAGE", f"runpod/flash:py{DEFAULT_PYTHON_VERSION}-{_RESOLVED_TAG}" -) -FLASH_CPU_IMAGE = os.environ.get( - "FLASH_CPU_IMAGE", f"runpod/flash-cpu:py{DEFAULT_PYTHON_VERSION}-{_RESOLVED_TAG}" -) -FLASH_LB_IMAGE = os.environ.get( - "FLASH_LB_IMAGE", f"runpod/flash-lb:py{DEFAULT_PYTHON_VERSION}-{_RESOLVED_TAG}" -) FLASH_CPU_LB_IMAGE = os.environ.get( "FLASH_CPU_LB_IMAGE", - f"runpod/flash-lb-cpu:py{DEFAULT_PYTHON_VERSION}-{_RESOLVED_TAG}", + f"runpod/flash-lb-cpu:py{DEFAULT_PYTHON_VERSION}-{FLASH_IMAGE_TAG}", ) # Worker configuration defaults diff --git a/src/runpod_flash/core/resources/load_balancer_sls_resource.py b/src/runpod_flash/core/resources/load_balancer_sls_resource.py index 8b9eb875..aa23604c 100644 --- a/src/runpod_flash/core/resources/load_balancer_sls_resource.py +++ b/src/runpod_flash/core/resources/load_balancer_sls_resource.py @@ -13,13 +13,11 @@ - Health checks via /ping endpoint """ -import asyncio import logging from typing import List, Optional from pydantic import model_validator -from runpod_flash.core.utils.http import get_authenticated_httpx_client from ..api.runpod import RunpodGraphQLClient from ..urls import ENDPOINT_DOMAIN from .cpu import CpuInstanceType @@ -29,15 +27,6 @@ log = logging.getLogger(__name__) -# Configuration constants -DEFAULT_HEALTH_CHECK_RETRIES = 10 -DEFAULT_HEALTH_CHECK_INTERVAL = 5 # seconds between retries -DEFAULT_PING_REQUEST_TIMEOUT = ( - 15.0 # seconds (load-balanced workers need time for cold starts) -) -HEALTHY_STATUS_CODES = (200, 204) - - class LoadBalancerSlsResource(ServerlessResource): """ Resource configuration for RunPod Load-Balanced Serverless endpoints. @@ -129,108 +118,6 @@ def _validate_lb_configuration(self) -> None: f"LoadBalancerSlsResource type must be LB, got {self.type.value}" ) - async def is_deployed_async(self) -> bool: - """ - Check if LB endpoint is deployed and /ping endpoint is responding. - - For LB endpoints, we verify: - 1. Endpoint ID exists (created in RunPod) - 2. /ping endpoint returns 200 or 204 - 3. Endpoint is in healthy state - - Returns: - True if endpoint is deployed and healthy, False otherwise - """ - try: - if not self.id: - return False - - # Use async health check for LB endpoints - return await self._check_ping_endpoint() - - except Exception as e: - log.debug(f"Error checking {self}: {e}") - return False - - async def _check_ping_endpoint(self) -> bool: - """ - Check if /ping endpoint is accessible and healthy. - - RunPod load-balancer endpoints require a /ping endpoint that returns: - - 200 OK: Worker is healthy and ready - - 204 No Content: Worker is initializing - - Other status: Worker is unhealthy - - Returns: - True if /ping endpoint responds with 200 or 204 - """ - try: - if not self.id: - return False - - ping_url = f"{self.endpoint_url}/ping" - - async with get_authenticated_httpx_client( - timeout=DEFAULT_PING_REQUEST_TIMEOUT - ) as client: - response = await client.get(ping_url) - return response.status_code in HEALTHY_STATUS_CODES - except Exception as e: - log.debug(f"Ping check failed for {self.name}: {e}") - return False - - async def _wait_for_health( - self, - max_retries: int = DEFAULT_HEALTH_CHECK_RETRIES, - retry_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL, - ) -> bool: - """ - Poll /ping endpoint until endpoint is healthy or timeout. - - Args: - max_retries: Number of health check attempts - retry_interval: Seconds between health check attempts - - Returns: - True if endpoint became healthy, False if timeout - - Raises: - ValueError: If endpoint ID not set - """ - if not self.id: - raise ValueError("Cannot wait for health: endpoint not deployed") - - log.debug( - f"Waiting for LB endpoint {self.name} ({self.id}) to become healthy... " - f"(max {max_retries} retries, {retry_interval}s interval)" - ) - - for attempt in range(max_retries): - try: - if await self._check_ping_endpoint(): - log.debug( - f"LB endpoint {self.name} is healthy (attempt {attempt + 1})" - ) - return True - - log.debug( - f"Health check attempt {attempt + 1}/{max_retries} - " - f"endpoint not ready yet" - ) - - except Exception as e: - log.debug(f"Health check attempt {attempt + 1} failed: {e}") - - # Wait before next attempt (except on last attempt) - if attempt < max_retries - 1: - await asyncio.sleep(retry_interval) - - log.debug( - f"LB endpoint {self.name} failed to become healthy after " - f"{max_retries} attempts" - ) - return False - async def _do_deploy(self) -> "LoadBalancerSlsResource": """ Deploy LB endpoint without blocking on health checks. diff --git a/src/runpod_flash/core/resources/resource_manager.py b/src/runpod_flash/core/resources/resource_manager.py index d6cd4241..2d51e593 100644 --- a/src/runpod_flash/core/resources/resource_manager.py +++ b/src/runpod_flash/core/resources/resource_manager.py @@ -219,9 +219,6 @@ async def _deploy_with_error_context( """ return await config._do_deploy() - async def get_resource_from_store(self, uid: str): - return self._resources.get(uid) - async def get_or_deploy_resource( self, config: DeployableResource ) -> DeployableResource: diff --git a/src/runpod_flash/core/utils/file_lock.py b/src/runpod_flash/core/utils/file_lock.py index b1866c34..bfa2dcb9 100644 --- a/src/runpod_flash/core/utils/file_lock.py +++ b/src/runpod_flash/core/utils/file_lock.py @@ -245,13 +245,3 @@ def _release_fallback_lock(file_handle: BinaryIO) -> None: except Exception as e: log.error(f"Failed to remove fallback lock file: {e}") - - -def get_platform_info() -> dict: - """Get information about current platform and available locking mechanisms.""" - return { - "platform": platform.system(), - "windows_locking": _IS_WINDOWS and _WINDOWS_LOCKING_AVAILABLE, - "unix_locking": _IS_UNIX and _UNIX_LOCKING_AVAILABLE, - "fallback_only": not (_WINDOWS_LOCKING_AVAILABLE or _UNIX_LOCKING_AVAILABLE), - } diff --git a/src/runpod_flash/logger.py b/src/runpod_flash/logger.py index 88283edc..9dcf7d04 100644 --- a/src/runpod_flash/logger.py +++ b/src/runpod_flash/logger.py @@ -50,11 +50,6 @@ class SensitiveDataFilter(logging.Filter): "authorization", } - # Pattern for generic tokens: 32+ char alphanumeric/underscore/hyphen/dot strings - # Excludes pure hex strings (commit SHAs, hashes) which are less likely to be tokens - # Using negative lookahead to exclude pure hex: (?![0-9a-fA-F]+$) - TOKEN_PATTERN = re.compile(r"(?![0-9a-fA-F]+$)\b[A-Za-z0-9_.-]{32,}\b") - # Pattern for common API key formats - capture prefix, separator, and quotes for proper redaction API_KEY_PATTERN = re.compile( r"((?:api[_-]?key|apikey|runpod[_-]?api[_-]?key)\s*[:=]\s*['\"]?)([A-Za-z0-9_-]+)(['\"]?)", @@ -136,18 +131,7 @@ def _redact_string(self, text: str) -> str: # Redact common prefixed API keys (sk-, key_, api_) text = self.PREFIXED_KEY_PATTERN.sub(self._redact_token, text) - # Generic token pattern disabled - causes false positives with Job IDs, Template IDs, etc. - # Specific patterns above catch actual sensitive tokens. - # text = self.TOKEN_PATTERN.sub(self._redact_token, text) - # Redact common password/secret patterns - # Match field names with : or = separators and redact the value, preserving separator - # Handles quoted values (captures until closing quote) and unquoted values (captures until whitespace/comma) - def redact_password_pattern(match): - field_name = match.group(1) - separator = match.group(2) - return f"{field_name}{separator}***REDACTED***" - # Pattern handles: password="value", password=value, password: value, etc. # For quoted values: captures everything until closing quote # For unquoted: captures until whitespace or comma diff --git a/src/runpod_flash/runtime/circuit_breaker.py b/src/runpod_flash/runtime/circuit_breaker.py deleted file mode 100644 index c24d4f69..00000000 --- a/src/runpod_flash/runtime/circuit_breaker.py +++ /dev/null @@ -1,274 +0,0 @@ -"""Circuit breaker pattern for handling endpoint failures.""" - -import asyncio -import logging -import time -from dataclasses import dataclass, field -from datetime import datetime, timezone -from enum import Enum -from typing import Any, Callable, Optional - -logger = logging.getLogger(__name__) - - -class CircuitState(Enum): - """Circuit breaker state machine.""" - - CLOSED = "closed" - OPEN = "open" - HALF_OPEN = "half_open" - - -@dataclass -class CircuitBreakerStats: - """Statistics for a circuit breaker instance.""" - - state: CircuitState = CircuitState.CLOSED - failure_count: int = 0 - success_count: int = 0 - last_failure_at: Optional[datetime] = None - last_success_at: Optional[datetime] = None - state_changed_at: datetime = field( - default_factory=lambda: datetime.now(timezone.utc) - ) - total_requests: int = 0 - total_failures: int = 0 - - -class EndpointCircuitBreaker: - """Circuit breaker for a single endpoint with sliding window.""" - - def __init__( - self, - endpoint_url: str, - failure_threshold: int = 5, - success_threshold: int = 2, - timeout_seconds: int = 60, - window_size: int = 10, - ): - """Initialize circuit breaker for an endpoint. - - Args: - endpoint_url: URL of the endpoint to protect - failure_threshold: Failures required to open circuit - success_threshold: Successes required to close circuit - timeout_seconds: Time before attempting recovery - window_size: Size of sliding window for counting failures - """ - self.endpoint_url = endpoint_url - self.failure_threshold = failure_threshold - self.success_threshold = success_threshold - self.timeout_seconds = timeout_seconds - self.window_size = window_size - self.stats = CircuitBreakerStats() - self._lock = asyncio.Lock() - self._failure_times: list[float] = [] - - async def execute(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: - """Execute function with circuit breaker protection. - - Args: - func: Async function to execute - *args: Positional arguments for func - **kwargs: Keyword arguments for func - - Returns: - Result from func - - Raises: - CircuitBreakerOpenError: If circuit is open - Exception: Any exception raised by func - """ - async with self._lock: - state = self.stats.state - - if state == CircuitState.OPEN: - # Check if timeout has passed - if self._should_attempt_recovery(): - self._transition_to_half_open() - else: - raise CircuitBreakerOpenError( - f"Circuit OPEN for {self.endpoint_url}. " - f"Retry in {self._seconds_until_recovery()}s" - ) - - # Execute function - try: - result = await func(*args, **kwargs) - await self._on_success() - return result - except Exception as e: - await self._on_failure(e) - raise - - async def _on_success(self) -> None: - """Record successful request.""" - async with self._lock: - self.stats.success_count += 1 - self.stats.total_requests += 1 - self.stats.last_success_at = datetime.now(timezone.utc) - - logger.debug( - f"Circuit breaker {self.endpoint_url}: " - f"success {self.stats.success_count}/{self.success_threshold}" - ) - - if self.stats.state == CircuitState.HALF_OPEN: - # Close circuit after successes - if self.stats.success_count >= self.success_threshold: - self._transition_to_closed() - elif self.stats.state == CircuitState.CLOSED: - # Reset failure count on success - self.stats.failure_count = 0 - self._failure_times.clear() - - async def _on_failure(self, error: Exception) -> None: - """Record failed request.""" - async with self._lock: - self.stats.failure_count += 1 - self.stats.total_failures += 1 - self.stats.total_requests += 1 - self.stats.last_failure_at = datetime.now(timezone.utc) - - # Track failure times for sliding window - now = time.time() - self._failure_times.append(now) - - # Keep only failures within window - cutoff = now - self.timeout_seconds - self._failure_times = [t for t in self._failure_times if t > cutoff] - - logger.debug( - f"Circuit breaker {self.endpoint_url}: " - f"failure {self.stats.failure_count}/{self.failure_threshold}, " - f"error: {error}" - ) - - if self.stats.state == CircuitState.HALF_OPEN: - # Open circuit on first failure in half-open - self._transition_to_open() - elif self.stats.state == CircuitState.CLOSED: - # Open circuit if threshold reached - if len(self._failure_times) >= self.failure_threshold: - self._transition_to_open() - - def _transition_to_open(self) -> None: - """Transition circuit to OPEN state.""" - if self.stats.state == CircuitState.OPEN: - return # Already open - self.stats.state = CircuitState.OPEN - self.stats.state_changed_at = datetime.now(timezone.utc) - self.stats.success_count = 0 - logger.warning( - f"Circuit breaker OPEN for {self.endpoint_url} " - f"after {self.stats.failure_count} failures" - ) - - def _transition_to_half_open(self) -> None: - """Transition circuit to HALF_OPEN state.""" - self.stats.state = CircuitState.HALF_OPEN - self.stats.state_changed_at = datetime.now(timezone.utc) - self.stats.failure_count = 0 - self.stats.success_count = 0 - logger.info( - f"Circuit breaker HALF_OPEN for {self.endpoint_url}, testing recovery" - ) - - def _transition_to_closed(self) -> None: - """Transition circuit to CLOSED state.""" - self.stats.state = CircuitState.CLOSED - self.stats.state_changed_at = datetime.now(timezone.utc) - self.stats.failure_count = 0 - self.stats.success_count = 0 - self._failure_times.clear() - logger.info(f"Circuit breaker CLOSED for {self.endpoint_url}, recovered") - - def _should_attempt_recovery(self) -> bool: - """Check if enough time has passed to attempt recovery.""" - if not self.stats.last_failure_at: - return False - elapsed = datetime.now(timezone.utc) - self.stats.state_changed_at - return elapsed.total_seconds() >= self.timeout_seconds - - def _seconds_until_recovery(self) -> int: - """Get seconds until recovery can be attempted.""" - if not self.stats.state_changed_at: - return self.timeout_seconds - elapsed = datetime.now(timezone.utc) - self.stats.state_changed_at - remaining = self.timeout_seconds - int(elapsed.total_seconds()) - return max(0, remaining) - - def get_state(self) -> CircuitState: - """Get current circuit state.""" - return self.stats.state - - def get_stats(self) -> CircuitBreakerStats: - """Get circuit breaker statistics.""" - return self.stats - - -class CircuitBreakerRegistry: - """Manages circuit breakers for multiple endpoints.""" - - def __init__( - self, - failure_threshold: int = 5, - success_threshold: int = 2, - timeout_seconds: int = 60, - ): - """Initialize circuit breaker registry. - - Args: - failure_threshold: Failures required to open circuit - success_threshold: Successes required to close circuit - timeout_seconds: Time before attempting recovery - """ - self.failure_threshold = failure_threshold - self.success_threshold = success_threshold - self.timeout_seconds = timeout_seconds - self._breakers: dict[str, EndpointCircuitBreaker] = {} - self._lock = asyncio.Lock() - - def get_breaker(self, endpoint_url: str) -> EndpointCircuitBreaker: - """Get or create circuit breaker for endpoint. - - Args: - endpoint_url: URL of the endpoint - - Returns: - EndpointCircuitBreaker instance - """ - if endpoint_url not in self._breakers: - self._breakers[endpoint_url] = EndpointCircuitBreaker( - endpoint_url, - failure_threshold=self.failure_threshold, - success_threshold=self.success_threshold, - timeout_seconds=self.timeout_seconds, - ) - return self._breakers[endpoint_url] - - def get_state(self, endpoint_url: str) -> CircuitState: - """Get state of circuit breaker for endpoint. - - Args: - endpoint_url: URL of the endpoint - - Returns: - Current circuit state - """ - breaker = self.get_breaker(endpoint_url) - return breaker.get_state() - - def get_all_stats(self) -> dict[str, CircuitBreakerStats]: - """Get statistics for all circuit breakers. - - Returns: - Mapping of endpoint URLs to statistics - """ - return {url: breaker.get_stats() for url, breaker in self._breakers.items()} - - -class CircuitBreakerOpenError(Exception): - """Raised when circuit breaker is open.""" - - pass diff --git a/src/runpod_flash/runtime/exceptions.py b/src/runpod_flash/runtime/exceptions.py index 520f129f..41ebf583 100644 --- a/src/runpod_flash/runtime/exceptions.py +++ b/src/runpod_flash/runtime/exceptions.py @@ -49,12 +49,6 @@ class GraphQLQueryError(GraphQLError): pass -class ManifestError(FlashRuntimeError): - """Raised when manifest is invalid, missing, or has unexpected structure.""" - - pass - - class ManifestServiceUnavailableError(FlashRuntimeError): """Raised when manifest service is unavailable.""" diff --git a/src/runpod_flash/runtime/load_balancer.py b/src/runpod_flash/runtime/load_balancer.py deleted file mode 100644 index 0c4b6f44..00000000 --- a/src/runpod_flash/runtime/load_balancer.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Load balancing strategies for distributed endpoint routing.""" - -import asyncio -import logging -import random -from typing import TYPE_CHECKING, List, Optional - -from runpod_flash.runtime.reliability_config import LoadBalancerStrategy - -if TYPE_CHECKING: - from runpod_flash.runtime.circuit_breaker import CircuitBreakerRegistry - -logger = logging.getLogger(__name__) - - -class LoadBalancer: - """Load balancer for selecting endpoints using various strategies.""" - - def __init__( - self, strategy: LoadBalancerStrategy = LoadBalancerStrategy.ROUND_ROBIN - ): - """Initialize load balancer. - - Args: - strategy: Load balancing strategy to use - """ - self.strategy = strategy - self._round_robin_index = 0 - self._lock = asyncio.Lock() - self._in_flight_requests: dict[str, int] = {} - - async def select_endpoint( - self, - endpoints: List[str], - circuit_breaker_registry: Optional["CircuitBreakerRegistry"] = None, - ) -> Optional[str]: - """Select an endpoint using configured strategy. - - Args: - endpoints: List of available endpoint URLs - circuit_breaker_registry: Optional circuit breaker registry to check health - - Returns: - Selected endpoint URL or None if all endpoints are unhealthy - """ - if not endpoints: - return None - - # Filter out unhealthy endpoints if circuit breaker available - healthy_endpoints = endpoints - if circuit_breaker_registry is not None: - from runpod_flash.runtime.circuit_breaker import CircuitState - - healthy_endpoints = [ - url - for url in endpoints - if circuit_breaker_registry.get_state(url) != CircuitState.OPEN - ] - - if not healthy_endpoints: - logger.warning( - f"All {len(endpoints)} endpoints are unhealthy (circuit open)" - ) - return None - - if self.strategy == LoadBalancerStrategy.ROUND_ROBIN: - return await self._round_robin_select(healthy_endpoints) - elif self.strategy == LoadBalancerStrategy.LEAST_CONNECTIONS: - return await self._least_connections_select(healthy_endpoints) - elif self.strategy == LoadBalancerStrategy.RANDOM: - return await self._random_select(healthy_endpoints) - else: - # Default to round-robin - return await self._round_robin_select(healthy_endpoints) - - async def _round_robin_select(self, endpoints: List[str]) -> str: - """Select endpoint using round-robin strategy. - - Args: - endpoints: List of available endpoints - - Returns: - Selected endpoint URL - """ - async with self._lock: - selected = endpoints[self._round_robin_index % len(endpoints)] - self._round_robin_index += 1 - return selected - - async def _least_connections_select(self, endpoints: List[str]) -> str: - """Select endpoint with fewest in-flight requests. - - Args: - endpoints: List of available endpoints - - Returns: - Selected endpoint URL - """ - async with self._lock: - # Initialize counts for endpoints - for endpoint in endpoints: - if endpoint not in self._in_flight_requests: - self._in_flight_requests[endpoint] = 0 - - # Find endpoint with minimum connections - selected = min(endpoints, key=lambda e: self._in_flight_requests.get(e, 0)) - - return selected - - async def _random_select(self, endpoints: List[str]) -> str: - """Select endpoint using random strategy. - - Args: - endpoints: List of available endpoints - - Returns: - Selected endpoint URL - """ - selected = random.choice(endpoints) - return selected - - async def record_request(self, endpoint: str) -> None: - """Record that a request is starting on endpoint. - - Args: - endpoint: Endpoint URL - """ - async with self._lock: - self._in_flight_requests[endpoint] = ( - self._in_flight_requests.get(endpoint, 0) + 1 - ) - - async def record_request_complete(self, endpoint: str) -> None: - """Record that a request completed on endpoint. - - Args: - endpoint: Endpoint URL - """ - async with self._lock: - if endpoint in self._in_flight_requests: - self._in_flight_requests[endpoint] = max( - 0, self._in_flight_requests[endpoint] - 1 - ) - - def get_stats(self) -> dict[str, int]: - """Get current in-flight request counts. - - Returns: - Mapping of endpoint URLs to in-flight request counts - """ - return dict(self._in_flight_requests) diff --git a/src/runpod_flash/runtime/metrics.py b/src/runpod_flash/runtime/metrics.py deleted file mode 100644 index 7095d9e8..00000000 --- a/src/runpod_flash/runtime/metrics.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Metrics collection via structured logging for observability.""" - -import logging -from dataclasses import asdict, dataclass -from enum import Enum -from typing import Any, Dict, Optional - -logger = logging.getLogger(__name__) - - -class MetricType(Enum): - """Types of metrics that can be collected.""" - - COUNTER = "counter" - GAUGE = "gauge" - HISTOGRAM = "histogram" - - -@dataclass -class Metric: - """Representation of a single metric.""" - - metric_type: MetricType - metric_name: str - value: float - labels: Dict[str, Any] - - def to_dict(self) -> Dict[str, Any]: - """Convert metric to dictionary. - - Returns: - Dictionary representation of metric - """ - return asdict(self) - - -class MetricsCollector: - """Collect metrics via structured logging.""" - - def __init__(self, namespace: str = "flash.metrics", enabled: bool = True): - """Initialize metrics collector. - - Args: - namespace: Namespace for metrics (used in structured logging) - enabled: Whether metrics collection is enabled - """ - self.namespace = namespace - self.enabled = enabled - - def counter( - self, - name: str, - value: float = 1.0, - labels: Optional[Dict[str, Any]] = None, - ) -> None: - """Record a counter metric (cumulative). - - Args: - name: Name of the metric - value: Value to add to counter (default: 1.0) - labels: Optional labels/tags for the metric - """ - if not self.enabled: - return - - metric = Metric(MetricType.COUNTER, name, value, labels or {}) - self._emit(metric) - - def gauge( - self, - name: str, - value: float, - labels: Optional[Dict[str, Any]] = None, - ) -> None: - """Record a gauge metric (point-in-time value). - - Args: - name: Name of the metric - value: Current value of the gauge - labels: Optional labels/tags for the metric - """ - if not self.enabled: - return - - metric = Metric(MetricType.GAUGE, name, value, labels or {}) - self._emit(metric) - - def histogram( - self, - name: str, - value: float, - labels: Optional[Dict[str, Any]] = None, - ) -> None: - """Record a histogram metric (distribution). - - Args: - name: Name of the metric - value: Value to add to histogram - labels: Optional labels/tags for the metric - """ - if not self.enabled: - return - - metric = Metric(MetricType.HISTOGRAM, name, value, labels or {}) - self._emit(metric) - - def _emit(self, metric: Metric) -> None: - """Emit metric via structured logging. - - Args: - metric: Metric to emit - """ - try: - logger.info( - f"[METRIC] {metric.metric_name}={metric.value}", - extra={ - "namespace": self.namespace, - "metric": metric.to_dict(), - }, - ) - except Exception as e: - logger.error(f"Failed to emit metric {metric.metric_name}: {e}") - - -# Global metrics collector instance -_collector: Optional[MetricsCollector] = None - - -def get_metrics_collector( - namespace: str = "flash.metrics", enabled: bool = True -) -> MetricsCollector: - """Get global metrics collector (lazy-loaded). - - Args: - namespace: Namespace for metrics - enabled: Whether metrics collection is enabled - - Returns: - MetricsCollector instance - """ - global _collector - if _collector is None: - _collector = MetricsCollector(namespace=namespace, enabled=enabled) - return _collector - - -def set_metrics_collector(collector: MetricsCollector) -> None: - """Set global metrics collector (for testing). - - Args: - collector: MetricsCollector instance - """ - global _collector - _collector = collector - - -class CircuitBreakerMetrics: - """Helper for emitting circuit breaker metrics.""" - - def __init__(self, collector: Optional[MetricsCollector] = None): - """Initialize circuit breaker metrics helper. - - Args: - collector: Optional MetricsCollector instance (uses global if not provided) - """ - self.collector = collector or get_metrics_collector() - - def state_changed( - self, endpoint_url: str, new_state: str, previous_state: str - ) -> None: - """Emit metric when circuit breaker state changes. - - Args: - endpoint_url: URL of the endpoint - new_state: New circuit state - previous_state: Previous circuit state - """ - self.collector.counter( - "circuit_breaker_state_changes", - value=1.0, - labels={ - "endpoint_url": endpoint_url, - "new_state": new_state, - "previous_state": previous_state, - }, - ) - - def endpoint_requests(self, endpoint_url: str, status: str, count: int = 1) -> None: - """Emit metric for endpoint requests. - - Args: - endpoint_url: URL of the endpoint - status: Request status (success, failure, etc.) - count: Number of requests - """ - self.collector.counter( - "endpoint_requests", - value=float(count), - labels={"endpoint_url": endpoint_url, "status": status}, - ) - - def endpoint_latency(self, endpoint_url: str, latency_ms: float) -> None: - """Emit metric for endpoint latency. - - Args: - endpoint_url: URL of the endpoint - latency_ms: Latency in milliseconds - """ - self.collector.histogram( - "endpoint_latency", - value=latency_ms, - labels={"endpoint_url": endpoint_url}, - ) - - def in_flight_requests(self, endpoint_url: str, count: int) -> None: - """Emit metric for in-flight requests. - - Args: - endpoint_url: URL of the endpoint - count: Current number of in-flight requests - """ - self.collector.gauge( - "in_flight_requests", - value=float(count), - labels={"endpoint_url": endpoint_url}, - ) - - -class RetryMetrics: - """Helper for emitting retry metrics.""" - - def __init__(self, collector: Optional[MetricsCollector] = None): - """Initialize retry metrics helper. - - Args: - collector: Optional MetricsCollector instance (uses global if not provided) - """ - self.collector = collector or get_metrics_collector() - - def retry_attempt( - self, function_name: str, attempt: int, error: Optional[str] = None - ) -> None: - """Emit metric for retry attempt. - - Args: - function_name: Name of the function being retried - attempt: Attempt number - error: Optional error message - """ - labels = { - "function_name": function_name, - "attempt": str(attempt), - } - if error: - labels["error"] = error - - self.collector.counter( - "retry_attempts", - value=1.0, - labels=labels, - ) - - def retry_success(self, function_name: str, total_attempts: int) -> None: - """Emit metric for successful retry. - - Args: - function_name: Name of the function - total_attempts: Total attempts made before success - """ - self.collector.counter( - "retry_success", - value=1.0, - labels={ - "function_name": function_name, - "attempts": str(total_attempts), - }, - ) - - def retry_exhausted(self, function_name: str, max_attempts: int) -> None: - """Emit metric when max retries exceeded. - - Args: - function_name: Name of the function - max_attempts: Maximum attempts configured - """ - self.collector.counter( - "retry_exhausted", - value=1.0, - labels={ - "function_name": function_name, - "max_attempts": str(max_attempts), - }, - ) - - -class LoadBalancerMetrics: - """Helper for emitting load balancer metrics.""" - - def __init__(self, collector: Optional[MetricsCollector] = None): - """Initialize load balancer metrics helper. - - Args: - collector: Optional MetricsCollector instance (uses global if not provided) - """ - self.collector = collector or get_metrics_collector() - - def endpoint_selected( - self, strategy: str, endpoint_url: str, total_candidates: int - ) -> None: - """Emit metric when endpoint is selected. - - Args: - strategy: Load balancing strategy used - endpoint_url: Selected endpoint URL - total_candidates: Total candidate endpoints - """ - self.collector.counter( - "load_balancer_selection", - value=1.0, - labels={ - "strategy": strategy, - "endpoint_url": endpoint_url, - "candidates": str(total_candidates), - }, - ) diff --git a/src/runpod_flash/runtime/reliability_config.py b/src/runpod_flash/runtime/reliability_config.py deleted file mode 100644 index ecae6440..00000000 --- a/src/runpod_flash/runtime/reliability_config.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Centralized configuration for reliability features.""" - -import os -from dataclasses import dataclass, field -from enum import Enum -from typing import Optional - - -class LoadBalancerStrategy(Enum): - """Load balancing strategies for endpoint selection.""" - - ROUND_ROBIN = "round_robin" - LEAST_CONNECTIONS = "least_connections" - RANDOM = "random" - - -@dataclass -class CircuitBreakerConfig: - """Configuration for circuit breaker behavior.""" - - enabled: bool = True - failure_threshold: int = 5 - success_threshold: int = 2 - timeout_seconds: int = 60 - window_size: int = 10 - - -@dataclass -class LoadBalancerConfig: - """Configuration for load balancer behavior.""" - - enabled: bool = False - strategy: LoadBalancerStrategy = LoadBalancerStrategy.ROUND_ROBIN - - -@dataclass -class RetryConfig: - """Configuration for retry behavior with exponential backoff.""" - - enabled: bool = True - max_attempts: int = 3 - base_delay: float = 0.5 - max_delay: float = 10.0 - jitter: float = 0.2 - retryable_exceptions: tuple = field( - default_factory=lambda: (TimeoutError, ConnectionError) - ) - retryable_status_codes: set = field( - default_factory=lambda: {408, 429, 500, 502, 503, 504} - ) - - -@dataclass -class MetricsConfig: - """Configuration for metrics collection.""" - - enabled: bool = True - namespace: str = "flash.metrics" - - -@dataclass -class ReliabilityConfig: - """Centralized reliability features configuration.""" - - circuit_breaker: CircuitBreakerConfig = field(default_factory=CircuitBreakerConfig) - load_balancer: LoadBalancerConfig = field(default_factory=LoadBalancerConfig) - retry: RetryConfig = field(default_factory=RetryConfig) - metrics: MetricsConfig = field(default_factory=MetricsConfig) - - @classmethod - def from_env(cls) -> "ReliabilityConfig": - """Load configuration from environment variables. - - Environment variables: - - FLASH_CIRCUIT_BREAKER_ENABLED: Enable circuit breaker (default: true) - - FLASH_CB_FAILURE_THRESHOLD: Failures before opening (default: 5) - - FLASH_CB_SUCCESS_THRESHOLD: Successes to close (default: 2) - - FLASH_CB_TIMEOUT_SECONDS: Time before half-open (default: 60) - - FLASH_LOAD_BALANCER_ENABLED: Enable load balancer (default: false) - - FLASH_LB_STRATEGY: Load balancer strategy (default: round_robin) - - FLASH_RETRY_ENABLED: Enable retry (default: true) - - FLASH_RETRY_MAX_ATTEMPTS: Max retry attempts (default: 3) - - FLASH_RETRY_BASE_DELAY: Base delay for backoff (default: 0.5) - - FLASH_METRICS_ENABLED: Enable metrics (default: true) - - Returns: - ReliabilityConfig initialized from environment variables. - """ - circuit_breaker = CircuitBreakerConfig( - enabled=os.getenv("FLASH_CIRCUIT_BREAKER_ENABLED", "true").lower() - == "true", - failure_threshold=int(os.getenv("FLASH_CB_FAILURE_THRESHOLD", "5")), - success_threshold=int(os.getenv("FLASH_CB_SUCCESS_THRESHOLD", "2")), - timeout_seconds=int(os.getenv("FLASH_CB_TIMEOUT_SECONDS", "60")), - ) - - strategy_str = os.getenv("FLASH_LB_STRATEGY", "round_robin").lower() - try: - strategy = LoadBalancerStrategy(strategy_str) - except ValueError: - strategy = LoadBalancerStrategy.ROUND_ROBIN - - load_balancer = LoadBalancerConfig( - enabled=os.getenv("FLASH_LOAD_BALANCER_ENABLED", "false").lower() == "true", - strategy=strategy, - ) - - retry = RetryConfig( - enabled=os.getenv("FLASH_RETRY_ENABLED", "true").lower() == "true", - max_attempts=int(os.getenv("FLASH_RETRY_MAX_ATTEMPTS", "3")), - base_delay=float(os.getenv("FLASH_RETRY_BASE_DELAY", "0.5")), - ) - - metrics = MetricsConfig( - enabled=os.getenv("FLASH_METRICS_ENABLED", "true").lower() == "true", - ) - - return cls( - circuit_breaker=circuit_breaker, - load_balancer=load_balancer, - retry=retry, - metrics=metrics, - ) - - -# Global default configuration -_config: Optional[ReliabilityConfig] = None - - -def get_reliability_config() -> ReliabilityConfig: - """Get global reliability configuration (lazy-loaded). - - Returns: - ReliabilityConfig instance initialized from environment. - """ - global _config - if _config is None: - _config = ReliabilityConfig.from_env() - return _config - - -def set_reliability_config(config: ReliabilityConfig) -> None: - """Set global reliability configuration (for testing). - - Args: - config: ReliabilityConfig to set as global. - """ - global _config - _config = config diff --git a/src/runpod_flash/runtime/retry_manager.py b/src/runpod_flash/runtime/retry_manager.py deleted file mode 100644 index bc43d937..00000000 --- a/src/runpod_flash/runtime/retry_manager.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Retry logic with exponential backoff for failed remote calls.""" - -import asyncio -import logging -from typing import Any, Callable, Optional, Set, Tuple, Type - -from runpod_flash.core.utils.backoff import get_backoff_delay - -logger = logging.getLogger(__name__) - - -class RetryExhaustedError(Exception): - """Raised when max retry attempts are exceeded.""" - - pass - - -async def retry_with_backoff( - func: Callable[..., Any], - max_attempts: int = 3, - base_delay: float = 0.5, - max_delay: float = 10.0, - jitter: float = 0.2, - retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None, - retryable_status_codes: Optional[Set[int]] = None, - circuit_breaker: Optional[Any] = None, - *args: Any, - **kwargs: Any, -) -> Any: - """Execute async function with retry and exponential backoff. - - Args: - func: Async function to execute - max_attempts: Maximum number of attempts (default: 3) - base_delay: Base delay between retries in seconds (default: 0.5) - max_delay: Maximum delay between retries (default: 10.0) - jitter: Jitter factor (0.0-1.0) to add randomness (default: 0.2) - retryable_exceptions: Tuple of exception types to retry on - (default: (asyncio.TimeoutError, ConnectionError)) - retryable_status_codes: Set of HTTP status codes to retry on - (default: {408, 429, 500, 502, 503, 504}) - circuit_breaker: Optional circuit breaker to check before retry - *args: Positional arguments for func - **kwargs: Keyword arguments for func - - Returns: - Result from successful function call - - Raises: - RetryExhaustedError: If max attempts exceeded - Exception: If non-retryable exception occurs - """ - if retryable_exceptions is None: - retryable_exceptions = (asyncio.TimeoutError, ConnectionError) - - if retryable_status_codes is None: - retryable_status_codes = {408, 429, 500, 502, 503, 504} - - last_exception: Optional[Exception] = None - - for attempt in range(max_attempts): - try: - # Check circuit breaker before attempting - if circuit_breaker is not None: - from runpod_flash.runtime.circuit_breaker import CircuitState - - if circuit_breaker.get_state() == CircuitState.OPEN: - raise RuntimeError( - f"Circuit breaker OPEN, skipping retry attempt {attempt + 1}" - ) - - result = await func(*args, **kwargs) - - # Log success on retry - if attempt > 0: - logger.info(f"Retry succeeded on attempt {attempt + 1}/{max_attempts}") - - return result - - except Exception as e: - last_exception = e - - # Check if exception is retryable - if not isinstance(e, retryable_exceptions): - logger.debug( - f"Non-retryable exception in {func.__name__}: {type(e).__name__}" - ) - raise - - # Check for retryable status codes (if exception has status_code) - if hasattr(e, "status_code"): - if e.status_code not in retryable_status_codes: # type: ignore - logger.debug( - f"Non-retryable status code {e.status_code} in {func.__name__}" - ) - raise - - # If this is the last attempt, don't retry - if attempt >= max_attempts - 1: - logger.warning( - f"Max retries ({max_attempts}) exhausted for {func.__name__}" - ) - raise RetryExhaustedError( - f"Failed after {max_attempts} attempts: {e}" - ) from e - - # Calculate delay with exponential backoff and jitter - delay = get_backoff_delay(attempt, base_delay, max_delay, jitter=jitter) - logger.debug( - f"Retry {attempt + 1}/{max_attempts} for {func.__name__} " - f"after {delay:.2f}s" - ) - await asyncio.sleep(delay) - - # Should never reach here, but handle edge case - if last_exception: - raise last_exception - raise RetryExhaustedError(f"Failed after {max_attempts} attempts") diff --git a/tests/unit/cli/commands/build_utils/test_manifest.py b/tests/unit/cli/commands/build_utils/test_manifest.py index 3a86c121..72efaa9b 100644 --- a/tests/unit/cli/commands/build_utils/test_manifest.py +++ b/tests/unit/cli/commands/build_utils/test_manifest.py @@ -1,6 +1,5 @@ """Tests for ManifestBuilder.""" -import json import sys import tempfile from pathlib import Path @@ -154,37 +153,6 @@ def test_build_manifest_includes_metadata(): assert test_class["is_class"] is True -def test_write_manifest_to_file(): - """Test writing manifest to file.""" - with tempfile.TemporaryDirectory() as tmpdir: - output_path = Path(tmpdir) / "flash_manifest.json" - - functions = [ - RemoteFunctionMetadata( - function_name="test_func", - module_path="workers.test", - resource_config_name="test_config", - resource_type="LiveServerless", - is_async=True, - is_class=False, - file_path=Path("workers/test.py"), - ) - ] - - builder = ManifestBuilder("test_app", functions) - result_path = builder.write_to_file(output_path) - - assert result_path.exists() - assert result_path == output_path - - # Read and verify content - with open(output_path) as f: - manifest = json.load(f) - - assert manifest["project_name"] == "test_app" - assert "test_config" in manifest["resources"] - - def test_manifest_empty_functions(): """Test building manifest with no functions.""" builder = ManifestBuilder("empty_app", []) diff --git a/tests/unit/cli/commands/test_build_helpers.py b/tests/unit/cli/commands/test_build_helpers.py index 3e46d262..8b8ec4eb 100644 --- a/tests/unit/cli/commands/test_build_helpers.py +++ b/tests/unit/cli/commands/test_build_helpers.py @@ -1,5 +1,5 @@ """Tests for build.py helper functions: _bundle_runpod_flash, -_remove_runpod_flash_from_requirements, _extract_runpod_flash_dependencies. +_remove_runpod_flash_from_requirements. These functions are always mocked in existing build tests; these tests exercise them directly. @@ -7,7 +7,6 @@ from runpod_flash.cli.commands.build import ( _bundle_runpod_flash, - _extract_runpod_flash_dependencies, _find_runpod_flash, _remove_runpod_flash_from_requirements, ) @@ -161,60 +160,6 @@ def test_keeps_other_packages(self, tmp_path): assert any("aiohttp" in line for line in lines) -class TestExtractRunpodFlashDependencies: - """Direct tests for _extract_runpod_flash_dependencies.""" - - def test_extracts_dependencies_from_pyproject(self, tmp_path): - """Extracts [project.dependencies] from pyproject.toml.""" - # Create flash_pkg_dir at src/runpod_flash - pkg_dir = tmp_path / "src" / "runpod_flash" - pkg_dir.mkdir(parents=True) - - # pyproject.toml is at project root (2 levels up from pkg_dir) - pyproject = tmp_path / "pyproject.toml" - pyproject.write_text( - '[project]\nname = "runpod-flash"\n' - 'dependencies = ["cloudpickle>=3.0", "pydantic>=2.0", "rich>=14.0"]\n' - ) - - deps = _extract_runpod_flash_dependencies(pkg_dir) - - assert len(deps) == 3 - assert "cloudpickle>=3.0" in deps - assert "pydantic>=2.0" in deps - assert "rich>=14.0" in deps - - def test_returns_empty_when_no_pyproject(self, tmp_path): - """Returns empty list when pyproject.toml doesn't exist.""" - pkg_dir = tmp_path / "src" / "runpod_flash" - pkg_dir.mkdir(parents=True) - - deps = _extract_runpod_flash_dependencies(pkg_dir) - assert deps == [] - - def test_returns_empty_on_parse_error(self, tmp_path): - """Returns empty list when pyproject.toml is invalid.""" - pkg_dir = tmp_path / "src" / "runpod_flash" - pkg_dir.mkdir(parents=True) - - pyproject = tmp_path / "pyproject.toml" - pyproject.write_text("this is not valid toml {{{{") - - deps = _extract_runpod_flash_dependencies(pkg_dir) - assert deps == [] - - def test_returns_empty_when_no_dependencies_key(self, tmp_path): - """Returns empty list when [project.dependencies] is missing.""" - pkg_dir = tmp_path / "src" / "runpod_flash" - pkg_dir.mkdir(parents=True) - - pyproject = tmp_path / "pyproject.toml" - pyproject.write_text('[project]\nname = "runpod-flash"\n') - - deps = _extract_runpod_flash_dependencies(pkg_dir) - assert deps == [] - - class TestFindRunpodFlashEdgeCases: """Edge cases for _find_runpod_flash not covered by existing tests.""" diff --git a/tests/unit/cli/commands/test_resource.py b/tests/unit/cli/commands/test_resource.py deleted file mode 100644 index cc2ab745..00000000 --- a/tests/unit/cli/commands/test_resource.py +++ /dev/null @@ -1,359 +0,0 @@ -"""Tests for resource management commands.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from runpod_flash.cli.commands.resource import _render_resource_report, report_command - - -@pytest.fixture -def mock_resource_manager(): - """Provide mock resource manager.""" - manager = MagicMock() - manager._resources = {} - return manager - - -class TestGenerateResourceTableEmpty: - """Tests for _render_resource_report with empty resources.""" - - def test_empty_resources_returns_renderable(self, mock_resource_manager): - """Test that empty resources returns a renderable.""" - result = _render_resource_report(mock_resource_manager) - assert result is not None - - def test_empty_resources_no_error(self, mock_resource_manager): - """Test that empty resources doesn't raise error.""" - try: - _render_resource_report(mock_resource_manager) - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - -class TestGenerateResourceTableSingleResource: - """Tests for _render_resource_report with single resource.""" - - def test_single_active_resource_no_error(self, mock_resource_manager): - """Test with single active resource doesn't error.""" - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = "https://example.com/endpoint-123" - - mock_resource_manager._resources = {"endpoint-001": resource} - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_single_inactive_resource_no_error(self, mock_resource_manager): - """Test with single inactive resource doesn't error.""" - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=False) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = "https://example.com/endpoint-456" - - mock_resource_manager._resources = {"endpoint-002": resource} - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_resource_is_deployed_exception_handled(self, mock_resource_manager): - """Test handles is_deployed exception.""" - resource = MagicMock() - resource.is_deployed = AsyncMock(side_effect=Exception("Connection failed")) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = "https://example.com/endpoint-789" - - mock_resource_manager._resources = {"endpoint-003": resource} - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_resource_without_url_attribute_handled(self, mock_resource_manager): - """Test resource without url attribute is handled.""" - resource = MagicMock(spec=["is_deployed", "__class__"]) - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "LoadBalancer" - - mock_resource_manager._resources = {"lb-001": resource} - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_resource_with_empty_url(self, mock_resource_manager): - """Test resource with empty string URL.""" - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = "" - - mock_resource_manager._resources = {"endpoint-empty-url": resource} - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - -class TestGenerateResourceTableMultipleResources: - """Tests for _render_resource_report with multiple resources.""" - - def test_multiple_resources_mixed_status_no_error(self, mock_resource_manager): - """Test with mixed statuses doesn't error.""" - active_resource = MagicMock() - active_resource.is_deployed = AsyncMock(return_value=True) - active_resource.__class__.__name__ = "ServerlessEndpoint" - active_resource.url = "https://api.example.com/active" - - inactive_resource = MagicMock() - inactive_resource.is_deployed = AsyncMock(return_value=False) - inactive_resource.__class__.__name__ = "ServerlessEndpoint" - inactive_resource.url = "https://api.example.com/inactive" - - mock_resource_manager._resources = { - "endpoint-001": active_resource, - "endpoint-002": inactive_resource, - } - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_multiple_resources_all_active_no_error(self, mock_resource_manager): - """Test with all active resources doesn't error.""" - resources = {} - for i in range(3): - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = f"https://api.example.com/endpoint-{i}" - resources[f"endpoint-{i}"] = resource - - mock_resource_manager._resources = resources - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_long_resource_id_handling(self, mock_resource_manager): - """Test that long resource IDs are handled (truncated).""" - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = "https://example.com" - - long_id = "a" * 30 - - mock_resource_manager._resources = {long_id: resource} - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_short_resource_id_no_error(self, mock_resource_manager): - """Test short resource IDs work.""" - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = "https://example.com" - - mock_resource_manager._resources = {"endpoint-123": resource} - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - -class TestGenerateResourceTableSummary: - """Tests for _render_resource_report summary calculation.""" - - def test_summary_all_active_no_error(self, mock_resource_manager): - """Test summary with all active resources.""" - resources = {} - for i in range(5): - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = f"https://example.com/{i}" - resources[f"endpoint-{i}"] = resource - - mock_resource_manager._resources = resources - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - def test_summary_mixed_status_no_error(self, mock_resource_manager): - """Test summary with mixed status resources.""" - resources = {} - - for i in range(2): - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = f"https://example.com/active-{i}" - resources[f"endpoint-{i}"] = resource - - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=False) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = "https://example.com/inactive" - resources["endpoint-2"] = resource - - for i in range(3, 5): - resource = MagicMock() - resource.is_deployed = AsyncMock(side_effect=Exception("Error")) - resource.__class__.__name__ = "ServerlessEndpoint" - resource.url = f"https://example.com/unknown-{i}" - resources[f"endpoint-{i}"] = resource - - mock_resource_manager._resources = resources - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - -class TestGenerateResourceTableResourceTypes: - """Tests for different resource types.""" - - def test_various_resource_types_no_error(self, mock_resource_manager): - """Test displaying various resource types.""" - resource_types = [ - "ServerlessEndpoint", - "LoadBalancer", - "NetworkVolume", - "CustomResource", - ] - - resources = {} - for i, res_type in enumerate(resource_types): - resource = MagicMock() - resource.is_deployed = AsyncMock(return_value=True) - resource.__class__.__name__ = res_type - resource.url = f"https://example.com/{res_type.lower()}-{i}" - resources[f"res-{i}"] = resource - - mock_resource_manager._resources = resources - - try: - result = _render_resource_report(mock_resource_manager) - assert result is not None - except Exception as e: - pytest.fail(f"_render_resource_report raised {type(e).__name__}: {e}") - - -@patch("runpod_flash.cli.commands.resource.ResourceManager") -@patch("runpod_flash.cli.commands.resource.console") -def test_report_command_static_mode(mock_console, mock_resource_manager_class): - """Test report_command in static (non-live) mode.""" - mock_manager_instance = MagicMock() - mock_manager_instance._resources = {} - mock_resource_manager_class.return_value = mock_manager_instance - - report_command(live=False) - - mock_resource_manager_class.assert_called_once() - mock_console.print.assert_called_once() - - -@patch("runpod_flash.cli.commands.resource.time") -@patch("runpod_flash.cli.commands.resource.Live") -@patch("runpod_flash.cli.commands.resource.ResourceManager") -@patch("runpod_flash.cli.commands.resource.console") -def test_report_command_live_mode( - mock_console, mock_resource_manager_class, mock_live_class, mock_time -): - """Test report_command in live mode with keyboard interrupt.""" - mock_manager_instance = MagicMock() - mock_manager_instance._resources = {} - mock_resource_manager_class.return_value = mock_manager_instance - - mock_live_instance = MagicMock() - mock_live_class.return_value.__enter__ = MagicMock(return_value=mock_live_instance) - mock_live_class.return_value.__exit__ = MagicMock(return_value=False) - - call_count = [0] - - def sleep_side_effect(duration): - call_count[0] += 1 - if call_count[0] > 0: - raise KeyboardInterrupt() - - mock_time.sleep.side_effect = sleep_side_effect - - report_command(live=True, refresh=2) - - mock_live_class.assert_called_once() - assert any("stopped" in str(c).lower() for c in mock_console.print.call_args_list) - - -@patch("runpod_flash.cli.commands.resource.ResourceManager") -@patch("runpod_flash.cli.commands.resource.console") -def test_report_command_with_custom_refresh(mock_console, mock_resource_manager_class): - """Test report_command accepts custom refresh interval.""" - mock_manager_instance = MagicMock() - mock_manager_instance._resources = {} - mock_resource_manager_class.return_value = mock_manager_instance - - report_command(live=False, refresh=5) - - mock_resource_manager_class.assert_called_once() - mock_console.print.assert_called_once() - - -@patch("runpod_flash.cli.commands.resource.ResourceManager") -@patch("runpod_flash.cli.commands.resource.console") -def test_report_command_instantiates_resource_manager( - mock_console, mock_resource_manager_class -): - """Test that report_command instantiates ResourceManager.""" - mock_manager_instance = MagicMock() - mock_manager_instance._resources = {} - mock_resource_manager_class.return_value = mock_manager_instance - - report_command(live=False) - - mock_resource_manager_class.assert_called_once_with() - - -@patch("runpod_flash.cli.commands.resource._render_resource_report") -@patch("runpod_flash.cli.commands.resource.ResourceManager") -@patch("runpod_flash.cli.commands.resource.console") -def test_report_command_calls_render_report( - mock_console, mock_resource_manager_class, mock_render -): - """Test that report_command calls _render_resource_report.""" - mock_manager_instance = MagicMock() - mock_resource_manager_class.return_value = mock_manager_instance - mock_render.return_value = MagicMock() - - report_command(live=False) - - mock_render.assert_called_once_with(mock_manager_instance) diff --git a/tests/unit/cli/utils/test_conda.py b/tests/unit/cli/utils/test_conda.py deleted file mode 100644 index 7968035c..00000000 --- a/tests/unit/cli/utils/test_conda.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Tests for cli/utils/conda.py - conda environment management.""" - -import subprocess -from unittest.mock import MagicMock, patch - - -from runpod_flash.cli.utils.conda import ( - check_conda_available, - create_conda_environment, - environment_exists, - get_activation_command, - install_packages_in_env, -) - - -class TestCheckCondaAvailable: - """Test check_conda_available function.""" - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_returns_true_when_available(self, mock_run): - """Returns True when conda --version succeeds.""" - mock_run.return_value = MagicMock(returncode=0) - assert check_conda_available() is True - mock_run.assert_called_once() - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_returns_false_on_nonzero_exit(self, mock_run): - """Returns False when conda returns non-zero exit code.""" - mock_run.return_value = MagicMock(returncode=1) - assert check_conda_available() is False - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_returns_false_on_file_not_found(self, mock_run): - """Returns False when conda binary is not found.""" - mock_run.side_effect = FileNotFoundError() - assert check_conda_available() is False - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_returns_false_on_subprocess_error(self, mock_run): - """Returns False on subprocess errors.""" - mock_run.side_effect = subprocess.SubprocessError("error") - assert check_conda_available() is False - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_timeout_is_set(self, mock_run): - """Verify timeout is set when calling conda.""" - mock_run.return_value = MagicMock(returncode=0) - check_conda_available() - call_kwargs = mock_run.call_args[1] - assert call_kwargs["timeout"] == 5 - - -class TestCreateCondaEnvironment: - """Test create_conda_environment function.""" - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_success(self, mock_run): - """Returns (True, message) on success.""" - mock_run.return_value = MagicMock(returncode=0) - success, msg = create_conda_environment("test-env", "3.11") - assert success is True - assert "successfully" in msg.lower() - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_failure(self, mock_run): - """Returns (False, message) on failure.""" - mock_run.return_value = MagicMock(returncode=1, stderr="conda error") - success, msg = create_conda_environment("test-env") - assert success is False - assert "conda error" in msg - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_timeout(self, mock_run): - """Returns (False, message) on timeout.""" - mock_run.side_effect = subprocess.TimeoutExpired("conda", 300) - success, msg = create_conda_environment("test-env") - assert success is False - assert "timed out" in msg.lower() - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_generic_exception(self, mock_run): - """Returns (False, message) on unexpected error.""" - mock_run.side_effect = RuntimeError("unexpected") - success, msg = create_conda_environment("test-env") - assert success is False - assert "unexpected" in msg - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_passes_python_version(self, mock_run): - """Verify python version is included in command.""" - mock_run.return_value = MagicMock(returncode=0) - create_conda_environment("test-env", "3.12") - cmd = mock_run.call_args[0][0] - assert "python=3.12" in cmd - - -class TestInstallPackagesInEnv: - """Test install_packages_in_env function.""" - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_pip_install_success(self, mock_run): - """Installs with pip by default.""" - mock_run.return_value = MagicMock(returncode=0) - success, msg = install_packages_in_env("test-env", ["numpy", "pandas"]) - assert success is True - cmd = mock_run.call_args[0][0] - assert "pip" in cmd - assert "numpy" in cmd - assert "pandas" in cmd - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_conda_install(self, mock_run): - """Installs with conda when use_pip=False.""" - mock_run.return_value = MagicMock(returncode=0) - success, msg = install_packages_in_env("test-env", ["scipy"], use_pip=False) - assert success is True - cmd = mock_run.call_args[0][0] - assert "conda" == cmd[0] - assert "install" in cmd - assert "scipy" in cmd - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_install_failure(self, mock_run): - """Returns (False, message) on failure.""" - mock_run.return_value = MagicMock(returncode=1, stderr="install error") - success, msg = install_packages_in_env("test-env", ["bad-pkg"]) - assert success is False - assert "install error" in msg - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_install_timeout(self, mock_run): - """Returns (False, message) on timeout.""" - mock_run.side_effect = subprocess.TimeoutExpired("cmd", 600) - success, msg = install_packages_in_env("test-env", ["torch"]) - assert success is False - assert "timed out" in msg.lower() - - -class TestEnvironmentExists: - """Test environment_exists function.""" - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_exists_true(self, mock_run): - """Returns True when environment is in the list.""" - mock_run.return_value = MagicMock( - returncode=0, stdout="base\ntest-env\nother-env" - ) - assert environment_exists("test-env") is True - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_exists_false(self, mock_run): - """Returns False when environment is not in the list.""" - mock_run.return_value = MagicMock(returncode=0, stdout="base\nother-env") - assert environment_exists("test-env") is False - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_exists_on_failure(self, mock_run): - """Returns False on subprocess failure.""" - mock_run.return_value = MagicMock(returncode=1) - assert environment_exists("test-env") is False - - @patch("runpod_flash.cli.utils.conda.subprocess.run") - def test_exists_on_exception(self, mock_run): - """Returns False on exception.""" - mock_run.side_effect = Exception("unexpected") - assert environment_exists("test-env") is False - - -class TestGetActivationCommand: - """Test get_activation_command function.""" - - def test_returns_correct_command(self): - """Returns correct conda activate command.""" - assert get_activation_command("my-env") == "conda activate my-env" - - def test_with_different_name(self): - """Works with different environment names.""" - assert get_activation_command("flash-3.11") == "conda activate flash-3.11" diff --git a/tests/unit/cli/utils/test_deployment.py b/tests/unit/cli/utils/test_deployment.py index f6c8e2b8..9dab333d 100644 --- a/tests/unit/cli/utils/test_deployment.py +++ b/tests/unit/cli/utils/test_deployment.py @@ -1,12 +1,10 @@ """Unit tests for CLI deployment utilities.""" -import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from runpod_flash.cli.utils.deployment import ( - provision_resources_for_build, deploy_from_uploaded_build, reconcile_and_provision_resources, ) @@ -38,170 +36,6 @@ def mock_deployed_resource(): return resource -@pytest.mark.asyncio -async def test_provision_resources_for_build_success( - mock_flash_app, - mock_deployed_resource, -): - """Test successful resource provisioning.""" - manifest = { - "resources": { - "cpu": {"resource_type": "ServerlessResource"}, - "gpu": {"resource_type": "ServerlessResource"}, - } - } - mock_flash_app.get_build_manifest.return_value = manifest - - with ( - patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, - patch( - "runpod_flash.cli.utils.deployment.create_resource_from_manifest" - ) as mock_create_resource, - ): - mock_manager = MagicMock() - mock_manager.get_or_deploy_resource = AsyncMock( - return_value=mock_deployed_resource - ) - mock_manager_cls.return_value = mock_manager - mock_create_resource.side_effect = [MagicMock(), MagicMock()] - - result = await provision_resources_for_build( - mock_flash_app, - "build-123", - "dev", - show_progress=False, - ) - - assert len(result) == 2 - assert result["cpu"] == "https://example.com/endpoint" - assert result["gpu"] == "https://example.com/endpoint" - mock_flash_app.update_build_manifest.assert_called_once() - - # Verify manifest was updated with resources_endpoints - call_args = mock_flash_app.update_build_manifest.call_args - updated_manifest = call_args[0][1] - assert "resources_endpoints" in updated_manifest - assert len(updated_manifest["resources_endpoints"]) == 2 - - -@pytest.mark.asyncio -async def test_provision_resources_for_build_no_resources(mock_flash_app): - """Test provisioning with empty manifest.""" - mock_flash_app.get_build_manifest.return_value = {} - - result = await provision_resources_for_build( - mock_flash_app, - "build-123", - "dev", - show_progress=False, - ) - - assert result == {} - mock_flash_app.update_build_manifest.assert_not_called() - - -@pytest.mark.asyncio -async def test_provision_resources_for_build_missing_resources_key(mock_flash_app): - """Test provisioning when resources key is missing.""" - mock_flash_app.get_build_manifest.return_value = {"other_field": "value"} - - result = await provision_resources_for_build( - mock_flash_app, - "build-123", - "dev", - show_progress=False, - ) - - assert result == {} - mock_flash_app.update_build_manifest.assert_not_called() - - -@pytest.mark.asyncio -async def test_provision_resources_for_build_failure( - mock_flash_app, -): - """Test provisioning failure handling.""" - manifest = { - "resources": { - "cpu": {"resource_type": "ServerlessResource"}, - } - } - mock_flash_app.get_build_manifest.return_value = manifest - - with ( - patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, - patch( - "runpod_flash.cli.utils.deployment.create_resource_from_manifest" - ) as mock_create_resource, - ): - mock_manager = MagicMock() - mock_manager.get_or_deploy_resource = AsyncMock( - side_effect=Exception("Deployment failed") - ) - mock_manager_cls.return_value = mock_manager - mock_create_resource.return_value = MagicMock() - - with pytest.raises(RuntimeError) as exc_info: - await provision_resources_for_build( - mock_flash_app, - "build-123", - "dev", - show_progress=False, - ) - - assert "Failed to provision resources" in str(exc_info.value) - mock_flash_app.update_build_manifest.assert_not_called() - - -@pytest.mark.asyncio -async def test_provision_resources_for_build_parallel_execution( - mock_flash_app, -): - """Test that multiple resources provision in parallel.""" - manifest = { - "resources": { - "resource1": {"resource_type": "ServerlessResource"}, - "resource2": {"resource_type": "ServerlessResource"}, - "resource3": {"resource_type": "ServerlessResource"}, - } - } - mock_flash_app.get_build_manifest.return_value = manifest - - call_times = [] - - async def mock_get_or_deploy_resource(resource): - call_times.append(__import__("time").time()) - await asyncio.sleep(0.1) - resource_mock = MagicMock() - resource_mock.endpoint_url = "https://example.com" - return resource_mock - - with ( - patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, - patch( - "runpod_flash.cli.utils.deployment.create_resource_from_manifest" - ) as mock_create_resource, - ): - mock_manager = MagicMock() - mock_manager.get_or_deploy_resource = mock_get_or_deploy_resource - mock_manager_cls.return_value = mock_manager - mock_create_resource.side_effect = [MagicMock(), MagicMock(), MagicMock()] - - await provision_resources_for_build( - mock_flash_app, - "build-123", - "dev", - show_progress=False, - ) - - # All tasks should start roughly at the same time (parallel execution) - # If serial, the time between first and last call would be > 0.2s - # If parallel, it would be < 0.1s - if len(call_times) > 1: - time_diff = max(call_times) - min(call_times) - assert time_diff < 0.15, f"Tasks not parallel: {time_diff}s difference" - - @pytest.mark.asyncio async def test_deploy_from_uploaded_build_success( mock_flash_app, mock_deployed_resource, tmp_path @@ -485,44 +319,6 @@ async def test_deploy_succeeds_without_api_key_when_no_remote_calls(tmp_path): ) -@pytest.mark.asyncio -async def test_provision_resources_persists_ai_key_to_manifest(mock_flash_app): - manifest = { - "resources": { - "cpu": {"resource_type": "ServerlessResource"}, - } - } - mock_flash_app.get_build_manifest.return_value = manifest - - deployed = MagicMock() - deployed.endpoint_url = "https://example.com/endpoint" - deployed.id = "endpoint-123" - deployed.aiKey = "ai-key-123" - - with ( - patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, - patch( - "runpod_flash.cli.utils.deployment.create_resource_from_manifest" - ) as mock_create_resource, - ): - mock_manager = MagicMock() - mock_manager.get_or_deploy_resource = AsyncMock(return_value=deployed) - mock_manager_cls.return_value = mock_manager - mock_create_resource.return_value = MagicMock() - - await provision_resources_for_build( - mock_flash_app, - "build-123", - "dev", - show_progress=False, - ) - - call_args = mock_flash_app.update_build_manifest.call_args - updated_manifest = call_args[0][1] - assert updated_manifest["resources"]["cpu"]["endpoint_id"] == "endpoint-123" - assert updated_manifest["resources"]["cpu"]["aiKey"] == "ai-key-123" - - @pytest.mark.asyncio async def test_reconciliation_copies_ai_key_from_state_manifest(tmp_path): import json diff --git a/tests/unit/cli/utils/test_formatting.py b/tests/unit/cli/utils/test_formatting.py index 1b57e014..407757ac 100644 --- a/tests/unit/cli/utils/test_formatting.py +++ b/tests/unit/cli/utils/test_formatting.py @@ -8,7 +8,6 @@ format_datetime, print_error, print_warning, - state_dot, ) @@ -111,17 +110,3 @@ def test_strips_leading_whitespace_from_message(self): print_warning(console, "\nwatch out") output = buf.getvalue() assert "watch out" in output - - -class TestStateDot: - def test_healthy(self): - assert "[green]●[/green]" in state_dot("HEALTHY") - - def test_building(self): - assert "[yellow]●[/yellow]" in state_dot("BUILDING") - - def test_error(self): - assert "[red]●[/red]" in state_dot("ERROR") - - def test_unknown_defaults_yellow(self): - assert "[yellow]●[/yellow]" in state_dot("WHATEVER") diff --git a/tests/unit/core/api/test_runpod_graphql.py b/tests/unit/core/api/test_runpod_graphql.py index 8b4e33a4..3635238f 100644 --- a/tests/unit/core/api/test_runpod_graphql.py +++ b/tests/unit/core/api/test_runpod_graphql.py @@ -372,17 +372,6 @@ async def test_save_endpoint_update_existing(self): assert result["id"] == "endpoint_123" assert result["name"] == "updated_endpoint" - @pytest.mark.asyncio - async def test_get_endpoint_not_implemented(self): - """Test that get_endpoint is not currently implemented.""" - client = RunpodGraphQLClient(api_key="test_key") - - # get_endpoint is not implemented in current schema - with pytest.raises( - NotImplementedError, match="not available in current schema" - ): - await client.get_endpoint("endpoint_123") - @pytest.mark.asyncio async def test_delete_endpoint(self): """Test deleting an endpoint.""" @@ -423,45 +412,6 @@ async def test_endpoint_exists_false(self): assert exists is False -class TestRunpodGraphQLClientGPUCPU: - """Test GPU and CPU type methods.""" - - @pytest.mark.asyncio - async def test_get_cpu_types(self): - """Test getting available CPU types.""" - client = RunpodGraphQLClient(api_key="test_key") - - expected_cpu_types = [ - {"id": "cpu1", "displayName": "CPU Type 1"}, - {"id": "cpu2", "displayName": "CPU Type 2"}, - ] - - with patch.object(client, "_execute_graphql") as mock_execute: - mock_execute.return_value = {"cpuTypes": expected_cpu_types} - - result = await client.get_cpu_types() - - assert result == expected_cpu_types - - @pytest.mark.asyncio - async def test_get_gpu_types(self): - """Test getting available GPU types.""" - client = RunpodGraphQLClient(api_key="test_key") - - gpu_filter = {"available": True} - expected_gpu_types = [ - {"id": "gpu1", "displayName": "NVIDIA RTX 4090"}, - {"id": "gpu2", "displayName": "NVIDIA A100"}, - ] - - with patch.object(client, "_execute_graphql") as mock_execute: - mock_execute.return_value = {"gpuTypes": expected_gpu_types} - - result = await client.get_gpu_types(gpu_filter) - - assert result == expected_gpu_types - - class TestRunpodGraphQLClientContextManager: """Test async context manager support.""" diff --git a/tests/unit/core/api/test_runpod_graphql_extended.py b/tests/unit/core/api/test_runpod_graphql_extended.py index 9aa9725f..2d6cb734 100644 --- a/tests/unit/core/api/test_runpod_graphql_extended.py +++ b/tests/unit/core/api/test_runpod_graphql_extended.py @@ -42,27 +42,6 @@ class TestGraphQLMutations: "id", "env-1", ), - ( - "register_endpoint_to_environment", - ({"flashEnvironmentId": "env-1", "endpointId": "ep-1"},), - {"addEndpointToFlashEnvironment": {"id": "ep-1", "name": "gpu"}}, - "id", - "ep-1", - ), - ( - "register_network_volume_to_environment", - ({"flashEnvironmentId": "env-1", "networkVolumeId": "nv-1"},), - {"addNetworkVolumeToFlashEnvironment": {"id": "nv-1", "name": "vol"}}, - "id", - "nv-1", - ), - ( - "set_environment_state", - ({"flashEnvironmentId": "env-1", "status": "HEALTHY"},), - {"updateFlashEnvironment": {"id": "env-1", "state": "HEALTHY"}}, - "state", - "HEALTHY", - ), ( "prepare_artifact_upload", ({"flashAppId": "app-1", "tarballSize": 1024},), @@ -216,12 +195,6 @@ class TestGraphQLQueries: {"flashApp": {"flashEnvironments": [{"id": "e1"}, {"id": "e2"}]}}, lambda r: len(r) == 2, ), - ( - "get_gpu_types", - (), - {"gpuTypes": [{"id": "gpu-1"}]}, - lambda r: len(r) == 1, - ), ], ids=lambda x: ( x @@ -262,18 +235,6 @@ async def test_endpoint_exists_handles_api_failure(self): result = await client.endpoint_exists("ep-123") assert result is False - @pytest.mark.asyncio - async def test_get_flash_artifact_url(self): - client = RunpodGraphQLClient(api_key="test") - with patch.object( - client, "get_flash_environment", new_callable=AsyncMock - ) as mock: - mock.return_value = { - "activeArtifact": {"downloadUrl": "https://example.com/dl"} - } - result = await client.get_flash_artifact_url("env-1") - assert "activeArtifact" in result - # ──── REST client ──── diff --git a/tests/unit/core/resources/test_constants.py b/tests/unit/core/resources/test_constants.py index 6384c6a4..d1f201a2 100644 --- a/tests/unit/core/resources/test_constants.py +++ b/tests/unit/core/resources/test_constants.py @@ -8,11 +8,9 @@ from runpod_flash.core.resources.constants import ( CPU_PYTHON_VERSIONS, DEFAULT_PYTHON_VERSION, - GPU_BASE_IMAGE_PYTHON_VERSION, GPU_PYTHON_VERSIONS, SUPPORTED_PYTHON_VERSIONS, get_image_name, - local_python_version, validate_python_version, ) @@ -30,9 +28,6 @@ def test_cpu_python_versions(self): def test_default_python_version_is_3_12(self): assert DEFAULT_PYTHON_VERSION == "3.12" - def test_gpu_base_image_python_version(self): - assert GPU_BASE_IMAGE_PYTHON_VERSION == "3.12" - class TestGetImageName: def test_gpu_3_12(self): @@ -113,14 +108,6 @@ def test_env_var_override_bypasses_gpu_version_constraint(self): assert get_image_name("gpu", "3.8") == "custom/gpu:mine" -class TestLocalPythonVersion: - def test_returns_3_12(self): - assert local_python_version() == "3.12" - - def test_returns_string_type(self): - assert isinstance(local_python_version(), str) - - class TestValidatePythonVersion: def test_valid_versions(self): for v in SUPPORTED_PYTHON_VERSIONS: diff --git a/tests/unit/core/utils/test_file_lock_extended.py b/tests/unit/core/utils/test_file_lock_extended.py index 6bc94d00..0904495f 100644 --- a/tests/unit/core/utils/test_file_lock_extended.py +++ b/tests/unit/core/utils/test_file_lock_extended.py @@ -13,7 +13,6 @@ _release_fallback_lock, _release_unix_lock, file_lock, - get_platform_info, ) @@ -165,27 +164,3 @@ def test_release_fallback_handles_missing_lock_file(self, tmp_path): with open(data_file, "rb") as f: _release_fallback_lock(f) # Should not raise - - -class TestGetPlatformInfo: - """Test get_platform_info function.""" - - def test_returns_dict_with_expected_keys(self): - """Returns dict with platform, locking availability.""" - info = get_platform_info() - assert "platform" in info - assert "windows_locking" in info - assert "unix_locking" in info - assert "fallback_only" in info - - def test_platform_matches_system(self): - """Platform matches current system.""" - info = get_platform_info() - assert info["platform"] == platform.system() - - @pytest.mark.skipif(platform.system() == "Windows", reason="Unix-only test") - def test_unix_locking_available_on_unix(self): - """Unix locking should be available on Unix/macOS.""" - info = get_platform_info() - assert info["unix_locking"] is True - assert info["fallback_only"] is False diff --git a/tests/unit/resources/test_live_serverless.py b/tests/unit/resources/test_live_serverless.py index 4dfeb929..a9c978eb 100644 --- a/tests/unit/resources/test_live_serverless.py +++ b/tests/unit/resources/test_live_serverless.py @@ -4,7 +4,7 @@ import pytest from runpod_flash.core.resources.constants import ( - GPU_BASE_IMAGE_PYTHON_VERSION, + DEFAULT_PYTHON_VERSION, ) from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.live_serverless import ( @@ -223,7 +223,7 @@ class TestLiveServerlessPythonVersion: def test_gpu_default_image_uses_gpu_base_python(self): ls = LiveServerless(name="test") - assert f"py{GPU_BASE_IMAGE_PYTHON_VERSION}" in ls.imageName + assert f"py{DEFAULT_PYTHON_VERSION}" in ls.imageName @pytest.mark.parametrize("version", ["3.10", "3.11", "3.12"]) def test_gpu_explicit_supported_versions(self, version): @@ -250,7 +250,7 @@ class TestLiveLoadBalancerPythonVersion: def test_lb_default_image_uses_gpu_base_python(self): lb = LiveLoadBalancer(name="test") - assert f"py{GPU_BASE_IMAGE_PYTHON_VERSION}" in lb.imageName + assert f"py{DEFAULT_PYTHON_VERSION}" in lb.imageName assert "runpod/flash-lb:" in lb.imageName @pytest.mark.parametrize("version", ["3.10", "3.11", "3.12"]) diff --git a/tests/unit/runtime/test_circuit_breaker.py b/tests/unit/runtime/test_circuit_breaker.py deleted file mode 100644 index ec9c6f81..00000000 --- a/tests/unit/runtime/test_circuit_breaker.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Tests for circuit breaker module.""" - -import asyncio - -import pytest - -from runpod_flash.runtime.circuit_breaker import ( - CircuitBreakerOpenError, - CircuitState, - EndpointCircuitBreaker, -) - - -class TestCircuitState: - """Test CircuitState enum.""" - - def test_states(self): - """Test circuit states exist.""" - assert CircuitState.CLOSED.value == "closed" - assert CircuitState.OPEN.value == "open" - assert CircuitState.HALF_OPEN.value == "half_open" - - -class TestEndpointCircuitBreaker: - """Test EndpointCircuitBreaker class.""" - - @pytest.mark.asyncio - async def test_successful_execution(self): - """Test successful function execution.""" - breaker = EndpointCircuitBreaker( - "http://example.com", - failure_threshold=5, - timeout_seconds=60, - ) - - async def success_func(): - return "success" - - result = await breaker.execute(success_func) - assert result == "success" - assert breaker.get_state() == CircuitState.CLOSED - - @pytest.mark.asyncio - async def test_failed_execution_within_threshold(self): - """Test failed execution within threshold.""" - breaker = EndpointCircuitBreaker( - "http://example.com", - failure_threshold=5, - timeout_seconds=60, - ) - - async def failing_func(): - raise ConnectionError("Connection failed") - - for _ in range(4): # 4 failures, threshold is 5 - with pytest.raises(ConnectionError): - await breaker.execute(failing_func) - assert breaker.get_state() == CircuitState.CLOSED - - @pytest.mark.asyncio - async def test_circuit_opens_at_threshold(self): - """Test circuit opens when failure threshold reached.""" - breaker = EndpointCircuitBreaker( - "http://example.com", - failure_threshold=3, - timeout_seconds=60, - ) - - async def failing_func(): - raise ConnectionError("Connection failed") - - # Reach threshold - for _ in range(3): - with pytest.raises(ConnectionError): - await breaker.execute(failing_func) - - # Circuit should be OPEN now - assert breaker.get_state() == CircuitState.OPEN - - # Further requests should fail immediately - with pytest.raises(CircuitBreakerOpenError): - await breaker.execute(failing_func) - - @pytest.mark.asyncio - async def test_circuit_half_open_after_timeout(self): - """Test circuit transitions to HALF_OPEN after timeout.""" - breaker = EndpointCircuitBreaker( - "http://example.com", - failure_threshold=2, - timeout_seconds=1, - ) - - async def failing_func(): - raise ConnectionError("Connection failed") - - # Open circuit - for _ in range(2): - with pytest.raises(ConnectionError): - await breaker.execute(failing_func) - - assert breaker.get_state() == CircuitState.OPEN - - # Wait for timeout - await asyncio.sleep(1.1) - - # Next attempt should transition to HALF_OPEN - async def success_func(): - return "recovered" - - await breaker.execute(success_func) - # First success in HALF_OPEN doesn't close circuit yet - assert breaker.get_state() == CircuitState.HALF_OPEN - - @pytest.mark.asyncio - async def test_circuit_closes_after_success_threshold(self): - """Test circuit closes after enough successes.""" - breaker = EndpointCircuitBreaker( - "http://example.com", - failure_threshold=2, - success_threshold=2, - timeout_seconds=1, - ) - - async def failing_func(): - raise ConnectionError("Connection failed") - - # Open circuit - for _ in range(2): - with pytest.raises(ConnectionError): - await breaker.execute(failing_func) - - assert breaker.get_state() == CircuitState.OPEN - - # Wait for timeout - await asyncio.sleep(1.1) - - # Succeed enough times to close circuit - async def success_func(): - return "success" - - for _ in range(2): - result = await breaker.execute(success_func) - assert result == "success" - - assert breaker.get_state() == CircuitState.CLOSED - - @pytest.mark.asyncio - async def test_get_stats(self): - """Test getting circuit breaker statistics.""" - breaker = EndpointCircuitBreaker("http://example.com") - - async def success_func(): - return "ok" - - await breaker.execute(success_func) - stats = breaker.get_stats() - assert stats.success_count == 1 - assert stats.failure_count == 0 - assert stats.state == CircuitState.CLOSED - assert stats.total_requests == 1 - - @pytest.mark.asyncio - async def test_half_open_resets_on_failure(self): - """Test that failure in HALF_OPEN opens circuit again.""" - breaker = EndpointCircuitBreaker( - "http://example.com", - failure_threshold=1, - timeout_seconds=1, - ) - - async def failing_func(): - raise ConnectionError("Connection failed") - - # Open circuit - with pytest.raises(ConnectionError): - await breaker.execute(failing_func) - - assert breaker.get_state() == CircuitState.OPEN - - # Wait for timeout - await asyncio.sleep(1.1) - - # Attempt recovery, should transition to HALF_OPEN - with pytest.raises(ConnectionError): - await breaker.execute(failing_func) - - # Should transition back to OPEN on first failure - assert breaker.get_state() == CircuitState.OPEN diff --git a/tests/unit/runtime/test_load_balancer.py b/tests/unit/runtime/test_load_balancer.py deleted file mode 100644 index 89806f2a..00000000 --- a/tests/unit/runtime/test_load_balancer.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Tests for load balancer module.""" - -import pytest - -from runpod_flash.runtime.load_balancer import LoadBalancer -from runpod_flash.runtime.reliability_config import LoadBalancerStrategy - - -class TestLoadBalancer: - """Test LoadBalancer class.""" - - def test_round_robin_selection(self): - """Test round-robin endpoint selection.""" - lb = LoadBalancer(strategy=LoadBalancerStrategy.ROUND_ROBIN) - endpoints = ["http://a.com", "http://b.com", "http://c.com"] - - selected = [] - for _ in range(9): - endpoint = lb._round_robin_index - selected_ep = endpoints[endpoint % len(endpoints)] - lb._round_robin_index += 1 - selected.append(selected_ep) - - # Should cycle through endpoints - assert selected[0] == "http://a.com" - assert selected[1] == "http://b.com" - assert selected[2] == "http://c.com" - assert selected[3] == "http://a.com" - - @pytest.mark.asyncio - async def test_select_endpoint_round_robin(self): - """Test select_endpoint with round-robin.""" - lb = LoadBalancer(strategy=LoadBalancerStrategy.ROUND_ROBIN) - endpoints = ["http://a.com", "http://b.com"] - - selected1 = await lb.select_endpoint(endpoints) - selected2 = await lb.select_endpoint(endpoints) - selected3 = await lb.select_endpoint(endpoints) - - assert selected1 == "http://a.com" - assert selected2 == "http://b.com" - assert selected3 == "http://a.com" - - @pytest.mark.asyncio - async def test_select_endpoint_random(self): - """Test select_endpoint with random strategy.""" - lb = LoadBalancer(strategy=LoadBalancerStrategy.RANDOM) - endpoints = ["http://a.com", "http://b.com"] - - selected = await lb.select_endpoint(endpoints) - assert selected in endpoints - - @pytest.mark.asyncio - async def test_select_endpoint_least_connections(self): - """Test select_endpoint with least connections strategy.""" - lb = LoadBalancer(strategy=LoadBalancerStrategy.LEAST_CONNECTIONS) - endpoints = ["http://a.com", "http://b.com"] - - await lb.record_request(endpoints[0]) - await lb.record_request(endpoints[0]) - - selected = await lb.select_endpoint(endpoints) - assert selected == endpoints[1] - - @pytest.mark.asyncio - async def test_empty_endpoints_returns_none(self): - """Test that empty endpoint list returns None.""" - lb = LoadBalancer() - selected = await lb.select_endpoint([]) - assert selected is None - - @pytest.mark.asyncio - async def test_record_request_and_complete(self): - """Test recording in-flight requests.""" - lb = LoadBalancer() - endpoint = "http://a.com" - - await lb.record_request(endpoint) - stats = lb.get_stats() - assert stats[endpoint] == 1 - - await lb.record_request(endpoint) - stats = lb.get_stats() - assert stats[endpoint] == 2 - - await lb.record_request_complete(endpoint) - stats = lb.get_stats() - assert stats[endpoint] == 1 - - @pytest.mark.asyncio - async def test_record_request_complete_does_not_go_negative(self): - """Test that in-flight count doesn't go negative.""" - lb = LoadBalancer() - endpoint = "http://a.com" - - await lb.record_request_complete(endpoint) - stats = lb.get_stats() - assert stats.get(endpoint, 0) == 0 - - @pytest.mark.asyncio - async def test_select_endpoint_with_circuit_breaker(self): - """Test select_endpoint filters unhealthy endpoints.""" - - class MockCircuitBreaker: - def __init__(self, open_endpoints): - self.open_endpoints = open_endpoints - - def get_state(self, endpoint): - from runpod_flash.runtime.circuit_breaker import CircuitState - - if endpoint in self.open_endpoints: - return CircuitState.OPEN - return CircuitState.CLOSED - - lb = LoadBalancer(strategy=LoadBalancerStrategy.ROUND_ROBIN) - endpoints = ["http://a.com", "http://b.com", "http://c.com"] - circuit_breaker = MockCircuitBreaker({"http://a.com"}) - - # Should skip the open endpoint - selected = await lb.select_endpoint(endpoints, circuit_breaker) - assert selected != "http://a.com" - - @pytest.mark.asyncio - async def test_all_endpoints_unhealthy_returns_none(self): - """Test that all unhealthy endpoints returns None.""" - - class MockCircuitBreaker: - def get_state(self, endpoint): - from runpod_flash.runtime.circuit_breaker import CircuitState - - return CircuitState.OPEN - - lb = LoadBalancer() - endpoints = ["http://a.com", "http://b.com"] - circuit_breaker = MockCircuitBreaker() - - selected = await lb.select_endpoint(endpoints, circuit_breaker) - assert selected is None - - def test_get_stats(self): - """Test getting load balancer statistics.""" - lb = LoadBalancer() - stats = lb.get_stats() - assert isinstance(stats, dict) - assert len(stats) == 0 diff --git a/tests/unit/runtime/test_metrics.py b/tests/unit/runtime/test_metrics.py deleted file mode 100644 index 564ac965..00000000 --- a/tests/unit/runtime/test_metrics.py +++ /dev/null @@ -1,502 +0,0 @@ -"""Tests for metrics collection module.""" - -import logging -from unittest.mock import patch - - -from runpod_flash.runtime.metrics import ( - CircuitBreakerMetrics, - LoadBalancerMetrics, - Metric, - MetricsCollector, - MetricType, - RetryMetrics, - get_metrics_collector, - set_metrics_collector, -) - - -class TestMetricType: - """Test MetricType enum.""" - - def test_metric_types(self): - """Test all metric type values.""" - assert MetricType.COUNTER.value == "counter" - assert MetricType.GAUGE.value == "gauge" - assert MetricType.HISTOGRAM.value == "histogram" - - -class TestMetric: - """Test Metric dataclass.""" - - def test_metric_creation(self): - """Test creating a Metric instance.""" - metric = Metric( - metric_type=MetricType.COUNTER, - metric_name="test_counter", - value=5.0, - labels={"endpoint": "test"}, - ) - - assert metric.metric_type == MetricType.COUNTER - assert metric.metric_name == "test_counter" - assert metric.value == 5.0 - assert metric.labels["endpoint"] == "test" - - def test_metric_to_dict(self): - """Test converting Metric to dictionary.""" - metric = Metric( - metric_type=MetricType.GAUGE, - metric_name="memory_usage", - value=75.5, - labels={"unit": "percent"}, - ) - - metric_dict = metric.to_dict() - - assert metric_dict["metric_type"] == MetricType.GAUGE - assert metric_dict["metric_name"] == "memory_usage" - assert metric_dict["value"] == 75.5 - assert metric_dict["labels"]["unit"] == "percent" - - def test_metric_with_empty_labels(self): - """Test Metric with empty labels.""" - metric = Metric( - metric_type=MetricType.HISTOGRAM, - metric_name="latency", - value=100.0, - labels={}, - ) - - assert metric.labels == {} - - def test_metric_with_multiple_labels(self): - """Test Metric with multiple labels.""" - metric = Metric( - metric_type=MetricType.COUNTER, - metric_name="requests", - value=1.0, - labels={ - "endpoint": "/api/users", - "method": "GET", - "status": "200", - "region": "us-east-1", - }, - ) - - assert len(metric.labels) == 4 - assert metric.labels["method"] == "GET" - - -class TestMetricsCollector: - """Test MetricsCollector functionality.""" - - def test_collector_initialization(self): - """Test MetricsCollector initialization.""" - collector = MetricsCollector(namespace="test.metrics", enabled=True) - - assert collector.namespace == "test.metrics" - assert collector.enabled is True - - def test_collector_disabled(self): - """Test that disabled collector doesn't emit metrics.""" - collector = MetricsCollector(enabled=False) - - with patch.object(collector, "_emit") as mock_emit: - collector.counter("test_counter", value=1.0) - collector.gauge("test_gauge", value=50.0) - collector.histogram("test_histogram", value=100.0) - - # _emit should never be called - mock_emit.assert_not_called() - - def test_counter_metric(self): - """Test emitting counter metric.""" - collector = MetricsCollector() - - with patch.object(collector, "_emit") as mock_emit: - collector.counter("request_count", value=5.0, labels={"endpoint": "api"}) - - mock_emit.assert_called_once() - metric = mock_emit.call_args[0][0] - assert metric.metric_type == MetricType.COUNTER - assert metric.metric_name == "request_count" - assert metric.value == 5.0 - - def test_counter_default_value(self): - """Test counter with default value of 1.0.""" - collector = MetricsCollector() - - with patch.object(collector, "_emit") as mock_emit: - collector.counter("page_views") - - metric = mock_emit.call_args[0][0] - assert metric.value == 1.0 - - def test_gauge_metric(self): - """Test emitting gauge metric.""" - collector = MetricsCollector() - - with patch.object(collector, "_emit") as mock_emit: - collector.gauge("cpu_usage", value=75.5, labels={"host": "server1"}) - - mock_emit.assert_called_once() - metric = mock_emit.call_args[0][0] - assert metric.metric_type == MetricType.GAUGE - assert metric.metric_name == "cpu_usage" - assert metric.value == 75.5 - - def test_histogram_metric(self): - """Test emitting histogram metric.""" - collector = MetricsCollector() - - with patch.object(collector, "_emit") as mock_emit: - collector.histogram("request_duration", value=250.0, labels={"path": "/"}) - - mock_emit.assert_called_once() - metric = mock_emit.call_args[0][0] - assert metric.metric_type == MetricType.HISTOGRAM - assert metric.metric_name == "request_duration" - assert metric.value == 250.0 - - def test_emit_logs_metric(self, caplog): - """Test that _emit logs metrics.""" - collector = MetricsCollector(namespace="test.metrics") - - with caplog.at_level(logging.INFO): - metric = Metric( - metric_type=MetricType.COUNTER, - metric_name="test_metric", - value=1.0, - labels={}, - ) - collector._emit(metric) - - assert "[METRIC] test_metric=1.0" in caplog.text - - def test_emit_handles_exceptions(self, caplog): - """Test that _emit handles logging exceptions gracefully.""" - collector = MetricsCollector() - - # Create a metric that will cause an exception during logging - metric = Metric( - metric_type=MetricType.COUNTER, - metric_name="bad_metric", - value=1.0, - labels={}, - ) - - # Patch logger.info to raise exception - with patch( - "runpod_flash.runtime.metrics.logger.info", - side_effect=Exception("Log error"), - ): - with caplog.at_level(logging.ERROR): - collector._emit(metric) - - assert "Failed to emit metric" in caplog.text - - def test_collector_with_no_labels(self): - """Test metrics without labels.""" - collector = MetricsCollector() - - with patch.object(collector, "_emit") as mock_emit: - collector.counter("simple_counter") - - metric = mock_emit.call_args[0][0] - assert metric.labels == {} - - -class TestGlobalMetricsCollector: - """Test global metrics collector functions.""" - - def test_get_metrics_collector_lazy_load(self): - """Test lazy loading of global metrics collector.""" - # Reset global collector - import runpod_flash.runtime.metrics as metrics_module - - metrics_module._collector = None - - collector1 = get_metrics_collector() - collector2 = get_metrics_collector() - - # Should return same instance - assert collector1 is collector2 - - def test_get_metrics_collector_with_params(self): - """Test get_metrics_collector with custom parameters.""" - import runpod_flash.runtime.metrics as metrics_module - - metrics_module._collector = None - - collector = get_metrics_collector(namespace="custom.metrics", enabled=False) - - # Note: After first call, namespace is set and won't change - assert collector.namespace == "custom.metrics" - assert collector.enabled is False - - def test_set_metrics_collector(self): - """Test setting custom metrics collector.""" - custom_collector = MetricsCollector(namespace="custom") - - set_metrics_collector(custom_collector) - - retrieved = get_metrics_collector() - assert retrieved is custom_collector - - -class TestCircuitBreakerMetrics: - """Test CircuitBreakerMetrics helper.""" - - def test_initialization(self): - """Test CircuitBreakerMetrics initialization.""" - collector = MetricsCollector() - cb_metrics = CircuitBreakerMetrics(collector=collector) - - assert cb_metrics.collector is collector - - def test_initialization_uses_global_collector(self): - """Test that CircuitBreakerMetrics uses global collector by default.""" - cb_metrics = CircuitBreakerMetrics() - - assert cb_metrics.collector is not None - - def test_state_changed_metric(self): - """Test emitting state change metric.""" - collector = MetricsCollector() - cb_metrics = CircuitBreakerMetrics(collector=collector) - - with patch.object(collector, "counter") as mock_counter: - cb_metrics.state_changed( - endpoint_url="https://test.com", - new_state="OPEN", - previous_state="CLOSED", - ) - - mock_counter.assert_called_once() - assert mock_counter.call_args[0][0] == "circuit_breaker_state_changes" - labels = mock_counter.call_args[1]["labels"] - assert labels["new_state"] == "OPEN" - assert labels["previous_state"] == "CLOSED" - - def test_endpoint_requests_metric(self): - """Test emitting endpoint requests metric.""" - collector = MetricsCollector() - cb_metrics = CircuitBreakerMetrics(collector=collector) - - with patch.object(collector, "counter") as mock_counter: - cb_metrics.endpoint_requests( - endpoint_url="https://test.com", - status="success", - count=5, - ) - - mock_counter.assert_called_once() - assert mock_counter.call_args[1]["value"] == 5.0 - - def test_endpoint_latency_metric(self): - """Test emitting endpoint latency metric.""" - collector = MetricsCollector() - cb_metrics = CircuitBreakerMetrics(collector=collector) - - with patch.object(collector, "histogram") as mock_histogram: - cb_metrics.endpoint_latency( - endpoint_url="https://test.com", - latency_ms=150.5, - ) - - mock_histogram.assert_called_once() - assert mock_histogram.call_args[1]["value"] == 150.5 - - def test_in_flight_requests_metric(self): - """Test emitting in-flight requests metric.""" - collector = MetricsCollector() - cb_metrics = CircuitBreakerMetrics(collector=collector) - - with patch.object(collector, "gauge") as mock_gauge: - cb_metrics.in_flight_requests( - endpoint_url="https://test.com", - count=3, - ) - - mock_gauge.assert_called_once() - assert mock_gauge.call_args[1]["value"] == 3.0 - - -class TestRetryMetrics: - """Test RetryMetrics helper.""" - - def test_initialization(self): - """Test RetryMetrics initialization.""" - collector = MetricsCollector() - retry_metrics = RetryMetrics(collector=collector) - - assert retry_metrics.collector is collector - - def test_retry_attempt_metric(self): - """Test emitting retry attempt metric.""" - collector = MetricsCollector() - retry_metrics = RetryMetrics(collector=collector) - - with patch.object(collector, "counter") as mock_counter: - retry_metrics.retry_attempt( - function_name="test_function", - attempt=2, - error="Connection timeout", - ) - - mock_counter.assert_called_once() - labels = mock_counter.call_args[1]["labels"] - assert labels["function_name"] == "test_function" - assert labels["attempt"] == "2" - assert labels["error"] == "Connection timeout" - - def test_retry_attempt_without_error(self): - """Test retry attempt metric without error message.""" - collector = MetricsCollector() - retry_metrics = RetryMetrics(collector=collector) - - with patch.object(collector, "counter") as mock_counter: - retry_metrics.retry_attempt( - function_name="test_function", - attempt=1, - ) - - labels = mock_counter.call_args[1]["labels"] - assert "error" not in labels - - def test_retry_success_metric(self): - """Test emitting retry success metric.""" - collector = MetricsCollector() - retry_metrics = RetryMetrics(collector=collector) - - with patch.object(collector, "counter") as mock_counter: - retry_metrics.retry_success( - function_name="test_function", - total_attempts=3, - ) - - mock_counter.assert_called_once() - assert mock_counter.call_args[0][0] == "retry_success" - labels = mock_counter.call_args[1]["labels"] - assert labels["attempts"] == "3" - - def test_retry_exhausted_metric(self): - """Test emitting retry exhausted metric.""" - collector = MetricsCollector() - retry_metrics = RetryMetrics(collector=collector) - - with patch.object(collector, "counter") as mock_counter: - retry_metrics.retry_exhausted( - function_name="test_function", - max_attempts=5, - ) - - mock_counter.assert_called_once() - assert mock_counter.call_args[0][0] == "retry_exhausted" - labels = mock_counter.call_args[1]["labels"] - assert labels["max_attempts"] == "5" - - -class TestLoadBalancerMetrics: - """Test LoadBalancerMetrics helper.""" - - def test_initialization(self): - """Test LoadBalancerMetrics initialization.""" - collector = MetricsCollector() - lb_metrics = LoadBalancerMetrics(collector=collector) - - assert lb_metrics.collector is collector - - def test_endpoint_selected_metric(self): - """Test emitting endpoint selected metric.""" - collector = MetricsCollector() - lb_metrics = LoadBalancerMetrics(collector=collector) - - with patch.object(collector, "counter") as mock_counter: - lb_metrics.endpoint_selected( - strategy="round_robin", - endpoint_url="https://endpoint1.com", - total_candidates=3, - ) - - mock_counter.assert_called_once() - assert mock_counter.call_args[0][0] == "load_balancer_selection" - labels = mock_counter.call_args[1]["labels"] - assert labels["strategy"] == "round_robin" - assert labels["endpoint_url"] == "https://endpoint1.com" - assert labels["candidates"] == "3" - - def test_endpoint_selected_with_different_strategies(self): - """Test endpoint selection with various strategies.""" - collector = MetricsCollector() - lb_metrics = LoadBalancerMetrics(collector=collector) - - strategies = ["round_robin", "random", "least_connections"] - - with patch.object(collector, "counter") as mock_counter: - for strategy in strategies: - lb_metrics.endpoint_selected( - strategy=strategy, - endpoint_url=f"https://{strategy}.com", - total_candidates=5, - ) - - assert mock_counter.call_count == 3 - - -class TestMetricsIntegration: - """Test integration scenarios with metrics.""" - - def test_multiple_metric_types_together(self): - """Test emitting different metric types.""" - collector = MetricsCollector() - - with patch.object(collector, "_emit") as mock_emit: - collector.counter("requests", value=100.0) - collector.gauge("memory_usage", value=75.0) - collector.histogram("latency", value=50.0) - - assert mock_emit.call_count == 3 - - def test_metrics_with_complex_labels(self): - """Test metrics with complex label structures.""" - collector = MetricsCollector() - - labels = { - "endpoint": "https://api.example.com/users", - "method": "POST", - "status_code": "201", - "user_agent": "Mozilla/5.0", - "region": "us-west-2", - "version": "v1.2.3", - } - - with patch.object(collector, "_emit") as mock_emit: - collector.counter("api_requests", value=1.0, labels=labels) - - metric = mock_emit.call_args[0][0] - assert len(metric.labels) == 6 - - def test_metrics_lifecycle(self): - """Test complete metrics lifecycle.""" - # Initialize - collector = MetricsCollector(namespace="app.metrics") - - # Emit various metrics - cb_metrics = CircuitBreakerMetrics(collector=collector) - retry_metrics = RetryMetrics(collector=collector) - lb_metrics = LoadBalancerMetrics(collector=collector) - - with patch.object(collector, "_emit") as mock_emit: - # Circuit breaker metrics - cb_metrics.state_changed("https://test.com", "OPEN", "CLOSED") - - # Retry metrics - retry_metrics.retry_attempt("test_func", 1) - - # Load balancer metrics - lb_metrics.endpoint_selected("random", "https://test.com", 3) - - # Should have emitted 3 metrics - assert mock_emit.call_count == 3 diff --git a/tests/unit/runtime/test_reliability_config.py b/tests/unit/runtime/test_reliability_config.py deleted file mode 100644 index 91099a45..00000000 --- a/tests/unit/runtime/test_reliability_config.py +++ /dev/null @@ -1,181 +0,0 @@ -"""Tests for reliability configuration module.""" - -from runpod_flash.runtime.reliability_config import ( - CircuitBreakerConfig, - LoadBalancerConfig, - LoadBalancerStrategy, - MetricsConfig, - ReliabilityConfig, - RetryConfig, - get_reliability_config, -) - - -class TestCircuitBreakerConfig: - """Test CircuitBreakerConfig dataclass.""" - - def test_defaults(self): - """Test default values.""" - config = CircuitBreakerConfig() - assert config.enabled is True - assert config.failure_threshold == 5 - assert config.success_threshold == 2 - assert config.timeout_seconds == 60 - assert config.window_size == 10 - - def test_custom_values(self): - """Test with custom values.""" - config = CircuitBreakerConfig( - enabled=False, - failure_threshold=10, - success_threshold=3, - timeout_seconds=30, - window_size=20, - ) - assert config.enabled is False - assert config.failure_threshold == 10 - assert config.success_threshold == 3 - assert config.timeout_seconds == 30 - assert config.window_size == 20 - - -class TestLoadBalancerConfig: - """Test LoadBalancerConfig dataclass.""" - - def test_defaults(self): - """Test default values.""" - config = LoadBalancerConfig() - assert config.enabled is False - assert config.strategy == LoadBalancerStrategy.ROUND_ROBIN - - def test_custom_values(self): - """Test with custom values.""" - config = LoadBalancerConfig( - enabled=True, - strategy=LoadBalancerStrategy.LEAST_CONNECTIONS, - ) - assert config.enabled is True - assert config.strategy == LoadBalancerStrategy.LEAST_CONNECTIONS - - -class TestRetryConfig: - """Test RetryConfig dataclass.""" - - def test_defaults(self): - """Test default values.""" - config = RetryConfig() - assert config.enabled is True - assert config.max_attempts == 3 - assert config.base_delay == 0.5 - assert config.max_delay == 10.0 - assert config.jitter == 0.2 - assert 408 in config.retryable_status_codes - assert 500 in config.retryable_status_codes - - def test_custom_values(self): - """Test with custom values.""" - config = RetryConfig( - enabled=False, - max_attempts=5, - base_delay=1.0, - max_delay=20.0, - jitter=0.1, - ) - assert config.enabled is False - assert config.max_attempts == 5 - assert config.base_delay == 1.0 - assert config.max_delay == 20.0 - assert config.jitter == 0.1 - - -class TestMetricsConfig: - """Test MetricsConfig dataclass.""" - - def test_defaults(self): - """Test default values.""" - config = MetricsConfig() - assert config.enabled is True - assert config.namespace == "flash.metrics" - - def test_custom_values(self): - """Test with custom values.""" - config = MetricsConfig(enabled=False, namespace="custom.metrics") - assert config.enabled is False - assert config.namespace == "custom.metrics" - - -class TestReliabilityConfig: - """Test ReliabilityConfig dataclass.""" - - def test_defaults(self): - """Test default values.""" - config = ReliabilityConfig() - assert config.circuit_breaker is not None - assert config.load_balancer is not None - assert config.retry is not None - assert config.metrics is not None - assert config.circuit_breaker.enabled is True - assert config.load_balancer.enabled is False - assert config.retry.enabled is True - assert config.metrics.enabled is True - - def test_custom_nested_configs(self): - """Test with custom nested configurations.""" - cb_config = CircuitBreakerConfig(enabled=False) - lb_config = LoadBalancerConfig(enabled=True) - config = ReliabilityConfig( - circuit_breaker=cb_config, - load_balancer=lb_config, - ) - assert config.circuit_breaker.enabled is False - assert config.load_balancer.enabled is True - - def test_from_env_default(self, monkeypatch): - """Test from_env with no environment variables.""" - monkeypatch.delenv("FLASH_CIRCUIT_BREAKER_ENABLED", raising=False) - config = ReliabilityConfig.from_env() - assert config.circuit_breaker.enabled is True - assert config.load_balancer.enabled is False - assert config.retry.enabled is True - - def test_from_env_custom(self, monkeypatch): - """Test from_env with custom environment variables.""" - monkeypatch.setenv("FLASH_CIRCUIT_BREAKER_ENABLED", "false") - monkeypatch.setenv("FLASH_LOAD_BALANCER_ENABLED", "true") - monkeypatch.setenv("FLASH_CB_FAILURE_THRESHOLD", "10") - config = ReliabilityConfig.from_env() - assert config.circuit_breaker.enabled is False - assert config.load_balancer.enabled is True - assert config.circuit_breaker.failure_threshold == 10 - - def test_from_env_load_balancer_strategy(self, monkeypatch): - """Test from_env with load balancer strategy.""" - monkeypatch.setenv("FLASH_LB_STRATEGY", "least_connections") - config = ReliabilityConfig.from_env() - assert config.load_balancer.strategy == LoadBalancerStrategy.LEAST_CONNECTIONS - - def test_from_env_invalid_strategy_defaults(self, monkeypatch): - """Test from_env with invalid strategy defaults to round_robin.""" - monkeypatch.setenv("FLASH_LB_STRATEGY", "invalid_strategy") - config = ReliabilityConfig.from_env() - assert config.load_balancer.strategy == LoadBalancerStrategy.ROUND_ROBIN - - -class TestLoadBalancerStrategy: - """Test LoadBalancerStrategy enum.""" - - def test_strategy_values(self): - """Test that strategies have correct values.""" - assert LoadBalancerStrategy.ROUND_ROBIN.value == "round_robin" - assert LoadBalancerStrategy.LEAST_CONNECTIONS.value == "least_connections" - assert LoadBalancerStrategy.RANDOM.value == "random" - - -class TestGlobalConfig: - """Test global configuration accessor.""" - - def test_get_reliability_config(self): - """Test getting global reliability config.""" - config = get_reliability_config() - assert isinstance(config, ReliabilityConfig) - assert config.circuit_breaker is not None diff --git a/tests/unit/runtime/test_retry_manager.py b/tests/unit/runtime/test_retry_manager.py deleted file mode 100644 index cdfb326a..00000000 --- a/tests/unit/runtime/test_retry_manager.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Tests for retry manager module.""" - -import asyncio - -import pytest - -from runpod_flash.runtime.retry_manager import RetryExhaustedError, retry_with_backoff - - -class TestRetryWithBackoff: - """Test retry_with_backoff function.""" - - @pytest.mark.asyncio - async def test_successful_first_attempt(self): - """Test successful execution on first attempt.""" - - async def success_func(): - return "success" - - result = await retry_with_backoff(success_func, max_attempts=3) - assert result == "success" - - @pytest.mark.asyncio - async def test_non_retryable_exception_raises_immediately(self): - """Test that non-retryable exceptions raise immediately.""" - - async def failing_func(): - raise ValueError("Non-retryable error") - - with pytest.raises(ValueError): - await retry_with_backoff(failing_func, max_attempts=3) - - @pytest.mark.asyncio - async def test_retryable_exception_retries(self): - """Test that retryable exceptions are retried.""" - attempt_count = 0 - - async def failing_then_success(): - nonlocal attempt_count - attempt_count += 1 - if attempt_count < 2: - raise ConnectionError("Connection failed") - return "success" - - result = await retry_with_backoff( - failing_then_success, - max_attempts=3, - base_delay=0.01, - max_delay=0.1, - ) - assert result == "success" - assert attempt_count == 2 - - @pytest.mark.asyncio - async def test_max_retries_exhausted(self): - """Test that RetryExhaustedError is raised after max attempts.""" - - async def always_fails(): - raise ConnectionError("Always fails") - - with pytest.raises(RetryExhaustedError): - await retry_with_backoff( - always_fails, - max_attempts=2, - base_delay=0.01, - max_delay=0.1, - ) - - @pytest.mark.asyncio - async def test_timeout_is_retryable(self): - """Test that asyncio.TimeoutError is retried by default.""" - attempt_count = 0 - - async def timeout_then_success(): - nonlocal attempt_count - attempt_count += 1 - if attempt_count < 2: - raise asyncio.TimeoutError("Request timed out") - return "success" - - result = await retry_with_backoff( - timeout_then_success, - max_attempts=3, - base_delay=0.01, - max_delay=0.1, - ) - assert result == "success" - - @pytest.mark.asyncio - async def test_custom_retryable_exceptions(self): - """Test with custom retryable exceptions.""" - - class CustomError(Exception): - pass - - attempt_count = 0 - - async def custom_error_then_success(): - nonlocal attempt_count - attempt_count += 1 - if attempt_count < 2: - raise CustomError("Custom error") - return "success" - - result = await retry_with_backoff( - custom_error_then_success, - max_attempts=3, - retryable_exceptions=(CustomError,), - base_delay=0.01, - max_delay=0.1, - ) - assert result == "success" - - @pytest.mark.asyncio - async def test_exponential_backoff(self): - """Test that backoff increases exponentially.""" - attempt_times = [] - - async def track_attempts(): - attempt_times.append(asyncio.get_event_loop().time()) - if len(attempt_times) < 3: - raise ConnectionError("Failed") - return "success" - - result = await retry_with_backoff( - track_attempts, - max_attempts=3, - base_delay=0.05, - max_delay=1.0, - jitter=0.0, # No jitter for predictable timing - ) - assert result == "success" - # Should have at least 3 attempts with delays between them - assert len(attempt_times) == 3 - - @pytest.mark.asyncio - async def test_with_args_and_kwargs(self): - """Test retry with function arguments.""" - - async def add(a, b): - return a + b - - result = await retry_with_backoff(add, max_attempts=1, a=2, b=3) - assert result == 5 - - @pytest.mark.asyncio - async def test_retry_with_circuit_breaker_open(self): - """Test that open circuit breaker prevents retries.""" - - class MockCircuitBreaker: - def get_state(self): - from runpod_flash.runtime.circuit_breaker import CircuitState - - return CircuitState.OPEN - - async def failing_func(): - raise ConnectionError("Failed") - - with pytest.raises(RuntimeError, match="Circuit breaker OPEN"): - await retry_with_backoff( - failing_func, - max_attempts=3, - circuit_breaker=MockCircuitBreaker(), - base_delay=0.01, - ) diff --git a/tests/unit/test_dotenv_loading.py b/tests/unit/test_dotenv_loading.py index 208db63d..73196796 100644 --- a/tests/unit/test_dotenv_loading.py +++ b/tests/unit/test_dotenv_loading.py @@ -188,8 +188,6 @@ def test_env_vars_available_after_flash_import(self, preserve_runpod_flash_modul # Set up test environment variables test_env_vars = { "RUNPOD_API_KEY": "test_key_12345", - "FLASH_GPU_IMAGE": "test/gpu:latest", - "FLASH_CPU_IMAGE": "test/cpu:latest", "LOG_LEVEL": "WARNING", } @@ -221,14 +219,6 @@ def test_env_vars_available_after_flash_import(self, preserve_runpod_flash_modul # Import specific modules that use environment variables from runpod_flash.core.api.runpod import RunpodGraphQLClient - from runpod_flash.core.resources.constants import ( - FLASH_GPU_IMAGE, - FLASH_CPU_IMAGE, - ) - - # Verify that the environment variables are accessible in imported modules - assert FLASH_GPU_IMAGE == "test/gpu:latest" - assert FLASH_CPU_IMAGE == "test/cpu:latest" # Test that RunpodGraphQLClient can access the API key try: diff --git a/tests/unit/test_file_locking.py b/tests/unit/test_file_locking.py index 12b77352..2c2f7a80 100644 --- a/tests/unit/test_file_locking.py +++ b/tests/unit/test_file_locking.py @@ -13,7 +13,6 @@ import tempfile import time from pathlib import Path -from unittest.mock import patch import pytest @@ -21,49 +20,11 @@ file_lock, FileLockError, FileLockTimeout, - get_platform_info, _acquire_fallback_lock, _release_fallback_lock, ) -class TestPlatformDetection: - """Test platform detection and capabilities.""" - - def test_get_platform_info(self): - """Test that platform info returns expected structure.""" - info = get_platform_info() - - required_keys = ["platform", "windows_locking", "unix_locking", "fallback_only"] - assert all(key in info for key in required_keys) - - # Platform should be one of the expected values - assert info["platform"] in ("Windows", "Linux", "Darwin") - - # Exactly one locking mechanism should be available (or fallback) - locking_mechanisms = [ - info["windows_locking"], - info["unix_locking"], - info["fallback_only"], - ] - assert sum(locking_mechanisms) >= 1 # At least fallback should work - - @patch("runpod_flash.core.utils.file_lock.platform.system", return_value="Windows") - def test_platform_detection_windows(self, mock_system): - """Test Windows platform detection via get_platform_info().""" - # Don't use reload() — it pollutes module-level state (_IS_WINDOWS, - # _UNIX_LOCKING_AVAILABLE, etc.) for all subsequent tests. - # get_platform_info() calls platform.system() at runtime, so patching suffices. - info = get_platform_info() - assert info["platform"] == "Windows" - - @patch("runpod_flash.core.utils.file_lock.platform.system", return_value="Linux") - def test_platform_detection_linux(self, mock_system): - """Test Linux platform detection via get_platform_info().""" - info = get_platform_info() - assert info["platform"] == "Linux" - - class TestFileLocking: """Test cross-platform file locking functionality.""" diff --git a/tests/unit/test_load_balancer_sls_resource.py b/tests/unit/test_load_balancer_sls_resource.py index 233ca84d..6f3336d1 100644 --- a/tests/unit/test_load_balancer_sls_resource.py +++ b/tests/unit/test_load_balancer_sls_resource.py @@ -118,248 +118,6 @@ def test_endpoint_url_raises_without_id(self): _ = resource.endpoint_url -class TestLoadBalancerSlsResourceHealthCheck: - """Test health check functionality.""" - - @staticmethod - def _create_mock_client( - status_code: int = 200, error: Exception = None - ) -> MagicMock: - """Create properly configured async context manager mock client.""" - mock_response = AsyncMock() - mock_response.status_code = status_code - mock_client = MagicMock() - if error: - mock_client.get = AsyncMock(side_effect=error) - else: - mock_client.get = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - return mock_client - - @pytest.mark.asyncio - async def test_check_ping_endpoint_success(self): - """Test successful ping endpoint check with ID set.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - mock_client = self._create_mock_client(200) - with ( - patch.object( - LoadBalancerSlsResource, - "endpoint_url", - new_callable=lambda: property(lambda self: "https://test-endpoint.com"), - ), - patch( - "runpod_flash.core.utils.http.httpx.AsyncClient", - return_value=mock_client, - ), - ): - result = await resource._check_ping_endpoint() - - assert result is True - - @pytest.mark.asyncio - async def test_check_ping_endpoint_initializing(self): - """Test ping endpoint returning 204 (initializing).""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - mock_client = self._create_mock_client(204) - with ( - patch.object( - LoadBalancerSlsResource, - "endpoint_url", - new_callable=lambda: property(lambda self: "https://test-endpoint.com"), - ), - patch( - "runpod_flash.core.utils.http.httpx.AsyncClient", - return_value=mock_client, - ), - ): - result = await resource._check_ping_endpoint() - - assert result is True - - @pytest.mark.asyncio - async def test_check_ping_endpoint_failure(self): - """Test ping endpoint returning unhealthy status.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - mock_client = self._create_mock_client(503) - with ( - patch.object( - LoadBalancerSlsResource, - "endpoint_url", - new_callable=lambda: property(lambda self: "https://test-endpoint.com"), - ), - patch( - "runpod_flash.core.utils.http.httpx.AsyncClient", - return_value=mock_client, - ), - ): - result = await resource._check_ping_endpoint() - - assert result is False - - @pytest.mark.asyncio - async def test_check_ping_endpoint_connection_error(self): - """Test ping endpoint with connection error.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - mock_client = self._create_mock_client( - error=ConnectionError("Connection refused") - ) - with ( - patch.object( - LoadBalancerSlsResource, - "endpoint_url", - new_callable=lambda: property(lambda self: "https://test-endpoint.com"), - ), - patch( - "runpod_flash.core.utils.http.httpx.AsyncClient", - return_value=mock_client, - ), - ): - result = await resource._check_ping_endpoint() - - assert result is False - - @pytest.mark.asyncio - async def test_check_ping_endpoint_no_id(self): - """Test ping check when endpoint ID is not set.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - # id not set - ) - - result = await resource._check_ping_endpoint() - assert result is False - - @pytest.mark.asyncio - async def test_wait_for_health_success(self): - """Test health check polling with successful response.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - with patch.object(resource, "_check_ping_endpoint") as mock_check: - mock_check.return_value = True - - result = await resource._wait_for_health(max_retries=3) - - assert result is True - mock_check.assert_called_once() - - @pytest.mark.asyncio - async def test_wait_for_health_retry_then_success(self): - """Test health check polling with retries before success.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - with patch.object(resource, "_check_ping_endpoint") as mock_check: - # Fail twice, then succeed - mock_check.side_effect = [False, False, True] - - result = await resource._wait_for_health(max_retries=5, retry_interval=0) - - assert result is True - assert mock_check.call_count == 3 - - @pytest.mark.asyncio - async def test_wait_for_health_timeout(self): - """Test health check polling timeout after max retries.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - with patch.object(resource, "_check_ping_endpoint") as mock_check: - mock_check.return_value = False - - result = await resource._wait_for_health(max_retries=3, retry_interval=0) - - assert result is False - assert mock_check.call_count == 3 - - @pytest.mark.asyncio - async def test_wait_for_health_no_id(self): - """Test health check when endpoint ID not set.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - # id not set - ) - - with pytest.raises(ValueError, match="Cannot wait for health"): - await resource._wait_for_health() - - @pytest.mark.asyncio - async def test_is_deployed_async_with_id(self): - """Test is_deployed_async returns True when healthy.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - with patch.object(resource, "_check_ping_endpoint") as mock_check: - mock_check.return_value = True - - result = await resource.is_deployed_async() - - assert result is True - - @pytest.mark.asyncio - async def test_is_deployed_async_without_id(self): - """Test is_deployed_async returns False when ID not set.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - ) - - result = await resource.is_deployed_async() - - assert result is False - - @pytest.mark.asyncio - async def test_is_deployed_async_unhealthy(self): - """Test is_deployed_async returns False when unhealthy.""" - resource = LoadBalancerSlsResource( - name="test", - imageName="image", - id="test-endpoint-id", - ) - - with patch.object(resource, "_check_ping_endpoint") as mock_check: - mock_check.return_value = False - - result = await resource.is_deployed_async() - - assert result is False - - class TestLoadBalancerSlsResourceDeployment: """Test deployment flow.""" @@ -415,9 +173,6 @@ async def test_do_deploy_success(self): new_callable=AsyncMock, return_value=False, ), - patch.object( - resource, "_wait_for_health", new_callable=AsyncMock - ) as mock_wait, patch.object( ServerlessResource, "_do_deploy", @@ -428,8 +183,6 @@ async def test_do_deploy_success(self): result = await resource._do_deploy() assert result == mock_deployed - # Health check should not be called during deployment - mock_wait.assert_not_called() @pytest.mark.asyncio async def test_do_deploy_parent_deploy_failure(self): diff --git a/tests/unit/test_logger_sensitive_data.py b/tests/unit/test_logger_sensitive_data.py index e5ad2640..939b3cf9 100644 --- a/tests/unit/test_logger_sensitive_data.py +++ b/tests/unit/test_logger_sensitive_data.py @@ -150,7 +150,7 @@ def test_long_token_partial_redaction(self): assert long_token not in record.msg def test_short_token_not_redacted(self): - """Verify short tokens (<32 chars) are not redacted by TOKEN_PATTERN.""" + """Verify short tokens (<32 chars) are not redacted.""" filter_instance = SensitiveDataFilter() # Short string won't match the 32+ pattern, so it's not redacted @@ -166,7 +166,7 @@ def test_short_token_not_redacted(self): ) filter_instance.filter(record) - # Short tokens aren't matched by TOKEN_PATTERN (requires 32+ chars) + # short tokens are not redacted by the sensitive data filter assert short_token in record.msg # Should not be redacted def test_multiple_sensitive_patterns(self): diff --git a/tests/unit/test_p2_remaining_gaps.py b/tests/unit/test_p2_remaining_gaps.py index 8d09cbf6..ff0788ba 100644 --- a/tests/unit/test_p2_remaining_gaps.py +++ b/tests/unit/test_p2_remaining_gaps.py @@ -521,60 +521,6 @@ class SharedClass: # --------------------------------------------------------------------------- -class TestLoadBalancerRandomStrategy: - """LoadBalancer RANDOM strategy selects from the endpoint pool and varies over runs.""" - - @pytest.mark.asyncio - async def test_random_strategy_selects_from_pool(self): - """LB-ROUTE-003: _random_select returns one of the provided endpoints.""" - from runpod_flash.runtime.load_balancer import LoadBalancer - from runpod_flash.runtime.reliability_config import LoadBalancerStrategy - - lb = LoadBalancer(strategy=LoadBalancerStrategy.RANDOM) - endpoints = ["http://ep1", "http://ep2", "http://ep3"] - - result = await lb._random_select(endpoints) - assert result in endpoints - - @pytest.mark.asyncio - async def test_random_strategy_produces_varied_results(self): - """LB-ROUTE-003: Over many calls, random strategy selects more than one endpoint.""" - from runpod_flash.runtime.load_balancer import LoadBalancer - from runpod_flash.runtime.reliability_config import LoadBalancerStrategy - - lb = LoadBalancer(strategy=LoadBalancerStrategy.RANDOM) - endpoints = ["http://ep-a", "http://ep-b", "http://ep-c"] - - seen = {await lb._random_select(endpoints) for _ in range(60)} - # With 60 draws from 3 options, seeing at least 2 different values is - # virtually certain (probability of seeing only 1 is (1/3)^59 ≈ 0). - assert len(seen) >= 2 - - @pytest.mark.asyncio - async def test_random_strategy_via_select_endpoint(self): - """LB-ROUTE-003: select_endpoint with RANDOM strategy returns an endpoint.""" - from runpod_flash.runtime.load_balancer import LoadBalancer - from runpod_flash.runtime.reliability_config import LoadBalancerStrategy - - lb = LoadBalancer(strategy=LoadBalancerStrategy.RANDOM) - endpoints = ["http://ep-x", "http://ep-y"] - - result = await lb.select_endpoint(endpoints) - assert result in endpoints - - @pytest.mark.asyncio - async def test_random_strategy_single_endpoint_returns_it(self): - """LB-ROUTE-003: With a single endpoint, random always returns it.""" - from runpod_flash.runtime.load_balancer import LoadBalancer - from runpod_flash.runtime.reliability_config import LoadBalancerStrategy - - lb = LoadBalancer(strategy=LoadBalancerStrategy.RANDOM) - endpoints = ["http://only-one"] - - result = await lb._random_select(endpoints) - assert result == "http://only-one" - - # --------------------------------------------------------------------------- # RT-SER-005: Serialize/deserialize with complex objects (no numpy/PIL) # ---------------------------------------------------------------------------