Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/forge/orchestrator/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ async def _handle_resume_event(
updated_state["is_paused"] = False
updated_state["is_blocked"] = False
updated_state["last_error"] = None
updated_state["auto_retry_cap_notified"] = False
updated_state["revision_requested"] = True
updated_state["feedback_comment"] = "Regeneration requested via retry."
updated_state["retry_count"] = 0
Expand All @@ -1147,6 +1148,7 @@ async def _handle_resume_event(
updated_state["is_paused"] = False
updated_state["is_blocked"] = False
updated_state["last_error"] = None
updated_state["auto_retry_cap_notified"] = False
updated_state["revision_requested"] = False
updated_state["feedback_comment"] = None
updated_state["retry_count"] = 0
Expand Down Expand Up @@ -1216,11 +1218,23 @@ async def _handle_resume_event(
reason = (
"terminal state" if is_terminal else f"retry cap ({MAX_AUTO_RETRIES}) reached"
)
if cap_reached and current_state.get("auto_retry_cap_notified"):
logger.info(
f"Workflow for {message.ticket_key} is already blocked after "
f"auto-retry cap at '{current_node}'"
)
return current_state

logger.warning(
f"Workflow for {message.ticket_key} at '{current_node}' requires "
f"forge:retry ({reason})"
)
await self._post_terminal_error_comment(message.ticket_key, last_error)
if cap_reached:
updated_state["is_paused"] = True
updated_state["is_blocked"] = True
updated_state["auto_retry_cap_notified"] = True
return updated_state
return current_state
else:
# Transient failure — auto-resume and let the node retry
Expand Down
74 changes: 59 additions & 15 deletions src/forge/workflow/nodes/workspace_setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Workspace setup node for LangGraph workflow."""

import logging
import shutil
from pathlib import Path
from typing import Any

Expand All @@ -20,6 +21,41 @@
logger = logging.getLogger(__name__)


def _recreate_workspace_from_fork(
*,
ticket_key: str,
current_repo: str,
branch_name: str,
fork_owner: str,
fork_repo: str,
stale_workspace_path: str | None = None,
) -> tuple[str, GitOperations]:
if not branch_name or not current_repo or not fork_owner or not fork_repo:
raise ValueError(
f"Cannot recreate workspace for {ticket_key}: "
"missing branch_name, current_repo, fork_owner, or fork_repo in state"
)

if stale_workspace_path:
stale_path = Path(stale_workspace_path)
if stale_path.exists():
logger.warning(
"Removing existing workspace for %s before recreating from fork: %s",
ticket_key,
stale_path,
)
shutil.rmtree(stale_path)

manager = WorkspaceManager(base_dir=get_settings().workspace_base_dir)
workspace_obj = manager.create_workspace(repo_name=current_repo, ticket_key=ticket_key)
git = GitOperations(workspace_obj)
git.clone()
git.add_fork_remote(fork_owner, fork_repo)
git.checkout_branch(branch_name, remote="fork")
logger.info(f"Workspace recreated at {workspace_obj.path} for {ticket_key}")
return str(workspace_obj.path), git


def prepare_workspace(
state: WorkflowState,
remote: str = "fork",
Expand Down Expand Up @@ -61,24 +97,32 @@ def prepare_workspace(
ticket_key=ticket_key,
)
git = GitOperations(workspace)
git.pull_rebase(remote=remote)
try:
git.pull_rebase(remote=remote)
except Exception as e:
logger.warning(
"Workspace sync failed for %s; recreating workspace from fork: %s",
ticket_key,
e,
)
return _recreate_workspace_from_fork(
ticket_key=ticket_key,
current_repo=current_repo,
branch_name=branch_name,
fork_owner=fork_owner,
fork_repo=fork_repo,
stale_workspace_path=workspace_path,
)
return workspace_path, git

# Workspace is missing — recreate from fork branch.
if not branch_name or not current_repo or not fork_owner or not fork_repo:
raise ValueError(
f"Cannot recreate workspace for {ticket_key}: "
"missing branch_name, current_repo, fork_owner, or fork_repo in state"
)

manager = WorkspaceManager(base_dir=get_settings().workspace_base_dir)
workspace_obj = manager.create_workspace(repo_name=current_repo, ticket_key=ticket_key)
git = GitOperations(workspace_obj)
git.clone()
git.add_fork_remote(fork_owner, fork_repo)
git.checkout_branch(branch_name, remote="fork")
logger.info(f"Workspace recreated at {workspace_obj.path} for {ticket_key}")
return str(workspace_obj.path), git
return _recreate_workspace_from_fork(
ticket_key=ticket_key,
current_repo=current_repo,
branch_name=branch_name,
fork_owner=fork_owner,
fork_repo=fork_repo,
)


# Global workspace manager instance
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/orchestrator/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,34 @@ async def test_prd_label_change_to_approved_sets_approved_flag(
assert result["revision_requested"] is False
assert result["is_paused"] is False

@pytest.mark.asyncio
async def test_auto_retry_cap_marks_workflow_blocked_once(
self,
worker: OrchestratorWorker,
base_message: QueueMessage,
base_state: dict,
):
"""Errored workflows stop auto-resuming once retry_count reaches the cap."""
state = {
**base_state,
"current_node": "implement_review",
"is_paused": False,
"last_error": "cannot rebase dirty workspace",
"retry_count": 3,
"is_blocked": False,
}

with patch.object(worker, "_post_terminal_error_comment", new_callable=AsyncMock) as post:
result = await worker._handle_resume_event(base_message, state)

assert result["current_node"] == "implement_review"
assert result["retry_count"] == 3
assert result["last_error"] == "cannot rebase dirty workspace"
assert result["is_paused"] is True
assert result["is_blocked"] is True
assert result["auto_retry_cap_notified"] is True
post.assert_awaited_once_with("TEST-123", "cannot rebase dirty workspace")

@pytest.mark.asyncio
async def test_question_with_leading_whitespace(
self, worker: OrchestratorWorker, base_message: QueueMessage, base_state: dict
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/workflow/nodes/test_workspace_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from forge.models.workflow import ForgeLabel
from forge.workflow.feature.state import create_initial_feature_state
from forge.workflow.nodes.workspace_setup import setup_workspace
from forge.workflow.nodes.workspace_setup import prepare_workspace, setup_workspace


def create_mock_jira_client():
Expand Down Expand Up @@ -239,3 +239,45 @@ async def test_workspace_setup_continues_on_jira_failure(self, caplog):
# Verify workspace setup continued successfully
assert result["workspace_path"] == str(Path("/tmp/test-workspace"))
mock_jira.close.assert_called_once()


class TestPrepareWorkspaceRecovery:
"""Tests for prepare_workspace workspace sync/recreation behavior."""

def test_sync_failure_recreates_workspace_from_fork(self, tmp_path):
"""A workspace that cannot sync is deleted and cloned fresh from the fork."""
workspace_path = tmp_path / "forge-TEST-123-org-repo"
workspace_path.mkdir()
stale_file = workspace_path / "stale.txt"
stale_file.write_text("dirty")

state = create_initial_feature_state(
ticket_key="TEST-123",
current_repo="org/repo",
workspace_path=str(workspace_path),
fork_owner="forge-bot",
fork_repo="repo",
context={"branch_name": "forge/test-123"},
)

old_git = MagicMock()
old_git.pull_rebase.side_effect = RuntimeError("any workspace sync failure")
new_git = MagicMock()
settings = MagicMock(workspace_base_dir=str(tmp_path))

with (
patch("forge.workflow.nodes.workspace_setup.get_settings", return_value=settings),
patch(
"forge.workflow.nodes.workspace_setup.GitOperations",
side_effect=[old_git, new_git],
),
):
result_path, result_git = prepare_workspace(state)

assert result_path == str(workspace_path)
assert result_git is new_git
assert not stale_file.exists()
old_git.pull_rebase.assert_called_once_with(remote="fork")
new_git.clone.assert_called_once()
new_git.add_fork_remote.assert_called_once_with("forge-bot", "repo")
new_git.checkout_branch.assert_called_once_with("forge/test-123", remote="fork")
Loading