Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ class NodeServerConfig:
# from the field name).
rate_limit_per_node: int = 100

# Number of times to retry an exec/read/write call when the
# target node is not connected or the server is unreachable.
# Each retry uses exponential backoff (1s, 2s, 4s... capped
# at max_retry_backoff). Set to 0 to disable retries.
# Env var: ``HERMES_NODES_MAX_RETRIES``.
max_retries: int = 3

# Maximum backoff between retries, in seconds. The actual
# delay is min(backoff_seconds * 2^attempt, max_retry_backoff).
# Env var: ``HERMES_NODES_RETRY_BACKOFF_SECONDS``.
retry_backoff_seconds: float = 2.0

def __post_init__(self) -> None:
# TLS partial-config is the most common deployment footgun: an
# operator sets tls_cert_path but forgets tls_key_path (or vice
Expand Down Expand Up @@ -514,6 +526,60 @@ def _build(
f"got {rate_limit_raw!r}"
) from exc

# -- max retries (int) ---------------------------------------------------
max_retries_raw: Any = None
max_retries_source = "default"
if env.get(_env_name("max_retries")) is not None:
max_retries_raw = env[_env_name("max_retries")]
max_retries_source = "env"
elif _read_file_value(file_data, "max_retries") is not None:
max_retries_raw = _read_file_value(file_data, "max_retries")
max_retries_source = "file"
max_retries_val: int | None = None
if max_retries_raw is not None:
if isinstance(max_retries_raw, str) and max_retries_raw.strip() == "":
max_retries_val = None
else:
try:
max_retries_val = int(max_retries_raw)
except (TypeError, ValueError) as exc:
raise ConfigError(
f"{max_retries_source}: max_retries must be an integer, "
f"got {max_retries_raw!r}"
) from exc
if max_retries_val < 0:
raise ConfigError(
f"{max_retries_source}: max_retries must be >= 0, "
f"got {max_retries_val}"
)

# -- retry backoff seconds (float) --------------------------------------
retry_backoff_raw: Any = None
retry_backoff_source = "default"
if env.get(_env_name("retry_backoff_seconds")) is not None:
retry_backoff_raw = env[_env_name("retry_backoff_seconds")]
retry_backoff_source = "env"
elif _read_file_value(file_data, "retry_backoff_seconds") is not None:
retry_backoff_raw = _read_file_value(file_data, "retry_backoff_seconds")
retry_backoff_source = "file"
retry_backoff_val: float | None = None
if retry_backoff_raw is not None:
if isinstance(retry_backoff_raw, str) and retry_backoff_raw.strip() == "":
retry_backoff_val = None
else:
try:
retry_backoff_val = float(retry_backoff_raw)
except (TypeError, ValueError) as exc:
raise ConfigError(
f"{retry_backoff_source}: retry_backoff_seconds must be a number, "
f"got {retry_backoff_raw!r}"
) from exc
if retry_backoff_val <= 0:
raise ConfigError(
f"{retry_backoff_source}: retry_backoff_seconds must be > 0, "
f"got {retry_backoff_val}"
)

# Now assemble. We use a partial dict + NodeServerConfig defaults for
# any key we didn't resolve — dataclass handles the "default" leg of
# the precedence chain.
Expand Down Expand Up @@ -543,6 +609,10 @@ def _build(
resolved["heartbeat_sweep_interval_seconds"] = sweep_interval
if rate_limit is not None:
resolved["rate_limit_per_node"] = rate_limit
if max_retries_val is not None:
resolved["max_retries"] = max_retries_val
if retry_backoff_val is not None:
resolved["retry_backoff_seconds"] = retry_backoff_val

return NodeServerConfig(**resolved)

Expand Down
145 changes: 123 additions & 22 deletions tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def handler(args, **kw) -> str:

import json
import logging
import time

from typing import TYPE_CHECKING, Any

Expand All @@ -30,6 +31,99 @@ def handler(args, **kw) -> str:
from .registry import NodeConnection, NodeRegistry


# ---------------------------------------------------------------------------
# Retry helper
# ---------------------------------------------------------------------------


def _retry_config() -> tuple[int, float]:
"""Return (max_retries, backoff_seconds) from the current config."""
from .config import load_config

cfg = load_config()
return cfg.max_retries, cfg.retry_backoff_seconds


def _should_retry(status_code: int, reason: str = "") -> bool:
"""Return True if the response suggests a transient failure worth retrying.

Retries on:
- Server errors (5xx)
- Node not connected (any status where reason mentions "not connected")
"""
if status_code >= 500:
return True
if "not connected" in reason.lower():
return True
return False


def _request_with_retry(
method: str,
url: str,
*,
json_body: dict[str, Any] | None = None,
timeout: float = 30.0,
) -> dict[str, Any]:
"""Make an HTTP request with exponential backoff retry.

Retries up to ``max_retries`` times when the server is unreachable
or returns a transient error. Between retries, sleeps
``backoff * 2^attempt`` seconds (capped at 30s).
"""
import httpx

max_retries, backoff = _retry_config()
last_error: Exception | None = None
last_result: dict[str, Any] | None = None

for attempt in range(max_retries + 1):
last_error = None
last_result = None
try:
with httpx.Client(timeout=timeout) as client:
if method == "POST":
resp = client.post(url, json=json_body)
else:
resp = client.get(url)
result = resp.json()
except Exception as e:
last_error = e
if attempt < max_retries:
delay = min(backoff * (2 ** attempt), 30.0)
logger.warning(
"Request to %s failed (attempt %d/%d): %s. Retrying in %.1fs...",
url, attempt + 1, max_retries + 1, e, delay,
)
time.sleep(delay)
continue
break

# Check if the response indicates a transient error.
reason = ""
if isinstance(result, dict):
reason = result.get("reason", "") or result.get("error", "")
if _should_retry(resp.status_code, reason):
last_result = result
if attempt < max_retries:
delay = min(backoff * (2 ** attempt), 30.0)
logger.warning(
"Request to %s returned %d (attempt %d/%d): %s. Retrying in %.1fs...",
url, resp.status_code, attempt + 1, max_retries + 1, reason, delay,
)
time.sleep(delay)
continue
else:
return result # success or non-retryable error

# All retries exhausted.
if last_error:
return {"error": f"Request failed after {max_retries + 1} attempts: {last_error}"}
if last_result:
return last_result
return {"error": "Request failed: unknown error"}


# ---------------------------------------------------------------------------
# Tool implementations (called by the wrapper handlers below)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -65,8 +159,6 @@ def _node_exec_impl(
)

try:
import httpx

url = f"http://{cfg.connect_host}:{cfg.port}/nodes/{target}/exec"
payload: dict[str, Any] = {"command": command}
if cwd:
Expand All @@ -76,10 +168,9 @@ def _node_exec_impl(
if timeout_ms is not None:
payload["timeout_ms"] = timeout_ms

with httpx.Client(timeout=timeout_s + 5.0) as client:
response = client.post(url, json=payload)
response.raise_for_status()
result = response.json()
result = _request_with_retry(
"POST", url, json_body=payload, timeout=timeout_s + 5.0,
)

# Normalise to {"output": str, "returncode": int}
if result.get("status") == "ok":
Expand All @@ -93,9 +184,14 @@ def _node_exec_impl(
"returncode": exec_result.get("returncode", 0),
})
else:
# timeout or other non-error status
# timeout, node not connected, or other non-ok status
error_msg = (
result.get("reason")
or result.get("error")
or "unknown error"
)
return json.dumps({
"output": result.get("reason", ""),
"output": error_msg,
"returncode": 1,
})
except Exception as e:
Expand Down Expand Up @@ -125,14 +221,11 @@ def _node_read_impl(
)

try:
import httpx

cfg = load_config()
url = f"http://{cfg.connect_host}:{cfg.port}/nodes/{target}/read"
with httpx.Client(timeout=timeout_s + 5.0) as client:
response = client.post(url, json={"path": path})
response.raise_for_status()
result = response.json()
result = _request_with_retry(
"POST", url, json_body={"path": path}, timeout=timeout_s + 5.0,
)

if result.get("status") == "ok":
read_result = result.get("read_result", {})
Expand All @@ -143,8 +236,13 @@ def _node_read_impl(
"encoding": "utf-8",
})
else:
error_msg = (
result.get("reason")
or result.get("error")
or "read failed"
)
return json.dumps({
"error": result.get("reason", "read failed"),
"error": error_msg,
"code": result.get("code", 0),
})
except Exception as e:
Expand Down Expand Up @@ -185,23 +283,26 @@ def _node_write_impl(
})

try:
import httpx

cfg = load_config()
url = f"http://{cfg.connect_host}:{cfg.port}/nodes/{target}/write"
with httpx.Client(timeout=timeout_s + 5.0) as client:
response = client.post(url, json={"path": path, "content": content, "mode": mode})
response.raise_for_status()
result = response.json()
result = _request_with_retry(
"POST", url, json_body={"path": path, "content": content, "mode": mode},
timeout=timeout_s + 5.0,
)

if result.get("status") == "ok":
write_result = result.get("write_result", {})
return json.dumps({
"bytes_written": write_result.get("bytes_written", 0),
})
else:
error_msg = (
result.get("reason")
or result.get("error")
or "write failed"
)
return json.dumps({
"error": result.get("reason", "write failed"),
"error": error_msg,
"code": result.get("code", 0),
})
except Exception as e:
Expand Down
Loading