diff --git a/tests/global_controller/test_runtime_backed_controller.py b/tests/global_controller/test_runtime_backed_controller.py index ae3eb5d..be7db03 100644 --- a/tests/global_controller/test_runtime_backed_controller.py +++ b/tests/global_controller/test_runtime_backed_controller.py @@ -14,6 +14,34 @@ class GlobalControllerRuntimeBackedTests(unittest.TestCase): + def test_load_autoscale_policies_applies_scale_down_defaults(self): + controller = make_global_controller([]) + controller.config = {"agents": [{"name": "Beta", "replicas": 2}]} + + policies = GlobalController._load_autoscale_policies( + controller, + { + "autoscale": { + "Beta": { + "queue_length_scale_up_threshold": 3, + "max_replicas": 5, + } + } + }, + ) + + self.assertEqual( + policies, + { + "Beta": { + "queue_length_scale_up_threshold": 3, + "max_replicas": 5, + "min_replicas": 2, + "idle_seconds_before_scale_down": 60, + } + }, + ) + def test_wait_for_healthy_reads_local_status_from_host_redis(self): instance = make_instance("Alpha", 0, host="localhost", host_port=8000) controller = make_global_controller([instance]) @@ -22,10 +50,16 @@ def test_wait_for_healthy_reads_local_status_from_host_redis(self): "controller:host.docker.internal:8000:status", "healthy", ) + published = [] + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: published.append( + [item["name"] for item in controllers] + ) + controller.controllers = [{"name": "Alpha"}] GlobalController._wait_for_healthy(controller, timeout=1, interval=0) self.assertEqual(controller._last_status, {("localhost", "8000"): "healthy"}) + self.assertEqual(published, [["Alpha"]]) def test_poll_controllers_calls_healthy_hook_for_runtime_instance(self): instance = make_instance("Beta", 0, host="localhost", host_port=8001) @@ -220,6 +254,332 @@ def test_ensure_host_redis_prepares_remote_docker_first(self): self.assertEqual(calls[0], ("prep", "10.0.0.7", "ec2-user")) self.assertEqual(calls[1][0], "run") + def test_poll_controllers_scales_service_when_queue_depth_exceeds_threshold(self): + instances = [make_instance("Beta", 0, host="localhost", host_port=8001)] + controller = make_global_controller(instances) + controller.controllers = [{"name": "Beta", "replicas": 1}] + controller._autoscale_policies = { + "Beta": {"queue_length_scale_up_threshold": 3, "max_replicas": 2} + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set( + "controller:host.docker.internal:8001:status", + "healthy", + ) + controller.node_redis["localhost"].set( + "queue_depth:Beta:host.docker.internal:8001", + "4", + ) + publish_calls = [] + waited = [] + + def ensure_replica(_controller_spec, replica_index, publish=True): + self.assertFalse(publish) + self.assertEqual(replica_index, 1) + instances.append(make_instance("Beta", 1, host="localhost", host_port=8002)) + return instances[-1] + + controller.runtime_manager.ensure_replica = ensure_replica + controller.runtime_manager.list_instances = lambda agent_name=None: [ + instance for instance in instances if agent_name in (None, instance["agent_name"]) + ] + controller.runtime_manager.remove_instance = lambda instance_id, publish=True: None + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: publish_calls.append( + [ctrl["name"] for ctrl in controllers] + ) + controller._wait_for_pending_healthy = lambda pending, timeout=30, interval=2, require_healthy=False: waited.extend(pending) + + GlobalController._poll_controllers(controller) + + self.assertEqual(controller.controllers[0]["replicas"], 2) + self.assertEqual(waited, [("Beta", "localhost", "8002")]) + self.assertTrue(publish_calls) + self.assertEqual(publish_calls[-1], ["Beta"]) + + def test_reconcile_marks_instance_idling_only_after_idle_window(self): + instance = make_instance("Beta", 0, host="localhost", host_port=8001) + controller = make_global_controller([instance]) + controller.controllers = [{"name": "Beta", "replicas": 1}] + controller._autoscale_policies = { + "Beta": { + "queue_length_scale_up_threshold": 3, + "max_replicas": 2, + "min_replicas": 1, + "idle_seconds_before_scale_down": 60, + } + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set("controller:host.docker.internal:8001:status", "healthy") + controller.node_redis["localhost"].set("queue_depth:Beta:host.docker.internal:8001", "0") + controller.node_redis["localhost"].set("controller:host.docker.internal:8001:active_work", "0") + controller._last_status[("localhost", "8001")] = "healthy" + publish_calls = [] + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: publish_calls.append(kwargs) + + with patch("ventis.controller.global_controller.time.time", return_value=100): + GlobalController._reconcile_instance_lifecycle(controller) + self.assertEqual(controller._lifecycle_statuses["host.docker.internal:8001"], "Healthy") + + with patch("ventis.controller.global_controller.time.time", return_value=161): + GlobalController._reconcile_instance_lifecycle(controller) + self.assertEqual(controller._lifecycle_statuses["host.docker.internal:8001"], "Idling") + self.assertTrue(publish_calls) + + def test_reconcile_resets_idling_when_queue_depth_returns(self): + instance = make_instance("Beta", 0, host="localhost", host_port=8001) + controller = make_global_controller([instance]) + controller.controllers = [{"name": "Beta", "replicas": 1}] + controller._autoscale_policies = { + "Beta": { + "queue_length_scale_up_threshold": 3, + "max_replicas": 2, + "min_replicas": 1, + "idle_seconds_before_scale_down": 1, + } + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set("queue_depth:Beta:host.docker.internal:8001", "0") + controller.node_redis["localhost"].set("controller:host.docker.internal:8001:active_work", "0") + controller._last_status[("localhost", "8001")] = "healthy" + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: None + + with patch("ventis.controller.global_controller.time.time", return_value=10): + GlobalController._reconcile_instance_lifecycle(controller) + with patch("ventis.controller.global_controller.time.time", return_value=12): + GlobalController._reconcile_instance_lifecycle(controller) + self.assertEqual(controller._lifecycle_statuses["host.docker.internal:8001"], "Idling") + + controller.node_redis["localhost"].set("queue_depth:Beta:host.docker.internal:8001", "2") + with patch("ventis.controller.global_controller.time.time", return_value=13): + GlobalController._reconcile_instance_lifecycle(controller) + self.assertEqual(controller._lifecycle_statuses["host.docker.internal:8001"], "Healthy") + self.assertNotIn("host.docker.internal:8001", controller._idle_since) + + def test_reconcile_removes_endpoint_before_marking_shutting_down(self): + instances = [ + make_instance("Beta", 0, host="localhost", host_port=8001), + make_instance("Beta", 1, host="localhost", host_port=8002), + ] + controller = make_global_controller(instances) + controller.controllers = [{"name": "Beta", "replicas": 2}] + controller._autoscale_policies = { + "Beta": { + "queue_length_scale_up_threshold": 3, + "max_replicas": 3, + "min_replicas": 1, + "idle_seconds_before_scale_down": 1, + } + } + controller.node_redis["localhost"] = FakeRedis() + for port in ("8001", "8002"): + controller.node_redis["localhost"].set(f"queue_depth:Beta:host.docker.internal:{port}", "0") + controller.node_redis["localhost"].set(f"controller:host.docker.internal:{port}:active_work", "0") + controller._last_status[("localhost", port)] = "healthy" + + publish_calls = [] + controller.runtime_manager.publish_routing_snapshot = ( + lambda controllers, **kwargs: publish_calls.append(kwargs) + ) + + with patch("ventis.controller.global_controller.time.time", return_value=10): + GlobalController._reconcile_instance_lifecycle(controller) + with patch("ventis.controller.global_controller.time.time", return_value=12): + GlobalController._reconcile_instance_lifecycle(controller) + + self.assertGreaterEqual(len(publish_calls), 2) + first = publish_calls[-2] + second = publish_calls[-1] + self.assertEqual(first["routable_endpoints"]["Beta"], {"host.docker.internal:8001"}) + self.assertEqual( + first["lifecycle_statuses"]["Beta"]["host.docker.internal:8002"], + "Idling", + ) + self.assertEqual( + second["lifecycle_statuses"]["Beta"]["host.docker.internal:8002"], + "Shutting down", + ) + + def test_stateful_instance_does_not_idle_with_affine_running_request(self): + instance = make_instance("Sticky", 0, host="localhost", host_port=8001) + controller = make_global_controller([instance]) + controller.controllers = [{"name": "Sticky", "replicas": 1, "stateful": True}] + controller._autoscale_policies = { + "Sticky": { + "queue_length_scale_up_threshold": 3, + "max_replicas": 2, + "min_replicas": 1, + "idle_seconds_before_scale_down": 1, + } + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set("queue_depth:Sticky:host.docker.internal:8001", "0") + controller.node_redis["localhost"].set("controller:host.docker.internal:8001:active_work", "0") + controller._last_status[("localhost", "8001")] = "healthy" + controller.redis.hset("affinity:req-1", "Sticky", "host.docker.internal:8001") + controller.redis.set("request:req-1:status", "running") + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: None + + with patch("ventis.controller.global_controller.time.time", return_value=10): + GlobalController._reconcile_instance_lifecycle(controller) + with patch("ventis.controller.global_controller.time.time", return_value=12): + GlobalController._reconcile_instance_lifecycle(controller) + + self.assertEqual(controller._lifecycle_statuses["host.docker.internal:8001"], "Healthy") + + controller.redis.set("request:req-1:status", "done") + with patch("ventis.controller.global_controller.time.time", return_value=14): + GlobalController._reconcile_instance_lifecycle(controller) + with patch("ventis.controller.global_controller.time.time", return_value=16): + GlobalController._reconcile_instance_lifecycle(controller) + self.assertEqual(controller._lifecycle_statuses["host.docker.internal:8001"], "Idling") + + def test_delete_ready_removes_runtime_and_metadata(self): + instances = [make_instance("Beta", 0, host="localhost", host_port=8001)] + controller = make_global_controller(instances) + controller.controllers = [{"name": "Beta", "replicas": 2}] + controller._autoscale_policies = { + "Beta": { + "queue_length_scale_up_threshold": 3, + "max_replicas": 3, + "min_replicas": 1, + "idle_seconds_before_scale_down": 1, + } + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set("controller:host.docker.internal:8001:active_work", "0") + controller.node_redis["localhost"].set("queue_depth:Beta:host.docker.internal:8001", "0") + controller.node_redis["localhost"].set("controller:host.docker.internal:8001:lifecycle", "Delete Ready") + controller._last_status[("localhost", "8001")] = "healthy" + controller._draining_endpoints.add("host.docker.internal:8001") + removed = [] + controller.runtime_manager.remove_instance = lambda instance_id, publish=False: removed.append((instance_id, publish)) + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: None + + GlobalController._reconcile_instance_lifecycle(controller) + + self.assertEqual(removed, [("local:Beta:0", False)]) + self.assertNotIn("host.docker.internal:8001", controller._lifecycle_statuses) + + def test_poll_controllers_does_not_scale_beyond_max_replicas(self): + instances = [ + make_instance("Beta", 0, host="localhost", host_port=8001), + make_instance("Beta", 1, host="localhost", host_port=8002), + ] + controller = make_global_controller(instances) + controller.controllers = [{"name": "Beta", "replicas": 2}] + controller._autoscale_policies = { + "Beta": {"queue_length_scale_up_threshold": 3, "max_replicas": 2} + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set( + "queue_depth:Beta:host.docker.internal:8001", + "99", + ) + controller.node_redis["localhost"].set( + "queue_depth:Beta:host.docker.internal:8002", + "99", + ) + calls = [] + controller.runtime_manager.ensure_instances = lambda controllers, publish=True: calls.append( + (controllers, publish) + ) + + GlobalController._maybe_scale_from_queue_depth(controller) + + self.assertEqual(calls, []) + + def test_scale_up_timeout_rolls_back_after_publishing_initializing_status(self): + instances = [make_instance("Beta", 0, host="localhost", host_port=8001)] + controller = make_global_controller(instances) + controller.controllers = [{"name": "Beta", "replicas": 1}] + controller._autoscale_policies = { + "Beta": {"queue_length_scale_up_threshold": 3, "max_replicas": 2} + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set( + "queue_depth:Beta:host.docker.internal:8001", + "4", + ) + publish_calls = [] + removed = [] + + def ensure_replica(_controller_spec, replica_index, publish=True): + self.assertEqual(replica_index, 1) + instances.append(make_instance("Beta", 1, host="localhost", host_port=8002)) + return instances[-1] + + def remove_instance(instance_id, publish=True): + removed.append((instance_id, publish)) + instances[:] = [ + instance for instance in instances + if controller.runtime_manager._instance_id_from_record(instance) != instance_id + ] + + controller.runtime_manager.ensure_replica = ensure_replica + controller.runtime_manager.remove_instance = remove_instance + controller.runtime_manager.list_instances = lambda agent_name=None: [ + instance for instance in instances if agent_name in (None, instance["agent_name"]) + ] + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: publish_calls.append( + [ctrl["name"] for ctrl in controllers] + ) + controller._wait_for_pending_healthy = lambda pending, timeout=30, interval=2, require_healthy=False: (_ for _ in ()).throw(TimeoutError("not ready")) + + with self.assertRaises(TimeoutError): + GlobalController._maybe_scale_from_queue_depth(controller) + + self.assertEqual(controller.controllers[0]["replicas"], 1) + self.assertEqual(removed, [("local:Beta:1", False)]) + self.assertEqual(publish_calls, [["Beta"]]) + + def test_scale_up_timeout_only_rolls_back_new_replica_slot(self): + instances = [make_instance("Beta", 0, host="localhost", host_port=8001)] + instances[0]["runtime_id"] = "stale-runtime" + controller = make_global_controller(instances) + controller.controllers = [{"name": "Beta", "replicas": 1}] + controller._autoscale_policies = { + "Beta": {"queue_length_scale_up_threshold": 3, "max_replicas": 2} + } + controller.node_redis["localhost"] = FakeRedis() + controller.node_redis["localhost"].set( + "queue_depth:Beta:host.docker.internal:8001", + "4", + ) + removed = [] + + def ensure_replica(_controller_spec, replica_index, publish=True): + self.assertEqual(replica_index, 1) + instances.append(make_instance("Beta", 1, host="localhost", host_port=8002)) + return instances[-1] + + def remove_instance(instance_id, publish=True): + removed.append((instance_id, publish)) + instances[:] = [ + instance for instance in instances + if controller.runtime_manager._instance_id_from_record(instance) != instance_id + ] + + controller.runtime_manager.ensure_replica = ensure_replica + controller.runtime_manager.remove_instance = remove_instance + controller.runtime_manager.list_instances = lambda agent_name=None: [ + instance for instance in instances if agent_name in (None, instance["agent_name"]) + ] + publish_calls = [] + controller.runtime_manager.publish_routing_snapshot = lambda controllers, **kwargs: publish_calls.append( + [ctrl["name"] for ctrl in controllers] + ) + controller._wait_for_pending_healthy = lambda pending, timeout=30, interval=2, require_healthy=False: (_ for _ in ()).throw(TimeoutError("not ready")) + + with self.assertRaises(TimeoutError): + GlobalController._maybe_scale_from_queue_depth(controller) + + self.assertEqual(controller.controllers[0]["replicas"], 1) + self.assertEqual( + removed, + [("local:Beta:1", False)], + ) + self.assertEqual(publish_calls, [["Beta"]]) + def test_wait_for_remote_ssh_retries_until_success(self): controller = make_global_controller([]) controller.config = {"ec2": {"ssh_private_key_path": "/tmp/test.pem"}} diff --git a/tests/live/test_full_local_deploy.py b/tests/live/test_full_local_deploy.py index c411527..fb4275d 100644 --- a/tests/live/test_full_local_deploy.py +++ b/tests/live/test_full_local_deploy.py @@ -84,6 +84,16 @@ def _wait_for_done(request_id, timeout=60): raise TimeoutError(f"request {request_id} did not finish within {timeout}s") +def _docker_inspect_exists(name): + result = subprocess.run( + ["docker", "inspect", name], + capture_output=True, + text=True, + check=False, + ) + return result.returncode == 0 + + @unittest.skipUnless( RUN_FULL_LOCAL, "set VENTIS_RUN_FULL_LOCAL=1 to run the full local build/deploy smoke test", @@ -111,6 +121,20 @@ def tearDown(self): self.deploy.stdout.close() shutil.rmtree(self.tmpdir, ignore_errors=True) + def _stop_deploy(self): + if self.deploy and self.deploy.poll() is None: + self.deploy.send_signal(signal.SIGTERM) + self.deploy.wait(timeout=30) + + def _assert_runtime_containers_removed(self, runtime_ids, timeout=30): + deadline = time.time() + timeout + while time.time() < deadline: + if all(not _docker_inspect_exists(runtime_id) for runtime_id in runtime_ids): + return + time.sleep(1) + still_present = [runtime_id for runtime_id in runtime_ids if _docker_inspect_exists(runtime_id)] + self.fail(f"Runtime containers were not removed: {still_present}") + def test_generated_local_project_builds_deploys_and_routes(self): result = _run_ventis(["new-project", self.project_name], cwd=self.tmpdir) self.assertEqual(result.returncode, 0, result.stderr) @@ -163,17 +187,330 @@ def _force_local_only_config(self): with open(config_path, "w") as f: yaml.safe_dump(config, f, sort_keys=False) + def _force_two_example_agent_replicas(self): + config_path = os.path.join(self.project_dir, "config", "global_controller.yaml") + policy_path = os.path.join(self.project_dir, "config", "policy.yaml") + + with open(config_path, "r") as f: + config = yaml.safe_load(f) + with open(policy_path, "r") as f: + policy = yaml.safe_load(f) + + config["agents"] = [ + agent for agent in config["agents"] if agent["name"] in {"ExampleAgent", "Workflow"} + ] + + for agent in config["agents"]: + agent["provider"] = "local" + agent.setdefault("resources", {}).pop("gpu", None) + if agent["name"] == "ExampleAgent": + agent["replicas"] = 2 + elif agent["name"] == "Workflow": + agent["replicas"] = 1 + + policy.setdefault("autoscale", {}).setdefault("ExampleAgent", {}) + policy["autoscale"]["ExampleAgent"].update( + { + "queue_length_scale_up_threshold": 10, + "min_replicas": 1, + "idle_seconds_before_scale_down": 60, + "max_replicas": 3, + } + ) + + with open(config_path, "w") as f: + yaml.safe_dump(config, f, sort_keys=False) + with open(policy_path, "w") as f: + yaml.safe_dump(policy, f, sort_keys=False) + def _assert_routing_metadata(self): redis = RedisClient(host="localhost", port=6379) - services = redis.smembers("routing_table:services") + deadline = time.time() + 10 + services = set() + raw_endpoints = None + raw_status = None + while time.time() < deadline: + services = redis.smembers("routing_table:services") + raw_endpoints = redis.hget("routing_table:endpoints", "ExampleAgent") + raw_status = redis.hget("routing_table:status", "ExampleAgent") + if "ExampleAgent" in services and "Workflow" in services and raw_endpoints and raw_status: + break + time.sleep(0.5) + self.assertIn("ExampleAgent", services) self.assertIn("Workflow", services) - - raw_endpoints = redis.hget("routing_table:endpoints", "ExampleAgent") self.assertIsNotNone(raw_endpoints) + self.assertIsNotNone(raw_status) endpoints = json.loads(raw_endpoints) + statuses = json.loads(raw_status) self.assertEqual(len(endpoints), 1) self.assertTrue(endpoints[0].startswith("host.docker.internal:")) + self.assertEqual(statuses[endpoints[0]], "Healthy") + + def _wait_for_example_agent_idle_and_scale_down(self): + redis = RedisClient(host="localhost", port=6379) + deadline = time.time() + 140 + started_at = time.time() + seen_idle = None + initial_endpoints = None + + while time.time() < deadline: + raw_endpoints = redis.hget("routing_table:endpoints", "ExampleAgent") + raw_status = redis.hget("routing_table:status", "ExampleAgent") + if not raw_endpoints or not raw_status: + time.sleep(1) + continue + + endpoints = json.loads(raw_endpoints) + statuses = json.loads(raw_status) + if initial_endpoints is None and len(endpoints) == 2: + initial_endpoints = list(endpoints) + + if seen_idle is None: + idle_endpoints = [ + endpoint + for endpoint, status in statuses.items() + if status == "Idling" + ] + if idle_endpoints: + seen_idle = (time.time(), idle_endpoints[0], list(endpoints)) + + if seen_idle and initial_endpoints and len(endpoints) == 1: + removed = [endpoint for endpoint in initial_endpoints if endpoint not in endpoints] + self.assertEqual(len(removed), 1) + return { + "idle_after_seconds": round(seen_idle[0] - started_at, 1), + "removed_endpoint": removed[0], + "remaining_endpoint": endpoints[0], + } + + time.sleep(1) + + raise TimeoutError("ExampleAgent did not idle and scale down within the timeout") + + def test_example_agent_two_replicas_idle_then_scale_down(self): + result = _run_ventis(["new-project", self.project_name], cwd=self.tmpdir) + self.assertEqual(result.returncode, 0, result.stderr) + + self._force_two_example_agent_replicas() + + result = _run_ventis(["build"], cwd=self.project_dir, timeout=300) + self.assertEqual(result.returncode, 0, result.stderr) + + self.deploy = subprocess.Popen( + [sys.executable, "-m", "ventis.cli", "deploy"], + cwd=self.project_dir, + env={ + **os.environ, + "PYTHONPATH": f"{REPO_ROOT}{os.pathsep}{os.environ.get('PYTHONPATH', '')}", + }, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + deadline = time.time() + 60 + while time.time() < deadline: + redis = RedisClient(host="localhost", port=6379) + raw_endpoints = redis.hget("routing_table:endpoints", "ExampleAgent") + raw_status = redis.hget("routing_table:status", "ExampleAgent") + if raw_endpoints and raw_status: + endpoints = json.loads(raw_endpoints) + statuses = json.loads(raw_status) + if len(endpoints) == 2 and all(statuses.get(endpoint) == "Healthy" for endpoint in endpoints): + break + time.sleep(1) + else: + self.fail("ExampleAgent routing metadata did not show two healthy replicas") + + verification = self._wait_for_example_agent_idle_and_scale_down() + self.assertGreaterEqual(verification["idle_after_seconds"], 55) + self.assertLessEqual(verification["idle_after_seconds"], 90) + + self._stop_deploy() + self._assert_runtime_containers_removed( + [ + "ventis-local-exampleagent-0", + "ventis-local-exampleagent-1", + "ventis-local-workflow-0", + ] + ) + + def _configure_three_agent_chain(self): + config_path = os.path.join(self.project_dir, "config", "global_controller.yaml") + policy_path = os.path.join(self.project_dir, "config", "policy.yaml") + workflow_path = os.path.join(self.project_dir, "workflows", "example_workflow.py") + agents_dir = os.path.join(self.project_dir, "agents") + + with open(config_path, "w") as f: + yaml.safe_dump( + { + "agents": [ + { + "name": "AlphaAgent", + "replicas": 1, + "redis_port": 6379, + "resources": {"cpu": 1, "memory": 512}, + "entrypoint": "agents/alpha_agent.py", + "provider": "local", + }, + { + "name": "BetaAgent", + "replicas": 1, + "redis_port": 6379, + "resources": {"cpu": 1, "memory": 512}, + "entrypoint": "agents/beta_agent.py", + "provider": "local", + }, + { + "name": "GammaAgent", + "replicas": 1, + "redis_port": 6379, + "resources": {"cpu": 1, "memory": 512}, + "entrypoint": "agents/gamma_agent.py", + "provider": "local", + }, + { + "name": "Workflow", + "replicas": 1, + "type": "workflow", + "redis_port": 6379, + "workflow_file": "workflows/example_workflow.py", + "provider": "local", + }, + ], + "poll_interval": 5, + "redis": {"host": "localhost", "port": 6379, "db": 0}, + }, + f, + sort_keys=False, + ) + + with open(policy_path, "w") as f: + yaml.safe_dump( + { + "rules": [ + { + "match": {}, + "access": ["AlphaAgent", "BetaAgent", "GammaAgent", "Workflow"], + } + ] + }, + f, + sort_keys=False, + ) + + with open(os.path.join(agents_dir, "alpha_agent.yaml"), "w") as f: + f.write( + """agent: + name: AlphaAgent + functions: + - name: start + arguments: + - name: text + type: str + returns: + type: str +""" + ) + with open(os.path.join(agents_dir, "beta_agent.yaml"), "w") as f: + f.write( + """agent: + name: BetaAgent + functions: + - name: step + arguments: + - name: text + type: str + returns: + type: str +""" + ) + with open(os.path.join(agents_dir, "gamma_agent.yaml"), "w") as f: + f.write( + """agent: + name: GammaAgent + functions: + - name: finish + arguments: + - name: text + type: str + returns: + type: str +""" + ) + + with open(os.path.join(agents_dir, "alpha_agent.py"), "w") as f: + f.write( + """import os\nimport sys\nsys.path.insert(0, os.path.dirname(__file__))\nfrom beta_agent_stub import BetaAgentStub\n\n\nclass AlphaAgent(object):\n def __init__(self):\n self.tools = [self.start]\n\n def start(self, text: str) -> str:\n return BetaAgentStub().step(text=f\"{text} -> alpha\").value()\n""" + ) + with open(os.path.join(agents_dir, "beta_agent.py"), "w") as f: + f.write( + """import os\nimport sys\nsys.path.insert(0, os.path.dirname(__file__))\nfrom gamma_agent_stub import GammaAgentStub\n\n\nclass BetaAgent(object):\n def __init__(self):\n self.tools = [self.step]\n\n def step(self, text: str) -> str:\n return GammaAgentStub().finish(text=f\"{text} -> beta\").value()\n""" + ) + with open(os.path.join(agents_dir, "gamma_agent.py"), "w") as f: + f.write( + """class GammaAgent(object):\n def __init__(self):\n self.tools = [self.finish]\n\n def finish(self, text: str) -> str:\n return f\"{text} -> gamma\"\n""" + ) + + with open(workflow_path, "w") as f: + f.write( + """import os\nimport sys\nsys.path.insert(0, os.path.dirname(__file__))\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'stubs'))\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grpc_stubs'))\nfrom deploy import deploy\nfrom alpha_agent_stub import AlphaAgentStub\n\n\ndef main(text: str = 'start', name: str = None):\n if name is not None:\n text = name\n return {'result': AlphaAgentStub().start(text=text).value()}\n\n\ndeploy(main, port=8080)\n""" + ) + + def test_three_agent_chain_creates_routes_and_deletes_all_instances(self): + result = _run_ventis(["new-project", self.project_name], cwd=self.tmpdir) + self.assertEqual(result.returncode, 0, result.stderr) + + self._configure_three_agent_chain() + + result = _run_ventis(["build"], cwd=self.project_dir, timeout=300) + self.assertEqual(result.returncode, 0, result.stderr) + + self.deploy = subprocess.Popen( + [sys.executable, "-m", "ventis.cli", "deploy"], + cwd=self.project_dir, + env={ + **os.environ, + "PYTHONPATH": f"{REPO_ROOT}{os.pathsep}{os.environ.get('PYTHONPATH', '')}", + }, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + _wait_for_http("http://localhost:8080/main", timeout=90) + + redis = RedisClient(host="localhost", port=6379) + deadline = time.time() + 30 + while time.time() < deadline: + services = redis.smembers("routing_table:services") + if {"AlphaAgent", "BetaAgent", "GammaAgent", "Workflow"}.issubset(services): + break + time.sleep(1) + else: + self.fail(f"Expected 4 services, got {services}") + + status, submitted = _request_json( + "POST", + "http://localhost:8080/main", + {"text": "start"}, + timeout=5, + ) + self.assertEqual(status, 202) + completed = _wait_for_done(submitted["request_id"], timeout=90) + self.assertEqual(completed["status"], "done") + self.assertEqual(completed["result"], {"result": "start -> alpha -> beta -> gamma"}) + + self._stop_deploy() + self._assert_runtime_containers_removed( + [ + "ventis-local-alphaagent-0", + "ventis-local-betaagent-0", + "ventis-local-gammaagent-0", + "ventis-local-workflow-0", + ] + ) if __name__ == "__main__": diff --git a/tests/local/test_queue_depth_publication.py b/tests/local/test_queue_depth_publication.py new file mode 100644 index 0000000..61e466d --- /dev/null +++ b/tests/local/test_queue_depth_publication.py @@ -0,0 +1,134 @@ +import queue +import unittest + +from tests.support.runtime_fakes import FakeRedis +from tests.support.runtime_fakes import install_grpc_stubs + +install_grpc_stubs() + +from ventis.controller.local_controller import LocalController + + +class LocalControllerQueueDepthTests(unittest.TestCase): + def test_publish_queue_depth_writes_current_queue_size(self): + controller = LocalController.__new__(LocalController) + controller.agent_name = "ExampleAgent" + controller.agent_host = "host.docker.internal" + controller.public_port = "8000" + controller.redis = FakeRedis() + controller._active_requests = 0 + controller.request_queue = queue.Queue() + controller.request_queue.put("a") + controller.request_queue.put("b") + + LocalController._publish_queue_depth(controller) + + self.assertEqual( + controller.redis.get("queue_depth:ExampleAgent:host.docker.internal:8000"), + "2", + ) + + def test_publish_queue_depth_includes_active_requests(self): + controller = LocalController.__new__(LocalController) + controller.agent_name = "ExampleAgent" + controller.agent_host = "host.docker.internal" + controller.public_port = "8000" + controller.redis = FakeRedis() + controller._active_requests = 3 + controller.request_queue = queue.Queue() + controller.request_queue.put("a") + + LocalController._publish_queue_depth(controller) + + self.assertEqual( + controller.redis.get("queue_depth:ExampleAgent:host.docker.internal:8000"), + "4", + ) + + def test_publish_active_work_writes_current_active_requests(self): + controller = LocalController.__new__(LocalController) + controller.redis = FakeRedis() + controller._active_requests = 3 + controller._active_work_key_name = "controller:host.docker.internal:8000:active_work" + + LocalController._publish_active_work(controller) + + self.assertEqual( + controller.redis.get("controller:host.docker.internal:8000:active_work"), + "3", + ) + + def test_update_lifecycle_signal_marks_delete_ready_only_after_drain(self): + controller = LocalController.__new__(LocalController) + controller.agent_name = "ExampleAgent" + controller.agent_host = "host.docker.internal" + controller.public_port = "8000" + controller._my_endpoint = "host.docker.internal:8000" + controller._shutting_down = False + controller._active_requests = 0 + controller._lifecycle_signal_key = "controller:host.docker.internal:8000:lifecycle" + controller.redis = FakeRedis() + controller.request_queue = queue.Queue() + controller.redis.hset( + "routing_table:status", + "ExampleAgent", + '{"host.docker.internal:8000":"Shutting down"}', + ) + + LocalController._update_lifecycle_signal(controller) + + self.assertEqual( + controller.redis.get("controller:host.docker.internal:8000:lifecycle"), + "Delete Ready", + ) + + def test_update_lifecycle_signal_blocks_stateful_affine_inflight_request(self): + controller = LocalController.__new__(LocalController) + controller.agent_name = "ExampleAgent" + controller.agent_host = "host.docker.internal" + controller.public_port = "8000" + controller._my_endpoint = "host.docker.internal:8000" + controller._shutting_down = False + controller._active_requests = 0 + controller._lifecycle_signal_key = "controller:host.docker.internal:8000:lifecycle" + controller.redis = FakeRedis() + controller.request_queue = queue.Queue() + controller.redis.hset( + "routing_table:status", + "ExampleAgent", + '{"host.docker.internal:8000":"Shutting down"}', + ) + controller.redis.hset("routing_table:stateful", "ExampleAgent", "true") + controller.redis.hset("affinity:req-1", "ExampleAgent", "host.docker.internal:8000") + controller.redis.set("request:req-1:status", "running") + + LocalController._update_lifecycle_signal(controller) + + self.assertIsNone( + controller.redis.get("controller:host.docker.internal:8000:lifecycle") + ) + + def test_stop_without_agent_name_does_not_write_bogus_queue_key(self): + controller = LocalController.__new__(LocalController) + controller.agent_name = None + controller.agent_host = "host.docker.internal" + controller.public_port = "8000" + controller.redis = FakeRedis() + controller._status_key = "controller:host.docker.internal:8000:status" + controller._active_work_key_name = "controller:host.docker.internal:8000:active_work" + controller._lifecycle_signal_key = "controller:host.docker.internal:8000:lifecycle" + controller._executor = type("Executor", (), {"shutdown": lambda self, wait=True: None})() + controller.server = type("Server", (), {"stop": lambda self, code: None})() + + LocalController.stop(controller) + + self.assertEqual(controller.redis.get(controller._status_key), "stopped") + self.assertEqual( + controller.redis.get("controller:host.docker.internal:8000:active_work"), + "0", + ) + self.assertEqual(controller.redis.scan_keys("queue_depth:*"), []) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/routing/test_routing_publication.py b/tests/routing/test_routing_publication.py index 2049d65..0838aba 100644 --- a/tests/routing/test_routing_publication.py +++ b/tests/routing/test_routing_publication.py @@ -30,22 +30,32 @@ def test_publish_routing_snapshot_orders_endpoints_by_replica_index(self): ["host.docker.internal:8000", "host.docker.internal:8001"], ) self.assertEqual(self.controller.redis.hget("routing_table:stateful", "Alpha"), "true") + self.assertEqual( + json.loads(self.controller.redis.hget("routing_table:status", "Alpha")), + { + "host.docker.internal:8000": "Healthy", + "host.docker.internal:8001": "Healthy", + }, + ) self.assertEqual(self.controller.redis.smembers("routing_table:services"), {"Alpha"}) def test_publish_routing_snapshot_removes_endpoint_when_no_instances_exist(self): redis = self.controller.redis redis.sadd("routing_table:services", "Alpha") redis.hset("routing_table:endpoints", "Alpha", json.dumps(["localhost:8000"])) + redis.hset("routing_table:status", "Alpha", json.dumps({"localhost:8000": "Healthy"})) self.manager.publish_routing_snapshot([{"name": "Alpha", "stateful": False}]) self.assertEqual(redis.smembers("routing_table:services"), {"Alpha"}) self.assertIsNone(redis.hget("routing_table:endpoints", "Alpha")) + self.assertIsNone(redis.hget("routing_table:status", "Alpha")) def test_publish_routing_snapshot_clears_stale_service_metadata(self): redis = self.controller.redis redis.sadd("routing_table:services", "Old", "Keep") redis.hset("routing_table:endpoints", "Old", json.dumps(["localhost:9000"])) + redis.hset("routing_table:status", "Old", json.dumps({"localhost:9000": "Healthy"})) redis.hset("routing_table:stateful", "Old", "true") redis.hset("routing_table:stateful", "Keep", "true") @@ -53,6 +63,7 @@ def test_publish_routing_snapshot_clears_stale_service_metadata(self): self.assertEqual(redis.smembers("routing_table:services"), {"Keep"}) self.assertIsNone(redis.hget("routing_table:endpoints", "Old")) + self.assertIsNone(redis.hget("routing_table:status", "Old")) self.assertIsNone(redis.hget("routing_table:stateful", "Old")) self.assertIsNone(redis.hget("routing_table:stateful", "Keep")) @@ -81,9 +92,40 @@ def test_publish_routing_snapshot_targets_each_node_redis(self): json.loads(redis.hget("routing_table:endpoints", "Beta")), ["host.docker.internal:8001"], ) + self.assertEqual( + json.loads(redis.hget("routing_table:status", "Beta")), + {"host.docker.internal:8001": "Healthy"}, + ) self.assertIsNone(self.controller.redis.hget("routing_table:endpoints", "Alpha")) + def test_publish_routing_snapshot_can_exclude_draining_endpoint_but_keep_status(self): + self.write_instance(make_instance("Alpha", 0, host_port=8000)) + self.write_instance(make_instance("Alpha", 1, host_port=8001)) + + self.manager.publish_routing_snapshot( + [{"name": "Alpha", "stateful": False}], + lifecycle_statuses={ + "Alpha": { + "host.docker.internal:8000": "Healthy", + "host.docker.internal:8001": "Shutting down", + } + }, + routable_endpoints={"Alpha": {"host.docker.internal:8000"}}, + ) + + self.assertEqual( + json.loads(self.controller.redis.hget("routing_table:endpoints", "Alpha")), + ["host.docker.internal:8000"], + ) + self.assertEqual( + json.loads(self.controller.redis.hget("routing_table:status", "Alpha")), + { + "host.docker.internal:8000": "Healthy", + "host.docker.internal:8001": "Shutting down", + }, + ) + def test_routing_targets_fall_back_to_central_redis(self): self.assertEqual(self.manager._routing_redis_targets(), [self.controller.redis]) @@ -109,6 +151,44 @@ def test_publish_policy_rules_writes_empty_rule_list(self): self.assertEqual(count, 1) self.assertEqual(json.loads(self.controller.redis.get("policy:rules")), []) + def test_ensure_instances_can_defer_routing_publication(self): + self.manager.ensure_instances( + [{"name": "Alpha", "provider": "local", "replicas": 1}], + publish=False, + ) + + self.assertIsNone(self.controller.redis.hget("routing_table:endpoints", "Alpha")) + + def test_remove_instance_can_defer_routing_publication(self): + instance = make_instance("Alpha", 0, host_port=8000) + self.write_instance(instance) + self.controller.runtime_ids.add(instance["runtime_id"]) + + self.manager.remove_instance("local:Alpha:0", publish=False) + + self.assertIsNone(self.controller.redis.hget("routing_table:endpoints", "Alpha")) + + def test_recreate_with_publish_false_keeps_routing_unpublished(self): + stale = make_instance("Alpha", 0, host_port=8000) + stale["runtime_id"] = "stale-runtime" + self.write_instance(stale) + + self.manager.ensure_instances( + [{"name": "Alpha", "provider": "local", "replicas": 1}], + publish=False, + ) + + self.assertIsNone(self.controller.redis.hget("routing_table:endpoints", "Alpha")) + + def test_ensure_replica_with_publish_false_keeps_routing_unpublished(self): + self.manager.ensure_replica( + {"name": "Alpha", "provider": "local", "replicas": 2}, + replica_index=1, + publish=False, + ) + + self.assertIsNone(self.controller.redis.hget("routing_table:endpoints", "Alpha")) + if __name__ == "__main__": unittest.main() diff --git a/tests/support/runtime_fakes.py b/tests/support/runtime_fakes.py index dba4901..96aae18 100644 --- a/tests/support/runtime_fakes.py +++ b/tests/support/runtime_fakes.py @@ -24,6 +24,8 @@ def __init__(self, resonse): local_controler_pb2_grpc = ModuleType("local_controler_pb2_grpc") local_controler_pb2_grpc.LocalControllerStub = type("LocalControllerStub", (), {}) + local_controler_pb2_grpc.LocalControllerServicer = type("LocalControllerServicer", (), {}) + local_controler_pb2_grpc.add_LocalControllerServicer_to_server = lambda servicer, server: None local_controler_pb2_grpc.__file__ = "local_controler_pb2_grpc.py" sys.modules.setdefault("local_controler_pb2_grpc", local_controler_pb2_grpc) sys.modules.setdefault("grpc", ModuleType("grpc")) @@ -167,6 +169,10 @@ def make_global_controller(instances): controller.containers = {} controller._last_status = {} controller._lc_stubs = {} + controller._lifecycle_statuses = {} + controller._idle_since = {} + controller._draining_endpoints = set() + controller._autoscale_policies = {} controller._healthy_calls = [] controller._unhealthy_calls = [] controller._run_cmd_calls = [] @@ -175,6 +181,11 @@ def make_global_controller(instances): list_instances=lambda agent_name=None: list(instances), list_runtime_nodes=lambda agent_specs=None: {}, _user_for_instance=lambda instance: instance.get("user"), + _routing_endpoint_for=lambda instance: ( + f"host.docker.internal:{instance['host_port']}" + if _is_local_host(instance["host"]) and instance.get("provider", "local").lower() == "local" + else instance["endpoint"] + ), _instance_id_from_record=lambda instance: ( f"{instance['provider']}:{instance['agent_name']}:{instance['replica_index']}" ), diff --git a/tests/test_runtime_manager.py b/tests/test_runtime_manager.py index d11ed36..3185a6f 100644 --- a/tests/test_runtime_manager.py +++ b/tests/test_runtime_manager.py @@ -473,6 +473,42 @@ def test_ec2_launch_includes_redis_and_agent_env(self): "10.0.0.30", ) + def test_ec2_instance_can_be_created_and_removed(self): + controller = FakeController() + manager = RuntimeManager(controller, controller.redis) + + manager.ensure_instances( + [ + { + "name": "Destroyable", + "provider": "EC2", + "replicas": 1, + "redis_port": 6380, + } + ] + ) + + self.assertIn( + "EC2:Destroyable:0", + controller.redis.smembers("agent:Destroyable:instances"), + ) + self.assertEqual( + controller.redis.hgetall("agent_instance:EC2:Destroyable:0")["endpoint"], + "10.0.0.30:50051", + ) + + manager.remove_instance("EC2:Destroyable:0") + + self.assertEqual(self.fake_ec2_client.terminate_requests, [["i-test1"]]) + self.assertEqual( + controller.redis.smembers("agent:Destroyable:instances"), + set(), + ) + self.assertEqual( + controller.redis.hgetall("agent_instance:EC2:Destroyable:0"), + {}, + ) + def test_ec2_agent_entrypoint_uses_generic_image_and_bind_mount(self): controller = FakeController() manager = RuntimeManager(controller, controller.redis) diff --git a/ventis/controller/global_controller.py b/ventis/controller/global_controller.py index c5eedf3..835ac75 100644 --- a/ventis/controller/global_controller.py +++ b/ventis/controller/global_controller.py @@ -48,10 +48,15 @@ class GlobalController(object): Designed to be subclassed — override the _on_* hooks to extend behavior. """ - ROUTING_ENDPOINTS_KEY = "routing_table:endpoints" - ROUTING_STATEFUL_KEY = "routing_table:stateful" - SERVICES_SET_KEY = "routing_table:services" - POLICY_RULES_KEY = "policy:rules" + QUEUE_DEPTH_KEY_PREFIX = "queue_depth" + ACTIVE_WORK_KEY_SUFFIX = "active_work" + LIFECYCLE_SIGNAL_SUFFIX = "lifecycle" + + STATUS_INITIALIZING = "Initializing" + STATUS_HEALTHY = "Healthy" + STATUS_IDLING = "Idling" + STATUS_SHUTTING_DOWN = "Shutting down" + STATUS_DELETE_READY = "Delete Ready" def __init__(self, config_path): self.config_path = config_path @@ -75,6 +80,10 @@ def __init__(self, config_path): self._lc_stubs = {} # endpoint -> gRPC stub self._shipped_images = set() # (image, host) already shipped this session self._synced_projects = set() # (host, remote_dir) synced this session + self._autoscale_policies = {} + self._lifecycle_statuses = {} + self._idle_since = {} + self._draining_endpoints = set() self.runtime_manager = RuntimeManager(self) # Clean up any stale containers from previous runs @@ -86,7 +95,7 @@ def __init__(self, config_path): write_agent_specs(self.config_path, self.redis) self._write_resource_specs() self._load_and_write_policies() - self.runtime_manager.publish_routing_snapshot(self.controllers) + self._publish_routing_state() logger.info("Global controller initialized with %d controller(s).", len(self.controllers)) # Start background cleanup thread @@ -100,7 +109,7 @@ def __init__(self, config_path): def _cleanup_stale_containers(self): """Remove only Redis containers from previous runs. - ponytail: agent containers are now reused by RuntimeManager, so startup + Agent containers are now reused by RuntimeManager, so startup cleanup must not delete them preemptively. """ logger.info("Checking for stale Redis containers from previous runs...") @@ -142,7 +151,8 @@ def reload_config(self): self.config = self._load_config(self.config_path) self.controllers = self.config.get("agents", []) self.poll_interval = self.config.get("poll_interval", 5) - self.runtime_manager.publish_routing_snapshot(self.controllers) + self._load_and_write_policies() + self._publish_routing_state() def _write_resource_specs(self): """Write the per-agent resource specs to Redis.""" @@ -155,17 +165,20 @@ def _write_resource_specs(self): "replicas": str(int(ctrl.get("replicas", 1))), }) - def _load_policy_rules(self): - """Load policy rules from config/policy.yaml.""" + def _load_policy_config(self): config_dir = os.path.dirname(os.path.abspath(self.config_path)) policy_path = os.path.join(config_dir, "policy.yaml") if not os.path.isfile(policy_path): logger.info("No policy file found at %s, skipping policy setup.", policy_path) - return [] + return {} with open(policy_path, "r") as f: - policy_config = yaml.safe_load(f) + return yaml.safe_load(f) or {} + + def _load_policy_rules(self, policy_config=None): + """Load access policy rules from config/policy.yaml.""" + policy_config = policy_config if policy_config is not None else self._load_policy_config() rules = policy_config.get("rules", []) @@ -174,16 +187,45 @@ def _load_policy_rules(self): rules.sort(key=lambda r: len(r.get("match", {})), reverse=True) return rules + def _load_autoscale_policies(self, policy_config=None): + """Normalize autoscale policy entries from policy.yaml. + + Services without queue_length_scale_up_threshold and max_replicas are + skipped. min_replicas defaults to the agent's configured replica count. + """ + policy_config = policy_config if policy_config is not None else self._load_policy_config() + autoscale = policy_config.get("autoscale", {}) + base_replicas = { + item["name"]: int(item.get("replicas", 1)) + for item in self.config.get("agents", []) + if isinstance(item, dict) and item.get("name") + } + normalized = {} + for service, cfg in autoscale.items(): + if not isinstance(cfg, dict): + continue + threshold = cfg.get("queue_length_scale_up_threshold") + max_replicas = cfg.get("max_replicas") + if threshold is None or max_replicas is None: + continue + normalized[service] = { + "queue_length_scale_up_threshold": int(threshold), + "max_replicas": int(max_replicas), + "min_replicas": int(cfg.get("min_replicas", base_replicas.get(service, 1))), + "idle_seconds_before_scale_down": int(cfg.get("idle_seconds_before_scale_down", 60)), + } + return normalized + def _load_and_write_policies(self): """Load policy rules and publish them to every host Redis.""" - rules = self._load_policy_rules() + policy_config = self._load_policy_config() + rules = self._load_policy_rules(policy_config) + self._autoscale_policies = self._load_autoscale_policies(policy_config) target_count = self.runtime_manager.publish_policy_rules(rules) logger.info("Policy rules written to %d Redis instance(s): %d rule(s)", target_count, len(rules)) - # Routing reads are direct Redis calls now that RuntimeManager owns publication: - # - self.redis.hgetall(self.ROUTING_ENDPOINTS_KEY) - # - self.redis.hget(self.ROUTING_ENDPOINTS_KEY, service_name) + # Routing reads are direct Redis calls; RuntimeManager owns publication. def get_node_redis(self, host): """Get the RedisClient for a specific node.""" @@ -311,11 +353,21 @@ def _wait_for_healthy(self, timeout=30, interval=2): timeout: Maximum seconds to wait. interval: Seconds between checks. """ - deadline = time.time() + timeout pending = [ (instance["agent_name"], instance["host"], instance["host_port"]) for instance in self.runtime_manager.list_instances() ] + self._wait_for_pending_healthy(pending, timeout=timeout, interval=interval) + self._publish_routing_state() + + def _wait_for_pending_healthy(self, pending, timeout=30, interval=2, require_healthy=False): + """Poll node Redis until listed replicas report healthy. + + When require_healthy is False, logs a warning on timeout and returns. + When True, raises TimeoutError (used during autoscale scale-up rollback). + """ + pending = list(pending) + deadline = time.time() + timeout logger.info("Waiting for %d replica(s) to become healthy (timeout=%ds)...", len(pending), timeout) @@ -341,6 +393,8 @@ def _wait_for_healthy(self, timeout=30, interval=2): "Controller %s (%s:%s) not ready after %ds.", name, host, port, timeout, ) + if require_healthy: + raise TimeoutError(f"Controllers not ready after {timeout}s: {pending}") # ------------------------------------------------------------------ # # Polling loop # @@ -389,6 +443,8 @@ def _poll_controllers(self): self._on_controller_healthy(name, host, port) else: self._on_controller_unhealthy(name, host, port) + self._reconcile_instance_lifecycle() + self._maybe_scale_from_queue_depth() # ------------------------------------------------------------------ # # Extensibility hooks — override in subclasses # @@ -406,6 +462,269 @@ def _on_routing_table_updated(self, table): """Called after the routing table has been written to Redis.""" pass + def _queue_depth_key(self, name, host, port): + return f"{self.QUEUE_DEPTH_KEY_PREFIX}:{name}:{self._agent_host_key(host)}:{port}" + + def _active_work_key(self, host, port): + return f"controller:{self._agent_host_key(host)}:{port}:{self.ACTIVE_WORK_KEY_SUFFIX}" + + def _lifecycle_signal_key(self, host, port): + return f"controller:{self._agent_host_key(host)}:{port}:{self.LIFECYCLE_SIGNAL_SUFFIX}" + + def _routing_endpoint(self, instance): + return self.runtime_manager._routing_endpoint_for(instance) + + def _routing_status_payload(self): + """Per-endpoint lifecycle status for routing_table:status publication.""" + payload = {} + for instance in self.runtime_manager.list_instances(): + service = instance["agent_name"] + endpoint = self._routing_endpoint(instance) + payload.setdefault(service, {})[endpoint] = self._lifecycle_statuses.get( + endpoint, + self.STATUS_INITIALIZING, + ) + return payload + + def _routable_endpoints_payload(self, extra_excluded_endpoints=None): + """Endpoints that should receive new requests. + + Excludes draining endpoints and replicas in Shutting down or Delete Ready. + Status may still list excluded endpoints so clients see the transition. + """ + excluded = set(self._draining_endpoints) + if extra_excluded_endpoints: + excluded.update(extra_excluded_endpoints) + payload = {} + for instance in self.runtime_manager.list_instances(): + endpoint = self._routing_endpoint(instance) + if endpoint in excluded: + continue + status = self._lifecycle_statuses.get(endpoint) + if status in {self.STATUS_SHUTTING_DOWN, self.STATUS_DELETE_READY}: + continue + payload.setdefault(instance["agent_name"], set()).add(endpoint) + return payload + + def _publish_routing_state(self, extra_excluded_endpoints=None): + """Publish routing endpoints and lifecycle status to host Redis.""" + publish_routing_snapshot = getattr(self.runtime_manager, "publish_routing_snapshot", None) + if publish_routing_snapshot: + publish_routing_snapshot( + self.controllers, + lifecycle_statuses=self._routing_status_payload(), + routable_endpoints=self._routable_endpoints_payload(extra_excluded_endpoints), + ) + + def _read_active_work(self, instance): + raw = self._get_node_redis_for(instance["host"]).get( + self._active_work_key(instance["host"], instance["host_port"]) + ) + try: + return max(0, int(raw or 0)) + except (TypeError, ValueError): + return 0 + + def _read_lifecycle_signal(self, instance): + return self._get_node_redis_for(instance["host"]).get( + self._lifecycle_signal_key(instance["host"], instance["host_port"]) + ) + + def _is_instance_healthy(self, instance): + return self._last_status.get((instance["host"], instance["host_port"])) == "healthy" + + def _has_affine_inflight_requests(self, service, endpoint): + for affinity_key in self.redis.scan_keys("affinity:*"): + request_endpoint = self.redis.hget(affinity_key, service) + if request_endpoint != endpoint: + continue + request_id = affinity_key.split(":", 1)[1] + request_status = self.redis.get(f"request:{request_id}:status") + if request_status not in {"done", "error"}: + return True + return False + + def _mark_instance_status(self, instance, status): + self._lifecycle_statuses[self._routing_endpoint(instance)] = status + + def _clear_instance_state(self, instance): + endpoint = self._routing_endpoint(instance) + self._lifecycle_statuses.pop(endpoint, None) + self._idle_since.pop(endpoint, None) + self._draining_endpoints.discard(endpoint) + + def _reconcile_instance_lifecycle(self): + """Drive per-replica lifecycle and scale-down. + + Each poll: mark instances Healthy/Idling from queue depth and active work, + pick one idling replica above min_replicas to drain (remove from routing, + mark Shutting down), then remove instances whose local controller reports + Delete Ready after draining. + """ + now = time.time() + delete_ready_instances = [] + + for controller_spec in self.controllers: + service = controller_spec["name"] + policy = self._autoscale_policies.get(service) + idle_window = int((policy or {}).get("idle_seconds_before_scale_down", 60)) + min_replicas = int((policy or {}).get("min_replicas", int(controller_spec.get("replicas", 1)))) + instances = sorted( + self.runtime_manager.list_instances(service), + key=lambda item: int(item["replica_index"]), + ) + current_replicas = len(instances) + + for instance in instances: + endpoint = self._routing_endpoint(instance) + if not self._is_instance_healthy(instance): + self._idle_since.pop(endpoint, None) + if endpoint not in self._draining_endpoints: + self._lifecycle_statuses[endpoint] = self.STATUS_INITIALIZING + continue + + if endpoint in self._draining_endpoints: + signal = self._read_lifecycle_signal(instance) + if signal == self.STATUS_DELETE_READY: + self._lifecycle_statuses[endpoint] = self.STATUS_DELETE_READY + delete_ready_instances.append((controller_spec, instance)) + else: + self._lifecycle_statuses[endpoint] = self.STATUS_SHUTTING_DOWN + continue + + active_work = self._read_active_work(instance) + queue_depth = self._queue_depth_for_instance(instance) + if queue_depth != 0 or active_work != 0: + self._idle_since.pop(endpoint, None) + self._lifecycle_statuses[endpoint] = self.STATUS_HEALTHY + continue + if controller_spec.get("stateful") and self._has_affine_inflight_requests(service, endpoint): + self._idle_since.pop(endpoint, None) + self._lifecycle_statuses[endpoint] = self.STATUS_HEALTHY + continue + idle_since = self._idle_since.setdefault(endpoint, now) + if (now - idle_since) >= idle_window: + self._lifecycle_statuses[endpoint] = self.STATUS_IDLING + else: + self._lifecycle_statuses[endpoint] = self.STATUS_HEALTHY + + if policy and current_replicas > min_replicas: + candidates = [ + instance + for instance in instances + if self._lifecycle_statuses.get(self._routing_endpoint(instance)) == self.STATUS_IDLING + ] + if candidates: + candidate = max(candidates, key=lambda item: int(item["replica_index"])) + endpoint = self._routing_endpoint(candidate) + self._draining_endpoints.add(endpoint) + self._publish_routing_state(extra_excluded_endpoints={endpoint}) + self._lifecycle_statuses[endpoint] = self.STATUS_SHUTTING_DOWN + + self._publish_routing_state() + + for controller_spec, instance in delete_ready_instances: + self._publish_routing_state() + self.runtime_manager.remove_instance( + self.runtime_manager._instance_id_from_record(instance), + publish=False, + ) + controller_spec["replicas"] = max( + int((self._autoscale_policies.get(controller_spec["name"]) or {}).get("min_replicas", 1)), + int(controller_spec.get("replicas", 1)) - 1, + ) + self._clear_instance_state(instance) + self._publish_routing_state() + + def _queue_depth_for_instance(self, instance): + key = self._queue_depth_key( + instance["agent_name"], + instance["host"], + instance["host_port"], + ) + raw = self._get_node_redis_for(instance["host"]).get(key) + try: + return max(0, int(raw or 0)) + except (TypeError, ValueError): + return 0 + + def _queue_depth_by_service(self): + queue_depths = {} + for instance in self.runtime_manager.list_instances(): + depth = self._queue_depth_for_instance(instance) + queue_depths[instance["agent_name"]] = queue_depths.get(instance["agent_name"], 0) + depth + return queue_depths + + def _maybe_scale_from_queue_depth(self): + """Check each service's aggregate queue depth and scale up if needed.""" + queue_depths = self._queue_depth_by_service() + for controller_spec in self.controllers: + self._maybe_scale_service(controller_spec, queue_depths) + + def _maybe_scale_service(self, controller_spec, queue_depths): + """Scale up one replica when aggregate queue depth exceeds policy threshold. + + Creates the new runtime with publish=False, publishes Initializing status, + waits for health, then publishes Healthy. On failure, removes only the new + replica and restores the previous replica count. + """ + name = controller_spec["name"] + policy = self._autoscale_policies.get(name) + if not policy: + return + + threshold = policy["queue_length_scale_up_threshold"] + max_replicas = policy["max_replicas"] + observed_depth = queue_depths.get(name, 0) + current_replicas = len(self.runtime_manager.list_instances(name)) + + if observed_depth <= threshold or current_replicas >= max_replicas: + return + + target_replicas = min(max_replicas, max(current_replicas, int(controller_spec.get("replicas", 1))) + 1) + if target_replicas <= current_replicas: + return + + logger.info( + "Scaling %s from %d to %d replicas (queue_depth=%d threshold=%d)", + name, + current_replicas, + target_replicas, + observed_depth, + threshold, + ) + + previous_replicas = int(controller_spec.get("replicas", current_replicas or 1)) + controller_spec["replicas"] = target_replicas + replica_index = target_replicas - 1 + rollback_instance_id = f"{controller_spec.get('provider', 'local')}:{name}:{replica_index}" + instance = None + try: + instance = self.runtime_manager.ensure_replica( + controller_spec, + replica_index, + publish=False, + ) + self._mark_instance_status(instance, self.STATUS_INITIALIZING) + self._publish_routing_state() + self._wait_for_pending_healthy( + [(instance["agent_name"], instance["host"], instance["host_port"])], + timeout=30, + interval=2, + require_healthy=True, + ) + self._mark_instance_status(instance, self.STATUS_HEALTHY) + self._publish_routing_state() + except Exception: + self.runtime_manager.remove_instance( + rollback_instance_id, + publish=False, + ) + controller_spec["replicas"] = previous_replicas + if instance is not None: + self._clear_instance_state(instance) + raise + # ------------------------------------------------------------------ # # Cleanup trigger # # ------------------------------------------------------------------ # @@ -677,7 +996,7 @@ def launch_agents(self): """Create or reuse agent containers through RuntimeManager.""" try: self.containers = {} - instances = self.runtime_manager.ensure_instances(self.controllers) + instances = self.runtime_manager.ensure_instances(self.controllers, publish=False) total = len(instances) logger.info( "Ensured %d Docker container(s) across %d service(s).", diff --git a/ventis/controller/local_controller.py b/ventis/controller/local_controller.py index 98e1c2b..48166d7 100644 --- a/ventis/controller/local_controller.py +++ b/ventis/controller/local_controller.py @@ -9,6 +9,7 @@ import sys import time import importlib.util +import threading from concurrent.futures import ThreadPoolExecutor import grpc @@ -36,8 +37,14 @@ logger = logging.getLogger(__name__) ROUTING_ENDPOINTS_KEY = "routing_table:endpoints" +ROUTING_STATUS_KEY = "routing_table:status" ROUTING_STATEFUL_KEY = "routing_table:stateful" POLICY_RULES_KEY = "policy:rules" +QUEUE_DEPTH_KEY_PREFIX = "queue_depth" +ACTIVE_WORK_KEY_SUFFIX = "active_work" +LIFECYCLE_SIGNAL_SUFFIX = "lifecycle" +STATUS_SHUTTING_DOWN = "Shutting down" +STATUS_DELETE_READY = "Delete Ready" class LocalController(object): @@ -63,7 +70,12 @@ def __init__(self, port=50051): redis_port = int(os.environ.get("VENTIS_REDIS_PORT", 6379)) self.redis = RedisClient(host=redis_host, port=redis_port) self._status_key = f"controller:{self.agent_host}:{self.public_port}:status" + self._active_work_key_name = f"controller:{self.agent_host}:{self.public_port}:{ACTIVE_WORK_KEY_SUFFIX}" + self._lifecycle_signal_key = f"controller:{self.agent_host}:{self.public_port}:{LIFECYCLE_SIGNAL_SUFFIX}" self.redis.set(self._status_key, "healthy") + self.redis.set(self._active_work_key_name, "0") + self.redis.delete(self._lifecycle_signal_key) + self._publish_queue_depth() # Cache for gRPC stubs to remote controllers self._remote_channels = {} # endpoint -> grpc.Channel @@ -77,6 +89,9 @@ def __init__(self, port=50051): # that need to be routed through the same controller's request queue. max_instances = int(os.environ.get("VENTIS_MAX_AGENT_INSTANCES", 8)) self._executor = ThreadPoolExecutor(max_workers=max_instances) + self._active_requests = 0 + self._active_requests_lock = threading.Lock() + self._shutting_down = False logger.info("Local controller initialized at %s (max_agent_instances=%d), reported healthy to Redis.", self._my_endpoint, max_instances) @@ -133,6 +148,92 @@ def _load_policy_rules(self): self._policy_rules = [] return self._policy_rules + def _queue_depth_key(self): + return f"{QUEUE_DEPTH_KEY_PREFIX}:{self.agent_name}:{self.agent_host}:{self.public_port}" + + def _current_queue_depth(self): + return self.request_queue.qsize() + getattr(self, "_active_requests", 0) + + def _publish_queue_depth(self): + if not self.agent_name: + return + self.redis.set(self._queue_depth_key(), str(self._current_queue_depth())) + + def _publish_active_work(self): + key = getattr(self, "_active_work_key_name", None) + if not key: + return + self.redis.set(key, str(max(0, getattr(self, "_active_requests", 0)))) + + def _increment_active_requests(self): + lock = getattr(self, "_active_requests_lock", None) + if lock is None: + self._active_requests = getattr(self, "_active_requests", 0) + 1 + self._publish_active_work() + return + with lock: + self._active_requests += 1 + self._publish_active_work() + + def _decrement_active_requests(self): + lock = getattr(self, "_active_requests_lock", None) + if lock is None: + self._active_requests = max(0, getattr(self, "_active_requests", 0) - 1) + self._publish_active_work() + return + with lock: + self._active_requests = max(0, self._active_requests - 1) + self._publish_active_work() + + def _is_shutting_down(self): + """Return True when routing status marks this endpoint as Shutting down.""" + if not self.agent_name: + return False + raw = self.redis.hget(ROUTING_STATUS_KEY, self.agent_name) + if not raw: + return self._shutting_down + try: + statuses = json.loads(raw) + except json.JSONDecodeError: + return self._shutting_down + if statuses.get(self._my_endpoint) == STATUS_SHUTTING_DOWN: + self._shutting_down = True + return self._shutting_down + + def _has_affine_inflight_requests(self): + if not self.agent_name: + return False + for affinity_key in self.redis.scan_keys("affinity:*"): + if self.redis.hget(affinity_key, self.agent_name) != self._my_endpoint: + continue + request_id = affinity_key.split(":", 1)[1] + status = self.redis.get(f"request:{request_id}:status") + if status not in {"done", "error"}: + return True + return False + + def _is_delete_ready(self): + """Return True when local queue and active work are drained.""" + if not self.request_queue.empty(): + return False + if getattr(self, "_active_requests", 0) != 0: + return False + is_stateful = self.redis.hget(ROUTING_STATEFUL_KEY, self.agent_name) == "true" + if is_stateful and self._has_affine_inflight_requests(): + return False + return True + + def _update_lifecycle_signal(self): + """Tell the global controller this replica has finished draining. + + Writes Delete Ready to Redis only when routing status is Shutting down + and local queue, active work, and stateful affinity constraints are clear. + """ + if self._is_shutting_down() and self._is_delete_ready(): + self.redis.set(self._lifecycle_signal_key, STATUS_DELETE_READY) + else: + self.redis.delete(self._lifecycle_signal_key) + def _check_policy(self, service, context): """ Check if the given service is accessible for the given request context. @@ -212,8 +313,12 @@ def run(self): logger.info("Local controller started, polling request queue...") try: while True: + self._publish_queue_depth() + self._publish_active_work() + self._update_lifecycle_signal() if not self.request_queue.empty(): raw = self.request_queue.get() + self._publish_queue_depth() try: data = json.loads(raw) self._process_request(data) @@ -279,6 +384,8 @@ def _process_request(self, data): return if endpoint == self._my_endpoint: + self._increment_active_requests() + self._publish_queue_depth() self._executor.submit(self._execute_locally, service, function, args, future_id, origin, request_id) else: # Register the target as a consumer for any Future args @@ -391,6 +498,9 @@ def _execute_locally(self, service, function, args, future_id, origin=None, requ self.redis.hset(f"future:{future_id}", "result", f"Execution failed: {e}") if origin and origin != self._my_endpoint: self._send_result_callback(origin, future_id, f"Execution failed: {e}") + finally: + self._decrement_active_requests() + self._publish_queue_depth() # ------------------------------------------------------------------ # # Request forwarding # @@ -443,6 +553,12 @@ def stop(self): logger.info("Shutting down local controller...") self._executor.shutdown(wait=True) self.redis.set(self._status_key, "stopped") + if getattr(self, "_active_work_key_name", None): + self.redis.set(self._active_work_key_name, "0") + if getattr(self, "_lifecycle_signal_key", None): + self.redis.delete(self._lifecycle_signal_key) + if self.agent_name: + self.redis.set(self._queue_depth_key(), "0") self.server.stop(0) diff --git a/ventis/controller/runtime_manager.py b/ventis/controller/runtime_manager.py index 7e485f6..a043be6 100644 --- a/ventis/controller/runtime_manager.py +++ b/ventis/controller/runtime_manager.py @@ -51,6 +51,7 @@ class RuntimeManager: """Create, reuse, and publish agent runtimes.""" ROUTING_ENDPOINTS_KEY = "routing_table:endpoints" + ROUTING_STATUS_KEY = "routing_table:status" ROUTING_STATEFUL_KEY = "routing_table:stateful" SERVICES_SET_KEY = "routing_table:services" CONTAINER_PORT = 50051 @@ -64,10 +65,11 @@ def __init__(self, controller, redis_client=None): def redis(self): return self._redis or self.controller.redis - def ensure_instances(self, agent_specs): + def ensure_instances(self, agent_specs, publish=True): instances = [] self._agent_specs = list(agent_specs) - self.publish_routing_snapshot(self._agent_specs) + if publish: + self.publish_routing_snapshot(self._agent_specs) ec2_runtime._set_controller(self.controller) for agent_spec in agent_specs: @@ -82,7 +84,7 @@ def ensure_instances(self, agent_specs): key = self._instance_key(provider, agent_name, replica_index) instance = self.redis.hgetall(key) if instance: - self.remove_instance(instance_id) + self.remove_instance(instance_id, publish=publish) provisioned = ec2_runtime.provision_instance(agent_spec, replica_index) redis_port = int(agent_spec.get("redis_port", provisioned.get("redis_port", 6379))) self.controller.ensure_host_redis( @@ -91,7 +93,8 @@ def ensure_instances(self, agent_specs): redis_port, ssh_host=provisioned.get("ssh_host"), ) - self.publish_routing_snapshot(self._agent_specs) + if publish: + self.publish_routing_snapshot(self._agent_specs) instance = ec2_runtime.bootstrap_instance( provisioned, agent_spec, @@ -102,7 +105,8 @@ def ensure_instances(self, agent_specs): self._write_instance(instance) self._add_instance_to_agent(agent_name, instance_id) self._track_runtime(agent_name, instance["runtime_id"]) - self.publish_routing_snapshot(self._agent_specs) + if publish: + self.publish_routing_snapshot(self._agent_specs) instances.append(instance) continue @@ -117,7 +121,7 @@ def ensure_instances(self, agent_specs): pass else: if instance: - self.remove_instance(instance_id) + self.remove_instance(instance_id, publish=publish) instance = self._create_instance( agent_spec=agent_spec, host=host, @@ -130,11 +134,81 @@ def ensure_instances(self, agent_specs): self._add_instance_to_agent(agent_name, instance_id) self._track_runtime(agent_name, instance["runtime_id"]) - self.publish_routing_snapshot(self._agent_specs) + if publish: + self.publish_routing_snapshot(self._agent_specs) instances.append(instance) return instances + def ensure_replica(self, agent_spec, replica_index, publish=True): + """Create or reuse a single replica slot without re-provisioning siblings. + + Used for autoscale scale-up. Callers typically pass publish=False and + publish routing through GlobalController._publish_routing_state() so + lifecycle status (Initializing/Healthy) is included. + """ + self._agent_specs = list(self._current_agent_specs()) + if publish: + self.publish_routing_snapshot(self._agent_specs) + ec2_runtime._set_controller(self.controller) + + agent_name = agent_spec["name"] + provider = agent_spec.get("provider", "local") + self.controller.containers.setdefault(agent_name, []) + instance_id = self._instance_id(provider, agent_name, replica_index) + key = self._instance_key(provider, agent_name, replica_index) + instance = self.redis.hgetall(key) + + if provider.upper() == "EC2": + ec2_runtime.validate_config() + if instance: + self.remove_instance(instance_id, publish=publish) + provisioned = ec2_runtime.provision_instance(agent_spec, replica_index) + redis_port = int(agent_spec.get("redis_port", provisioned.get("redis_port", 6379))) + self.controller.ensure_host_redis( + provisioned["host"], + provisioned.get("user"), + redis_port, + ssh_host=provisioned.get("ssh_host"), + ) + if publish: + self.publish_routing_snapshot(self._agent_specs) + instance = ec2_runtime.bootstrap_instance( + provisioned, + agent_spec, + replica_index, + redis_host=provisioned["host"], + redis_port=redis_port, + ) + self._write_instance(instance) + self._add_instance_to_agent(agent_name, instance_id) + self._track_runtime(agent_name, instance["runtime_id"]) + if publish: + self.publish_routing_snapshot(self._agent_specs) + return instance + + placement = self._replica_placements(agent_spec)[replica_index] + host = placement["host"] + host_port = placement.get("host_port") + if instance and self._runtime_exists(instance) and self._placement_matches(instance, host, host_port): + return instance + if instance: + self.remove_instance(instance_id, publish=publish) + instance = self._create_instance( + agent_spec=agent_spec, + host=host, + host_port=host_port, + replica_index=replica_index, + instance_id=instance_id, + previous_instance=instance, + ) + self._write_instance(instance) + self._add_instance_to_agent(agent_name, instance_id) + self._track_runtime(agent_name, instance["runtime_id"]) + if publish: + self.publish_routing_snapshot(self._agent_specs) + return instance + def _write_instance(self, instance): key = self._instance_key( instance["provider"], @@ -160,10 +234,7 @@ def _write_instance(self, instance): def _add_instance_to_agent(self, agent_name, instance_id): self.redis.sadd(f"agent:{agent_name}:instances", instance_id) - def _publish_endpoint(self, instance): - self.publish_routing_snapshot(self._current_agent_specs()) - - def remove_instance(self, instance_id): + def remove_instance(self, instance_id, publish=True): key = f"agent_instance:{instance_id}" instance = self.redis.hgetall(key) if not instance: @@ -173,7 +244,8 @@ def remove_instance(self, instance_id): self.redis.delete(key) self.redis.srem(f"agent:{instance['agent_name']}:instances", instance_id) - self.publish_routing_snapshot(self._current_agent_specs()) + if publish: + self.publish_routing_snapshot(self._current_agent_specs()) containers = self.controller.containers.get(instance["agent_name"], []) self.controller.containers[instance["agent_name"]] = [ @@ -397,14 +469,22 @@ def _instance_key(cls, provider, agent_name, replica_index): def sync_routing_metadata(self, agent_specs): self.publish_routing_snapshot(agent_specs) - def publish_routing_snapshot(self, agent_specs): - """Copy routing metadata derived from central records to host Redis.""" + def publish_routing_snapshot(self, agent_specs, lifecycle_statuses=None, routable_endpoints=None): + """Copy routing metadata derived from central records to host Redis. + + lifecycle_statuses maps service -> {endpoint: status} for routing_table:status. + routable_endpoints maps service -> set of endpoints that accept new traffic; + endpoints omitted from routable_endpoints stay in status but are excluded + from routing_table:endpoints (used during scale-down draining). + """ services = {agent_spec["name"] for agent_spec in agent_specs} stateful = { agent_spec["name"] for agent_spec in agent_specs if agent_spec.get("stateful", False) } + lifecycle_statuses = lifecycle_statuses or {} + routable_endpoints = routable_endpoints or {} for redis_client in self._routing_redis_targets(): existing_services = redis_client.smembers(self.SERVICES_SET_KEY) @@ -412,18 +492,23 @@ def publish_routing_snapshot(self, agent_specs): redis_client.srem(self.SERVICES_SET_KEY, stale) self._hdel(redis_client, self.ROUTING_STATEFUL_KEY, stale) self._hdel(redis_client, self.ROUTING_ENDPOINTS_KEY, stale) + self._hdel(redis_client, self.ROUTING_STATUS_KEY, stale) for service in services: redis_client.sadd(self.SERVICES_SET_KEY, service) if service in stateful: redis_client.hset(self.ROUTING_STATEFUL_KEY, service, "true") else: self._hdel(redis_client, self.ROUTING_STATEFUL_KEY, service) + service_instances = sorted( + self.list_instances(service), + key=lambda item: int(item["replica_index"]), + ) endpoints = [ self._routing_endpoint_for(item) - for item in sorted( - self.list_instances(service), - key=lambda item: int(item["replica_index"]), - ) + for item in service_instances + if self._routing_endpoint_for(item) in routable_endpoints.get(service, { + self._routing_endpoint_for(instance) for instance in service_instances + }) ] if endpoints: redis_client.hset( @@ -433,6 +518,20 @@ def publish_routing_snapshot(self, agent_specs): ) else: self._hdel(redis_client, self.ROUTING_ENDPOINTS_KEY, service) + statuses = lifecycle_statuses.get(service) + if statuses is None: + statuses = { + self._routing_endpoint_for(item): "Healthy" + for item in service_instances + } + if statuses: + redis_client.hset( + self.ROUTING_STATUS_KEY, + service, + json.dumps(statuses), + ) + else: + self._hdel(redis_client, self.ROUTING_STATUS_KEY, service) def publish_policy_rules(self, rules): rules_json = json.dumps(rules) diff --git a/ventis/templates/config/policy.yaml b/ventis/templates/config/policy.yaml index 7367437..348b0c9 100644 --- a/ventis/templates/config/policy.yaml +++ b/ventis/templates/config/policy.yaml @@ -14,3 +14,10 @@ rules: - ExampleAgent - VllmAgent - Workflow + +autoscale: + ExampleAgent: + queue_length_scale_up_threshold: 10 + min_replicas: 1 + idle_seconds_before_scale_down: 60 + max_replicas: 3