diff --git a/.claude-plugin/plugin.json b/.claude-plugin/plugin.json index e965c14..baaf577 100644 --- a/.claude-plugin/plugin.json +++ b/.claude-plugin/plugin.json @@ -1,6 +1,6 @@ { "name": "autocode", - "version": "0.6.0", + "version": "0.7.0", "description": "Claude Code plugin for competitive programming problem-setting workflows.", "author": { "name": "SummerOneTwo", diff --git a/CHANGELOG.md b/CHANGELOG.md index 968dbf0..cbb5906 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,27 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.7.0] - 2026-04-27 + +### Features + +- **source_path 直接编译**: 当使用 `source_path` 参数时,直接从原始文件编译,不再覆盖到标准位置。标准位置仍保留副本以供其他工具使用。所有构建工具返回 `canonical_path`(标准位置副本)和 `source_path`(实际编译源)。 +- **resolve_source() 公共函数**: 提取 5 个构建工具中的源码解析逻辑到 `mixins.py` 的 `resolve_source()` 函数和 `ResolvedSource` 数据类,消除约 100 行重复代码。 +- **name 参数**: `solution_build` 和 `solution_run` 新增 `name` 参数,支持自定义文件名(如 `name="brute_force"` 替代默认 `brute`)。 +- **sol_name / brute_name**: `stress_test_run` 新增 `sol_name` 和 `brute_name` 参数,支持查找自定义命名的解法二进制文件。 +- **output_dir 参数**: `problem_generate_tests` 新增 `output_dir` 参数,可指定测试数据输出目录(默认 `problem_dir/tests`)。 +- **extra_args 参数**: `stress_test_run`、`generator_run`、`problem_generate_tests` 的 `test_configs` 新增 `extra_args` 参数,支持传递自定义命令行参数给 generator。协议扩展为 `gen.exe [extra_args...]`。 +- **types 参数**: `stress_test_run` 新增 `types` 参数,支持在对拍中循环使用多种生成策略(如 `["1","2","3","4"]`)。 +- **problem_verify_tests 工具**: 新增测试数据验证工具,检查文件配对、答案一致性(重新运行 sol)、validator 验证、无空文件等。 +- **stress_test_run 统计信息**: 对拍通过/失败时返回详细统计,包括 sol/brute 运行时间分布、N 值分布、最慢轮次等。 +- **构建结果透明度**: 所有构建工具返回 `binary_size` 和 `canonical_path`,`source_path` 返回实际编译源文件路径。 + +### Improvements + +- **smart mode 文档**: `problem_generate_tests` 的 `constraints` 参数说明更明确,返回 `effective_test_configs` 展示实际使用的配置。 +- **workflow_guard 自定义命名**: `infer_state()` 支持自定义解法文件名(前缀匹配),新增 `tests_verified` 状态字段。 +- **工作流步骤更新**: 新增 `problem_verify_tests(passed)` 步骤,位于 `problem_generate_tests` 和 `problem_pack_polygon` 之间。 + ## [0.6.0] - 2026-04-25 ### Features diff --git a/CLAUDE.md b/CLAUDE.md index 279bd40..58c8d7f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -66,6 +66,7 @@ AutoCode/ | stress_test_run | 压力测试 | | problem_create | 初始化题目 | | problem_generate_tests | 生成测试数据 | +| problem_verify_tests | 验证测试数据质量 | | problem_validate | 验证题面样例 | | problem_pack_polygon | 打包为 Polygon 格式 | @@ -102,7 +103,8 @@ AutoCode/ 6. 运行压力测试 (`stress_test_run`, completed_rounds == total_rounds) 7. 按需构建检查器 (`checker_build`, accuracy >= 0.9) 8. 生成测试数据 (`problem_generate_tests`, generated_test_count > 0) -9. 打包 Polygon (`problem_pack_polygon`) +9. 验证测试数据 (`problem_verify_tests`, passed) +10. 打包 Polygon (`problem_pack_polygon`) 该顺序会被 [hooks/hooks.json](/c:/userProgram/program/AutoCode/hooks/hooks.json) 和 [scripts/workflow_guard.py](/c:/userProgram/program/AutoCode/scripts/workflow_guard.py) 实际强制执行。 diff --git a/pyproject.toml b/pyproject.toml index 77cce83..2efc2dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "autocode-mcp" -version = "0.6.0" +version = "0.7.0" description = "MCP Server for competitive programming problem creation, based on AutoCode paper" readme = "README.md" requires-python = ">=3.10" diff --git a/scripts/workflow_guard.py b/scripts/workflow_guard.py index 0780e40..763b2cf 100644 --- a/scripts/workflow_guard.py +++ b/scripts/workflow_guard.py @@ -36,11 +36,12 @@ def state_file(problem_dir: str) -> Path: def infer_state(problem_dir: str) -> dict[str, Any]: root = Path(problem_dir) + solutions_dir = root / "solutions" return { "problem_dir": str(root), "created": root.exists() and (root / "files").exists() and (root / "solutions").exists(), - "sol_built": (root / "solutions" / "sol.cpp").exists() or any(root.glob("solutions/sol.*")), - "brute_built": (root / "solutions" / "brute.cpp").exists() or any(root.glob("solutions/brute.*")), + "sol_built": _has_solution(solutions_dir, "sol"), + "brute_built": _has_solution(solutions_dir, "brute"), "validator_ready": (root / "files" / "val.cpp").exists() or any(root.glob("files/val.*")), "validator_accuracy": None, "generator_built": (root / "files" / "gen.cpp").exists() or any(root.glob("files/gen.*")), @@ -54,10 +55,25 @@ def infer_state(problem_dir: str) -> dict[str, Any]: "validation_passed": False, "tests_generated": any((root / "tests").glob("*.in")) if (root / "tests").exists() else False, "generated_test_count": len(list((root / "tests").glob("*.in"))) if (root / "tests").exists() else 0, + "tests_verified": False, "packaged": (root / "problem.xml").exists(), } +def _has_solution(solutions_dir: Path, prefix: str) -> bool: + """检查 solutions/ 下是否有指定前缀的解法文件(支持自定义命名)。""" + if not solutions_dir.exists(): + return False + # 精确匹配(如 sol.cpp, brute.cpp) + if (solutions_dir / f"{prefix}.cpp").exists(): + return True + # 前缀匹配(如 brute_force.cpp) + for f in solutions_dir.iterdir(): + if f.is_file() and f.stem.startswith(prefix) and f.suffix == ".cpp": + return True + return False + + def load_state(problem_dir: str) -> dict[str, Any]: path = state_file(problem_dir) if path.exists(): @@ -120,7 +136,7 @@ def pre_tool(payload: dict[str, Any]) -> int: "checker_build": "必须先通过 stress_test_run(completed_rounds == total_rounds),再构建 checker。", "problem_validate": "必须先通过 stress_test_run(completed_rounds == total_rounds),再进行验证。", "problem_generate_tests": "必须先通过 problem_validate(验证通过),才能生成最终测试数据。", - "problem_pack_polygon": "必须先生成最终测试数据,并且生成数量 > 0,再进行打包。", + "problem_pack_polygon": "必须先生成最终测试数据并通过 problem_verify_tests(passed),再进行打包。", } tool_input = payload.get("tool_input", {}) @@ -169,6 +185,7 @@ def pre_tool(payload: dict[str, Any]) -> int: if short_name == "problem_pack_polygon" and not ( state["tests_generated"] and state.get("generated_test_count", 0) > 0 + and state.get("tests_verified", False) ): deny(reasons["problem_pack_polygon"]) return 0 @@ -194,6 +211,12 @@ def post_tool(payload: dict[str, Any]) -> int: save_state(problem_dir, state) return 0 + if short_name == "problem_verify_tests" and not success: + state = load_state(problem_dir) + state["tests_verified"] = False + save_state(problem_dir, state) + return 0 + if not success: return 0 @@ -229,6 +252,9 @@ def post_tool(payload: dict[str, Any]) -> int: generated_tests = data.get("generated_tests", []) state["tests_generated"] = bool(generated_tests) state["generated_test_count"] = len(generated_tests) + state["tests_verified"] = False + elif short_name == "problem_verify_tests": + state["tests_verified"] = bool(data.get("passed", False)) elif short_name == "problem_pack_polygon": state["packaged"] = True @@ -244,7 +270,8 @@ def session_start() -> int: "stress_test_run(completed_rounds == total_rounds) -> " "checker_build if needed (accuracy >= 0.9) -> " "problem_validate(validation_passed) -> " - "problem_generate_tests(generated_test_count > 0) -> problem_pack_polygon. " + "problem_generate_tests(generated_test_count > 0) -> " + "problem_verify_tests(passed) -> problem_pack_polygon. " "If a hook blocks a step, complete the missing prerequisite instead of retrying blindly." ) print( diff --git a/src/autocode_mcp/__init__.py b/src/autocode_mcp/__init__.py index 63594bd..2988cc1 100644 --- a/src/autocode_mcp/__init__.py +++ b/src/autocode_mcp/__init__.py @@ -6,7 +6,7 @@ """ import os -__version__ = "0.6.0" +__version__ = "0.7.0" # 获取 templates 目录路径(包内目录) _PACKAGE_DIR = os.path.dirname(__file__) diff --git a/src/autocode_mcp/server.py b/src/autocode_mcp/server.py index 577fb4f..6cadd2d 100644 --- a/src/autocode_mcp/server.py +++ b/src/autocode_mcp/server.py @@ -1,7 +1,7 @@ """ MCP Server 入口。 -提供 15 个原子工具,基于 AutoCode 论文框架。 +提供 17 个原子工具,基于 AutoCode 论文框架。 """ from __future__ import annotations @@ -35,6 +35,7 @@ from .tools.problem import ProblemCreateTool, ProblemGenerateTestsTool, ProblemPackPolygonTool from .tools.solution import SolutionBuildTool, SolutionRunTool from .tools.stress_test import StressTestRunTool +from .tools.test_verify import ProblemVerifyTestsTool from .tools.validation import ProblemValidateTool from .tools.validator import ValidatorBuildTool, ValidatorSelectTool @@ -67,6 +68,7 @@ def register_all_tools() -> None: # Problem 工具组 register_tool(ProblemCreateTool()) register_tool(ProblemGenerateTestsTool()) + register_tool(ProblemVerifyTestsTool()) register_tool(ProblemPackPolygonTool()) register_tool(ProblemValidateTool()) diff --git a/src/autocode_mcp/tools/checker.py b/src/autocode_mcp/tools/checker.py index 4d13fb8..df13826 100644 --- a/src/autocode_mcp/tools/checker.py +++ b/src/autocode_mcp/tools/checker.py @@ -11,7 +11,7 @@ from ..utils.compiler import run_binary_with_args from ..utils.platform import get_exe_extension from .base import Tool, ToolResult -from .mixins import BuildToolMixin +from .mixins import BuildToolMixin, resolve_source class CheckerBuildTool(Tool, BuildToolMixin): @@ -91,58 +91,45 @@ async def execute( compiler: str = "g++", ) -> ToolResult: """执行 Checker 构建。""" - # 解析源代码:source_path 优先于 code - source_dir = None - if source_path: - if not os.path.isabs(source_path): - source_path = os.path.join(problem_dir, source_path) - if not os.path.exists(source_path): - return ToolResult.fail(f"Source file not found: {source_path}") - try: - with open(source_path, encoding="utf-8") as f: - code = f.read() - except UnicodeDecodeError: - try: - with open(source_path, encoding="latin-1") as f: - code = f.read() - except Exception as e: - return ToolResult.fail(f"Failed to read source file: {e}") - source_dir = os.path.dirname(os.path.abspath(source_path)) - elif code is None: - return ToolResult.fail("Either 'code' or 'source_path' must be provided") + resolved, err = resolve_source(problem_dir, code, source_path) + if err is not None: + return err + assert resolved is not None os.makedirs(problem_dir, exist_ok=True) - - # 保存到 files/ 子目录 files_dir = os.path.join(problem_dir, "files") os.makedirs(files_dir, exist_ok=True) - # 保存代码 - source_path = os.path.join(files_dir, "checker.cpp") + canonical_path = os.path.join(files_dir, "checker.cpp") try: - with open(source_path, "w", encoding="utf-8") as f: - f.write(code) + with open(canonical_path, "w", encoding="utf-8") as f: + f.write(resolved.code) except Exception as e: return ToolResult.fail(f"Failed to save code: {str(e)}") - # 编译 binary_path = os.path.join(files_dir, f"checker{get_exe_extension()}") - include_dirs = [source_dir] if source_dir else None - compile_result = await self.build(source_path, binary_path, compiler=compiler, include_dirs=include_dirs) + compile_source = resolved.original_source_path or canonical_path + include_dirs = [resolved.include_dir] if resolved.include_dir else None + compile_result = await self.build(compile_source, binary_path, compiler=compiler, include_dirs=include_dirs) if not compile_result.success: return ToolResult.fail( f"Compilation failed: {compile_result.error}", - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, compile_log=compile_result.stderr, ) + binary_size = os.path.getsize(binary_path) if os.path.exists(binary_path) else 0 + # 如果没有测试场景,直接返回成功 if not test_scenarios: return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, + binary_size=binary_size, compile_log=compile_result.stderr, message="Checker built successfully (no test scenarios provided)", ) @@ -214,8 +201,10 @@ async def execute( accuracy = correct_count / total if total > 0 else 0 return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, + binary_size=binary_size, compile_log=compile_result.stderr, test_results=test_results, correct_count=correct_count, diff --git a/src/autocode_mcp/tools/generator.py b/src/autocode_mcp/tools/generator.py index 479f269..a487523 100644 --- a/src/autocode_mcp/tools/generator.py +++ b/src/autocode_mcp/tools/generator.py @@ -12,7 +12,7 @@ from ..utils.compiler import run_binary, run_binary_with_args from ..utils.platform import get_exe_extension from .base import Tool, ToolResult -from .mixins import BuildToolMixin +from .mixins import BuildToolMixin, resolve_source class GeneratorBuildTool(Tool, BuildToolMixin): @@ -74,55 +74,44 @@ async def execute( compiler: str = "g++", ) -> ToolResult: """执行 Generator 构建。""" - # 解析源代码:source_path 优先于 code - source_dir = None - if source_path: - if not os.path.isabs(source_path): - source_path = os.path.join(problem_dir, source_path) - if not os.path.exists(source_path): - return ToolResult.fail(f"Source file not found: {source_path}") - try: - with open(source_path, encoding="utf-8") as f: - code = f.read() - except UnicodeDecodeError: - try: - with open(source_path, encoding="latin-1") as f: - code = f.read() - except Exception as e: - return ToolResult.fail(f"Failed to read source file: {e}") - source_dir = os.path.dirname(os.path.abspath(source_path)) - elif code is None: - return ToolResult.fail("Either 'code' or 'source_path' must be provided") + resolved, err = resolve_source(problem_dir, code, source_path) + if err is not None: + return err + assert resolved is not None os.makedirs(problem_dir, exist_ok=True) - - # 保存到 files/ 子目录 files_dir = os.path.join(problem_dir, "files") os.makedirs(files_dir, exist_ok=True) - source_path = os.path.join(files_dir, "gen.cpp") + canonical_path = os.path.join(files_dir, "gen.cpp") try: - with open(source_path, "w", encoding="utf-8") as f: - f.write(code) + with open(canonical_path, "w", encoding="utf-8") as f: + f.write(resolved.code) except Exception as e: return ToolResult.fail(f"Failed to save code: {str(e)}") exe_ext = get_exe_extension() binary_path = os.path.join(files_dir, f"gen{exe_ext}") - include_dirs = [source_dir] if source_dir else None - compile_result = await self.build(source_path, binary_path, compiler=compiler, include_dirs=include_dirs) + compile_source = resolved.original_source_path or canonical_path + include_dirs = [resolved.include_dir] if resolved.include_dir else None + compile_result = await self.build(compile_source, binary_path, compiler=compiler, include_dirs=include_dirs) if not compile_result.success: return ToolResult.fail( f"Compilation failed: {compile_result.error}", - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, compile_log=compile_result.stderr, ) + binary_size = os.path.getsize(binary_path) if os.path.exists(binary_path) else 0 + return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, + binary_size=binary_size, compile_log=compile_result.stderr, message="Generator built successfully", ) @@ -207,6 +196,12 @@ def input_schema(self) -> dict: "description": "T 最大值", "default": 1, }, + "extra_args": { + "type": "array", + "items": {"type": "string"}, + "description": "附加命令行参数,追加在标准 6 参数之后传递给 generator", + "default": [], + }, }, "required": ["problem_dir", "strategies"], } @@ -222,9 +217,11 @@ async def execute( n_max: int = 100000, t_min: int = 1, t_max: int = 1, + extra_args: list[str] | None = None, ) -> ToolResult: """执行数据生成。""" exe_ext = get_exe_extension() + extra_args = extra_args or [] # 检查 generator - 优先查找 files/ 子目录 gen_exe = os.path.join(problem_dir, "files", f"gen{exe_ext}") @@ -262,8 +259,8 @@ async def execute( type_param = strategy_type_map.get(strategy, "2") # 运行 generator - # gen.exe - cmd_args = [str(seed), type_param, str(n_min), str(n_max), str(t_min), str(t_max)] + # gen.exe [extra_args...] + cmd_args = [str(seed), type_param, str(n_min), str(n_max), str(t_min), str(t_max)] + extra_args gen_result = await run_binary_with_args( gen_exe, diff --git a/src/autocode_mcp/tools/interactor.py b/src/autocode_mcp/tools/interactor.py index cf4de91..cdd6691 100644 --- a/src/autocode_mcp/tools/interactor.py +++ b/src/autocode_mcp/tools/interactor.py @@ -12,6 +12,7 @@ from ..utils.compiler import compile_cpp from ..utils.platform import get_exe_extension from .base import Tool, ToolResult +from .mixins import resolve_source class InteractorBuildTool(Tool): @@ -84,58 +85,45 @@ async def execute( compiler: str = "g++", ) -> ToolResult: """执行 Interactor 构建。""" - # 解析源代码:source_path 优先于 code - source_dir = None - if source_path: - if not os.path.isabs(source_path): - source_path = os.path.join(problem_dir, source_path) - if not os.path.exists(source_path): - return ToolResult.fail(f"Source file not found: {source_path}") - try: - with open(source_path, encoding="utf-8") as f: - code = f.read() - except UnicodeDecodeError: - try: - with open(source_path, encoding="latin-1") as f: - code = f.read() - except Exception as e: - return ToolResult.fail(f"Failed to read source file: {e}") - source_dir = os.path.dirname(os.path.abspath(source_path)) - elif code is None: - return ToolResult.fail("Either 'code' or 'source_path' must be provided") + resolved, err = resolve_source(problem_dir, code, source_path) + if err is not None: + return err + assert resolved is not None os.makedirs(problem_dir, exist_ok=True) - - # 保存到 files/ 子目录 files_dir = os.path.join(problem_dir, "files") os.makedirs(files_dir, exist_ok=True) - # 保存代码 - source_path = os.path.join(files_dir, "interactor.cpp") + canonical_path = os.path.join(files_dir, "interactor.cpp") try: - with open(source_path, "w", encoding="utf-8") as f: - f.write(code) + with open(canonical_path, "w", encoding="utf-8") as f: + f.write(resolved.code) except Exception as e: return ToolResult.fail(f"Failed to save code: {str(e)}") - # 编译 binary_path = os.path.join(files_dir, f"interactor{get_exe_extension()}") - include_dirs = [source_dir] if source_dir else None - compile_result = await compile_cpp(source_path, binary_path, compiler=compiler, include_dirs=include_dirs) + compile_source = resolved.original_source_path or canonical_path + include_dirs = [resolved.include_dir] if resolved.include_dir else None + compile_result = await compile_cpp(compile_source, binary_path, compiler=compiler, include_dirs=include_dirs) if not compile_result.success: return ToolResult.fail( f"Compilation failed: {compile_result.error}", - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, compile_log=compile_result.stderr, ) + binary_size = os.path.getsize(binary_path) if os.path.exists(binary_path) else 0 + # 如果没有提供参考解和变异解,直接返回成功(但 pass_rate 为 0) if not reference_solution_path and not mutant_solutions: return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, + binary_size=binary_size, compile_log=compile_result.stderr, pass_rate=0.0, fail_rate=0.0, @@ -155,7 +143,8 @@ async def execute( if not os.path.exists(reference_solution_path): return ToolResult.fail( f"Reference solution not found: {reference_solution_path}", - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, ) pass_total = 1 @@ -186,8 +175,10 @@ async def execute( fail_rate = fail_count / fail_total if fail_total > 0 else 0.0 return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, + binary_size=binary_size, compile_log=compile_result.stderr, pass_rate=pass_rate, fail_rate=fail_rate, diff --git a/src/autocode_mcp/tools/mixins.py b/src/autocode_mcp/tools/mixins.py index 86f9b29..ec056a6 100644 --- a/src/autocode_mcp/tools/mixins.py +++ b/src/autocode_mcp/tools/mixins.py @@ -6,10 +6,65 @@ from __future__ import annotations +import os +from dataclasses import dataclass from typing import Literal from ..utils.compiler import CompileResult, RunResult, compile_cpp, run_binary from ..utils.resource_limit import get_resource_limit +from .base import ToolResult + + +@dataclass +class ResolvedSource: + """解析后的源代码信息。""" + + code: str + original_source_path: str | None # 用户提供的源文件绝对路径 + include_dir: str | None # 源文件目录(用于 -I 编译选项) + from_source_path: bool # True 表示通过 source_path 读取 + + +def resolve_source( + problem_dir: str, + code: str | None, + source_path: str | None, +) -> tuple[ResolvedSource | None, ToolResult | None]: + """解析源代码来源:source_path 优先于 code。 + + Returns: + 成功时返回 (ResolvedSource, None),失败时返回 (None, ToolResult.fail(...)) + """ + original_source_path = None + include_dir = None + from_source_path = False + + if source_path: + from_source_path = True + if not os.path.isabs(source_path): + source_path = os.path.join(problem_dir, source_path) + if not os.path.exists(source_path): + return None, ToolResult.fail(f"Source file not found: {source_path}") + try: + with open(source_path, encoding="utf-8") as f: + code = f.read() + except UnicodeDecodeError: + try: + with open(source_path, encoding="latin-1") as f: + code = f.read() + except Exception as e: + return None, ToolResult.fail(f"Failed to read source file: {e}") + original_source_path = os.path.abspath(source_path) + include_dir = os.path.dirname(original_source_path) + elif code is None: + return None, ToolResult.fail("Either 'code' or 'source_path' must be provided") + + return ResolvedSource( + code=code, + original_source_path=original_source_path, + include_dir=include_dir, + from_source_path=from_source_path, + ), None class BuildToolMixin: diff --git a/src/autocode_mcp/tools/problem.py b/src/autocode_mcp/tools/problem.py index fa2efd8..b98b377 100644 --- a/src/autocode_mcp/tools/problem.py +++ b/src/autocode_mcp/tools/problem.py @@ -158,7 +158,7 @@ def input_schema(self) -> dict: }, "constraints": { "type": "object", - "description": "题目约束条件,用于生成极限数据", + "description": "题目约束条件。不提供 test_configs 时,系统将根据 constraints 自动生成覆盖边界、随机、极限、TLE 诱导等策略的测试配置(smart mode)", "properties": { "n_max": { "type": "integer", @@ -194,10 +194,23 @@ def input_schema(self) -> dict: "n_max": {"type": "integer", "description": "N 最大值"}, "t_min": {"type": "integer", "description": "T 最小值"}, "t_max": {"type": "integer", "description": "T 最大值"}, + "extra_args": { + "type": "array", + "items": {"type": "string"}, + "description": "附加命令行参数,追加在标准 6 参数之后传递给 generator", + }, }, "required": ["type", "n_min", "n_max", "t_min", "t_max"], }, }, + "output_dir": { + "type": "string", + "description": "测试数据输出目录路径,默认为 problem_dir/tests。必须位于 problem_dir 下,且不能是题目根目录或 files/solutions/statements 等保留目录", + }, + "sol_name": { + "type": "string", + "description": "标准解法文件名(不含扩展名),默认 'sol'", + }, "enable_dedup": { "type": "boolean", "description": "启用去重(基于 MD5 signature)", @@ -229,6 +242,8 @@ async def execute( timeout: int = 60, constraints: dict | None = None, test_configs: list[dict] | None = None, + output_dir: str | None = None, + sol_name: str | None = None, enable_dedup: bool = True, enable_validator_filter: bool = True, enable_balance: bool = True, @@ -296,33 +311,38 @@ async def execute( ) exe_ext = get_exe_extension() + effective_sol_name = sol_name or "sol" # 检查必要文件 gen_exe = os.path.join(problem_dir, "files", f"gen{exe_ext}") - sol_exe = os.path.join(problem_dir, "solutions", f"sol{exe_ext}") + sol_exe = os.path.join(problem_dir, "solutions", f"{effective_sol_name}{exe_ext}") val_exe = os.path.join(problem_dir, "files", f"val{exe_ext}") - tests_dir = os.path.join(problem_dir, "tests") + + # 解析输出目录 + tests_dir, tests_dir_error = self._resolve_tests_dir(problem_dir, output_dir) + if tests_dir_error: + return tests_dir_error # 如果 files 目录下没有,检查根目录 if not os.path.exists(gen_exe): gen_exe = os.path.join(problem_dir, f"gen{exe_ext}") if not os.path.exists(sol_exe): - sol_exe = os.path.join(problem_dir, f"sol{exe_ext}") + sol_exe = os.path.join(problem_dir, f"{effective_sol_name}{exe_ext}") if not os.path.exists(val_exe): val_exe = os.path.join(problem_dir, f"val{exe_ext}") if not os.path.exists(gen_exe): return ToolResult.fail("Generator not found. Run generator_build first.") if not os.path.exists(sol_exe): - return ToolResult.fail("sol not found. Run solution_build first.") + return ToolResult.fail(f"{effective_sol_name} not found. Run solution_build first.") # Validator 是否可用 validator_available = enable_validator_filter and os.path.exists(val_exe) - # 创建/清空 tests 目录 - if os.path.exists(tests_dir): - shutil.rmtree(tests_dir) - os.makedirs(tests_dir) + # 创建/清空 tests 目录。只移除旧的测试数据,避免误删用户源码或其他文件。 + clear_error = self._clear_generated_tests(tests_dir) + if clear_error: + return clear_error # 获取测试配置 if test_configs: @@ -334,6 +354,7 @@ async def execute( str(c["n_max"]), str(c["t_min"]), str(c["t_max"]), + [str(a) for a in c.get("extra_args", [])], ) for c in test_configs ] @@ -354,7 +375,7 @@ async def execute( cfg_idx = (seed - 1) % len(test_configs_list) test_cfg = test_configs_list[cfg_idx] - seed_offset, type_param, n_min, n_max, t_min, t_max = test_cfg + seed_offset, type_param, n_min, n_max, t_min, t_max, extra_args = test_cfg cmd_args = [ str(seed + int(seed_offset)), type_param, @@ -362,7 +383,7 @@ async def execute( str(n_max), str(t_min), str(t_max), - ] + ] + extra_args try: # 生成输入 @@ -455,6 +476,7 @@ async def execute( validator_filter_enabled=validator_available, balance_enabled=enable_balance, candidates_generated=len(candidates), + sol_name=effective_sol_name, message=f"Generated {len(generated_tests)} test cases (from {len(candidates)} candidates)", ) else: @@ -462,8 +484,62 @@ async def execute( f"Partial generation: {len(generated_tests)}/{test_count}", generated_tests=generated_tests, errors=errors, + sol_name=effective_sol_name, ) + def _resolve_tests_dir( + self, + problem_dir: str, + output_dir: str | None, + ) -> tuple[str | None, ToolResult | None]: + """解析并校验测试输出目录,防止清理时误删题目文件或外部目录。""" + problem_root = os.path.realpath(problem_dir) + raw_output_dir = output_dir or "tests" + tests_dir = raw_output_dir + if not os.path.isabs(tests_dir): + tests_dir = os.path.join(problem_root, tests_dir) + tests_dir = os.path.abspath(tests_dir) + resolved_tests_dir = os.path.realpath(tests_dir) + + try: + common = os.path.commonpath([problem_root, resolved_tests_dir]) + except ValueError: + common = "" + if os.path.normcase(common) != os.path.normcase(problem_root): + return None, ToolResult.fail("output_dir must be inside problem_dir") + + if os.path.normcase(resolved_tests_dir) == os.path.normcase(problem_root): + return None, ToolResult.fail("output_dir cannot be the problem_dir root") + + reserved_dirs = {"files", "solutions", "statements"} + for reserved in reserved_dirs: + reserved_path = os.path.realpath(os.path.join(problem_root, reserved)) + try: + reserved_common = os.path.commonpath([reserved_path, resolved_tests_dir]) + except ValueError: + reserved_common = "" + if os.path.normcase(reserved_common) == os.path.normcase(reserved_path): + return None, ToolResult.fail(f"output_dir cannot be reserved directory: {reserved}") + + if os.path.exists(tests_dir) and os.path.islink(tests_dir): + return None, ToolResult.fail(f"output_dir cannot be a symlink: {tests_dir}") + + if os.path.exists(tests_dir) and not os.path.isdir(tests_dir): + return None, ToolResult.fail(f"output_dir exists and is not a directory: {tests_dir}") + + return tests_dir, None + + def _clear_generated_tests(self, tests_dir: str) -> ToolResult | None: + """创建测试目录并清理旧的 .in/.ans 文件。""" + os.makedirs(tests_dir, exist_ok=True) + for filename in os.listdir(tests_dir): + if not (filename.endswith(".in") or filename.endswith(".ans")): + continue + path = os.path.join(tests_dir, filename) + if os.path.isfile(path): + os.remove(path) + return None + def _balance_and_sample( self, candidates: list[CandidateTest], target_count: int ) -> list[CandidateTest]: @@ -505,14 +581,14 @@ def _balance_and_sample( def _get_default_configs( self, constraints: dict | None = None - ) -> list[tuple[str, str, str, str, str, str]]: + ) -> list[tuple[str, str, str, str, str, str, list[str]]]: """获取默认测试配置。 Args: constraints: 题目约束条件,包含 n_max, t_max, sum_n_max 等 Returns: - 配置列表,每项为 (seed_offset, type, n_min, n_max, t_min, t_max) + 配置列表,每项为 (seed_offset, type, n_min, n_max, t_min, t_max, extra_args) """ # 从约束中获取极限值 n_limit = constraints.get("n_max", 100000) if constraints else 100000 @@ -524,17 +600,17 @@ def _get_default_configs( # 1. 边界情况 (type=1 tiny) - 最小值和极小值 configs.extend( [ - ("0", "1", "1", "1", "1", "1"), # N=1, T=1 - ("1", "1", "1", "1", str(t_limit), str(t_limit)), # N=1, T=max - ("2", "1", "2", "2", "1", "1"), # N=2 + ("0", "1", "1", "1", "1", "1", []), # N=1, T=1 + ("1", "1", "1", "1", str(t_limit), str(t_limit), []), # N=1, T=max + ("2", "1", "2", "2", "1", "1", []), # N=2 ] ) # 2. 小数据随机 (type=2 random) configs.extend( [ - ("3", "2", "1", "10", "1", str(min(3, t_limit))), - ("4", "2", "10", "100", "1", str(min(3, t_limit))), + ("3", "2", "1", "10", "1", str(min(3, t_limit)), []), + ("4", "2", "10", "100", "1", str(min(3, t_limit)), []), ] ) @@ -542,8 +618,8 @@ def _get_default_configs( mid_n = n_limit // 2 configs.extend( [ - ("5", "2", "100", str(mid_n // 10), "1", str(min(3, t_limit))), - ("6", "2", str(mid_n // 10), str(mid_n), "1", str(min(2, t_limit))), + ("5", "2", "100", str(mid_n // 10), "1", str(min(3, t_limit)), []), + ("6", "2", str(mid_n // 10), str(mid_n), "1", str(min(2, t_limit)), []), ] ) @@ -551,17 +627,17 @@ def _get_default_configs( if n_limit >= 10000: configs.extend( [ - ("7", "2", str(mid_n), str(n_limit), "1", "1"), - ("8", "2", str(int(n_limit * 0.8)), str(n_limit), "1", "1"), + ("7", "2", str(mid_n), str(n_limit), "1", "1", []), + ("8", "2", str(int(n_limit * 0.8)), str(n_limit), "1", "1", []), ] ) # 5. 极限数据 (type=3 extreme) - 接近上限 configs.extend( [ - ("9", "3", str(n_limit), str(n_limit), "1", "1"), # N=max - ("10", "3", str(n_limit - 1), str(n_limit), "1", "1"), # N=max-1 - ("11", "3", str(int(n_limit * 0.99)), str(n_limit), "1", "1"), # 接近极限 + ("9", "3", str(n_limit), str(n_limit), "1", "1", []), # N=max + ("10", "3", str(n_limit - 1), str(n_limit), "1", "1", []), # N=max-1 + ("11", "3", str(int(n_limit * 0.99)), str(n_limit), "1", "1", []), # 接近极限 ] ) @@ -570,15 +646,15 @@ def _get_default_configs( # T=max, N 根据sum约束调整 n_per_test = min(n_limit, sum_n_limit // t_limit) if sum_n_limit else n_limit configs.append( - ("12", "3", str(max(1, n_per_test // 2)), str(n_per_test), str(t_limit), str(t_limit)) + ("12", "3", str(max(1, n_per_test // 2)), str(n_per_test), str(t_limit), str(t_limit), []) ) # 7. TLE 诱导数据 (type=4) if n_limit >= 100: configs.extend( [ - ("13", "4", str(n_limit), str(n_limit), "1", "1"), - ("14", "4", str(int(n_limit * 0.9)), str(n_limit), "1", "1"), + ("13", "4", str(n_limit), str(n_limit), "1", "1", []), + ("14", "4", str(int(n_limit * 0.9)), str(n_limit), "1", "1", []), ] ) diff --git a/src/autocode_mcp/tools/solution.py b/src/autocode_mcp/tools/solution.py index 72bee64..70b7ca1 100644 --- a/src/autocode_mcp/tools/solution.py +++ b/src/autocode_mcp/tools/solution.py @@ -3,11 +3,12 @@ """ import os +import shutil from typing import Literal from ..utils.platform import get_exe_extension from .base import Tool, ToolResult -from .mixins import BuildToolMixin, RunToolMixin +from .mixins import BuildToolMixin, RunToolMixin, resolve_source class SolutionBuildTool(Tool, BuildToolMixin): @@ -47,6 +48,10 @@ def input_schema(self) -> dict: "enum": ["sol", "brute"], "description": "解法类型:sol(标准解法)或 brute(暴力解法)", }, + "name": { + "type": "string", + "description": "自定义文件名(不含扩展名),默认使用 solution_type。例如 'brute_force' 替代 'brute'", + }, "code": { "type": "string", "description": "C++ 源代码(与 source_path 二选一)", @@ -72,69 +77,69 @@ async def execute( self, problem_dir: str, solution_type: Literal["sol", "brute"], + name: str | None = None, code: str | None = None, source_path: str | None = None, compiler: str = "g++", ) -> ToolResult: """执行解法构建。""" - # 解析源代码:source_path 优先于 code - source_dir = None - if source_path: - if not os.path.isabs(source_path): - source_path = os.path.join(problem_dir, source_path) - if not os.path.exists(source_path): - return ToolResult.fail(f"Source file not found: {source_path}") - try: - with open(source_path, encoding="utf-8") as f: - code = f.read() - except UnicodeDecodeError: - try: - with open(source_path, encoding="latin-1") as f: - code = f.read() - except Exception as e: - return ToolResult.fail(f"Failed to read source file: {e}") - source_dir = os.path.dirname(os.path.abspath(source_path)) - elif code is None: - return ToolResult.fail("Either 'code' or 'source_path' must be provided") + resolved, err = resolve_source(problem_dir, code, source_path) + if err is not None: + return err + assert resolved is not None # 确保目录存在 os.makedirs(problem_dir, exist_ok=True) - - # 保存到 solutions/ 子目录 solutions_dir = os.path.join(problem_dir, "solutions") os.makedirs(solutions_dir, exist_ok=True) # 确定文件名 - source_name = f"{solution_type}.cpp" - source_path = os.path.join(solutions_dir, source_name) + effective_name = name or solution_type + exe_ext = get_exe_extension() + canonical_path = os.path.join(solutions_dir, f"{effective_name}.cpp") + binary_path = os.path.join(solutions_dir, f"{effective_name}{exe_ext}") + standard_source_path = os.path.join(solutions_dir, f"{solution_type}.cpp") + standard_binary_path = os.path.join(solutions_dir, f"{solution_type}{exe_ext}") - # 保存代码 + # 保存自定义命名文件,并保留 sol.cpp/brute.cpp 供打包和默认流程使用。 try: - with open(source_path, "w", encoding="utf-8") as f: - f.write(code) + with open(canonical_path, "w", encoding="utf-8") as f: + f.write(resolved.code) + if os.path.normcase(canonical_path) != os.path.normcase(standard_source_path): + shutil.copy2(canonical_path, standard_source_path) except Exception as e: return ToolResult.fail(f"Failed to save code: {str(e)}") - # 编译 - exe_ext = get_exe_extension() - binary_name = f"{solution_type}{exe_ext}" - binary_path = os.path.join(solutions_dir, binary_name) - - include_dirs = [source_dir] if source_dir else None - result = await self.build(source_path, binary_path, compiler=compiler, include_dirs=include_dirs) + # 编译:source_path 时从原始文件编译,否则从标准位置编译 + compile_source = resolved.original_source_path or canonical_path + include_dirs = [resolved.include_dir] if resolved.include_dir else None + result = await self.build(compile_source, binary_path, compiler=compiler, include_dirs=include_dirs) if not result.success: return ToolResult.fail( f"Compilation failed: {result.error}", - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, compile_log=result.stderr, ) + binary_size = os.path.getsize(binary_path) if os.path.exists(binary_path) else 0 + if os.path.normcase(binary_path) != os.path.normcase(standard_binary_path): + try: + shutil.copy2(binary_path, standard_binary_path) + except Exception as e: + return ToolResult.fail(f"Failed to save standard binary: {str(e)}") + return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, + standard_source_path=standard_source_path, binary_path=binary_path, + standard_binary_path=standard_binary_path, + binary_size=binary_size, compile_log=result.stderr, - message=f"Successfully built {solution_type}", + effective_name=effective_name, + message=f"Successfully built {effective_name}", ) @@ -174,6 +179,10 @@ def input_schema(self) -> dict: "enum": ["sol", "brute"], "description": "解法类型:sol 或 brute", }, + "name": { + "type": "string", + "description": "自定义文件名(不含扩展名),默认使用 solution_type", + }, "input_data": { "type": "string", "description": "输入数据", @@ -192,20 +201,21 @@ async def execute( problem_dir: str, solution_type: Literal["sol", "brute"], input_data: str, + name: str | None = None, timeout: int = 30, ) -> ToolResult: """执行解法运行。""" - # 确定二进制文件路径 - 优先查找 solutions/ 子目录 + effective_name = name or solution_type exe_ext = get_exe_extension() - binary_path = os.path.join(problem_dir, "solutions", f"{solution_type}{exe_ext}") + binary_path = os.path.join(problem_dir, "solutions", f"{effective_name}{exe_ext}") # 如果子目录没有,检查根目录(向后兼容) if not os.path.exists(binary_path): - binary_path = os.path.join(problem_dir, f"{solution_type}{exe_ext}") + binary_path = os.path.join(problem_dir, f"{effective_name}{exe_ext}") if not os.path.exists(binary_path): return ToolResult.fail( - f"Binary not found: {solution_type}. Please run solution_build first." + f"Binary not found: {effective_name}. Please run solution_build first." ) # 运行 diff --git a/src/autocode_mcp/tools/stress_test.py b/src/autocode_mcp/tools/stress_test.py index 659d57f..bfd8b93 100644 --- a/src/autocode_mcp/tools/stress_test.py +++ b/src/autocode_mcp/tools/stress_test.py @@ -62,9 +62,25 @@ def input_schema(self) -> dict: "description": "单次执行超时(秒)", "default": 30, }, + "sol_name": { + "type": "string", + "description": "标准解法文件名(不含扩展名),默认 'sol'", + }, + "brute_name": { + "type": "string", + "description": "暴力解法文件名(不含扩展名),默认 'brute'", + }, + "types": { + "type": "array", + "items": { + "type": "string", + "enum": ["1", "2", "3", "4"], + }, + "description": "生成策略类型列表,轮次之间循环使用。例如 ['1','2','3','4'] 表示依次使用 tiny, random, extreme, tle。未指定时使用 generator_args.type 或默认 '2'", + }, "generator_args": { "type": "object", - "description": "Generator 命令行参数。调用协议: gen.exe 。seed 由系统自动填充为当前轮次,其余参数在此指定", + "description": "Generator 命令行参数。调用协议: gen.exe [extra_args...]。seed 由系统自动填充为当前轮次,其余参数在此指定", "properties": { "type": { "type": "string", @@ -90,6 +106,12 @@ def input_schema(self) -> dict: "default": 1, "description": "测试组数 T 的最大值", }, + "extra_args": { + "type": "array", + "items": {"type": "string"}, + "description": "附加命令行参数,追加在标准 6 参数之后传递给 generator", + "default": [], + }, }, }, }, @@ -102,30 +124,35 @@ async def execute( trials: int = 1000, n_max: int = 100, timeout: int = 30, + sol_name: str | None = None, + brute_name: str | None = None, + types: list[str] | None = None, generator_args: dict | None = None, ) -> ToolResult: """执行对拍测试。""" exe_ext = get_exe_extension() + effective_sol_name = sol_name or "sol" + effective_brute_name = brute_name or "brute" # 检查必要文件 - 优先查找子目录,回退到根目录 gen_exe = os.path.join(problem_dir, "files", f"gen{exe_ext}") if not os.path.exists(gen_exe): gen_exe = os.path.join(problem_dir, f"gen{exe_ext}") - sol_exe = os.path.join(problem_dir, "solutions", f"sol{exe_ext}") + sol_exe = os.path.join(problem_dir, "solutions", f"{effective_sol_name}{exe_ext}") if not os.path.exists(sol_exe): - sol_exe = os.path.join(problem_dir, f"sol{exe_ext}") + sol_exe = os.path.join(problem_dir, f"{effective_sol_name}{exe_ext}") - brute_exe = os.path.join(problem_dir, "solutions", f"brute{exe_ext}") + brute_exe = os.path.join(problem_dir, "solutions", f"{effective_brute_name}{exe_ext}") if not os.path.exists(brute_exe): - brute_exe = os.path.join(problem_dir, f"brute{exe_ext}") + brute_exe = os.path.join(problem_dir, f"{effective_brute_name}{exe_ext}") if not os.path.exists(gen_exe): return ToolResult.fail("Generator not found. Run generator_build first.") if not os.path.exists(sol_exe): - return ToolResult.fail("sol not found. Run solution_build first.") + return ToolResult.fail(f"{effective_sol_name} not found. Run solution_build first.") if not os.path.exists(brute_exe): - return ToolResult.fail("brute not found. Run solution_build first.") + return ToolResult.fail(f"{effective_brute_name} not found. Run solution_build first.") # 可选的 validator val_exe = os.path.join(problem_dir, "files", f"val{exe_ext}") @@ -141,20 +168,24 @@ async def execute( brute_output = None validator_failed = False + # 统计信息收集 + round_stats: list[dict] = [] + with tempfile.TemporaryDirectory(dir=problem_dir) as temp_dir: input_path = os.path.join(temp_dir, "input.txt") for i in range(1, trials + 1): # 1. 生成输入数据 gen_result = await self._generate_input( - gen_exe, input_path, i, seed=i, timeout=timeout, n_max=n_max, generator_args=generator_args + gen_exe, input_path, i, seed=i, timeout=timeout, n_max=n_max, + generator_args=generator_args, types=types, ) if not gen_result["success"]: error_detail = gen_result.get("error", "Unknown error") if "timed out" in error_detail: hint = "Generator may contain an infinite loop or be too slow. Try increasing the timeout parameter." elif "no output" in error_detail: - hint = "Check that the generator follows the protocol: gen.exe " + hint = "Check that the generator follows the protocol: gen.exe [extra_args...]" else: hint = "Generator crashed unexpectedly. Check stderr for details." return ToolResult.fail( @@ -165,6 +196,7 @@ async def execute( stdout=gen_result.get("stdout", ""), cmd_args=gen_result.get("cmd_args", []), last_input=last_input, + statistics=self._compute_summary(round_stats), ) # 2. 验证输入(如果有 validator) @@ -187,10 +219,11 @@ async def execute( sol_result = await run_binary(sol_exe, input_data, timeout=timeout) if sol_result.timed_out or not sol_result.success: return ToolResult.fail( - f"sol failed at round {i}", + f"{effective_sol_name} failed at round {i}", round=i, input_data=input_data, stderr=sol_result.stderr, + statistics=self._compute_summary(round_stats), ) sol_output = sol_result.stdout @@ -198,20 +231,31 @@ async def execute( brute_result = await run_binary(brute_exe, input_data, timeout=timeout) if brute_result.timed_out: return ToolResult.fail( - f"brute timed out at round {i} (N may be too large)", + f"{effective_brute_name} timed out at round {i} (N may be too large)", round=i, input_data=input_data, suggestion="Try reducing n_max parameter", + statistics=self._compute_summary(round_stats), ) if not brute_result.success: return ToolResult.fail( - f"brute failed at round {i}", + f"{effective_brute_name} failed at round {i}", round=i, input_data=input_data, stderr=brute_result.stderr, + statistics=self._compute_summary(round_stats), ) brute_output = brute_result.stdout + # 收集统计信息 + round_stats.append({ + "round": i, + "sol_time_ms": sol_result.time_ms, + "brute_time_ms": brute_result.time_ms, + "input_size": len(input_data), + "n_value": self._extract_n_value(input_data), + }) + # 5. 比较输出 if sol_output.strip() != brute_output.strip(): last_input = input_data @@ -226,6 +270,7 @@ async def execute( brute_output, trials, effective_n_max, + round_stats, ) async def _generate_input( @@ -237,43 +282,37 @@ async def _generate_input( timeout: int, n_max: int = 100, generator_args: dict | None = None, + types: list[str] | None = None, ) -> dict: - """ - 生成输入数据。 - - Args: - gen_exe: generator 可执行文件路径 - input_path: 输入文件保存路径 - round_num: 当前轮次 - seed: 随机种子 - timeout: 超时时间(秒) - n_max: N 最大值(用于默认协议) - generator_args: Generator 完整参数(可选) - - Returns: - dict: {"success": bool, "error": str | None} - """ + """生成输入数据。""" try: + # 确定 type 参数 + if types: + type_param = types[(round_num - 1) % len(types)] + elif generator_args: + type_param = generator_args.get("type", "2") + else: + type_param = "2" + # 构建命令参数 + # gen.exe [extra_args...] if generator_args: - # 完整协议: gen.exe cmd_args = [ str(seed), - generator_args.get("type", "2"), + type_param, str(generator_args.get("n_min", 1)), str(generator_args.get("n_max", n_max)), str(generator_args.get("t_min", 1)), str(generator_args.get("t_max", 1)), - ] + ] + generator_args.get("extra_args", []) else: - # 默认使用完整协议,与 generator_run 和 problem_generate_tests 保持一致 cmd_args = [ str(seed), - "2", # type=random - "1", # n_min - str(n_max), # n_max 使用参数 - "1", # t_min - "1", # t_max + type_param, + "1", + str(n_max), + "1", + "1", ] gen_result = await run_binary_with_args( @@ -281,8 +320,6 @@ async def _generate_input( cmd_args, timeout=timeout, ) - # Generator 可能因 testlib.h 优化问题崩溃,但输出仍有效 - # 只要没有超时且有输出,就认为成功 if gen_result.timed_out: return { "success": False, @@ -318,6 +355,70 @@ async def _generate_input( "seed": seed, } + def _extract_n_value(self, input_data: str) -> int | None: + """尝试从输入数据的第一行解析 N 值。""" + try: + first_line = input_data.strip().split("\n")[0] + # 尝试解析为整数(常见竞赛编程格式) + n = int(first_line.strip()) + return n if n > 0 else None + except (ValueError, IndexError): + return None + + def _compute_n_distribution(self, round_stats: list[dict]) -> dict[str, int]: + """计算 N 值分布。""" + buckets = {"1": 0, "2-10": 0, "11-50": 0, "51-100": 0, "101+": 0} + for stat in round_stats: + n = stat.get("n_value") + if n is None: + continue + if n == 1: + buckets["1"] += 1 + elif n <= 10: + buckets["2-10"] += 1 + elif n <= 50: + buckets["11-50"] += 1 + elif n <= 100: + buckets["51-100"] += 1 + else: + buckets["101+"] += 1 + return {k: v for k, v in buckets.items() if v > 0} + + def _compute_summary(self, round_stats: list[dict]) -> dict | None: + """计算统计摘要。""" + if not round_stats: + return None + + sol_times = [s["sol_time_ms"] for s in round_stats] + brute_times = [s["brute_time_ms"] for s in round_stats] + + summary = { + "rounds_completed": len(round_stats), + "sol_time": { + "min_ms": min(sol_times), + "max_ms": max(sol_times), + "avg_ms": sum(sol_times) // len(sol_times), + "total_ms": sum(sol_times), + }, + "brute_time": { + "min_ms": min(brute_times), + "max_ms": max(brute_times), + "avg_ms": sum(brute_times) // len(brute_times), + "total_ms": sum(brute_times), + }, + "n_distribution": self._compute_n_distribution(round_stats), + "slowest_round": max(round_stats, key=lambda s: s["sol_time_ms"]), + } + + # 计算最大时间比 + ratios = [] + for s in round_stats: + bt = max(s["brute_time_ms"], 1) + ratios.append(s["sol_time_ms"] / bt) + summary["max_ratio"] = max(ratios) + + return summary + def _format_result( self, failed_round: int | None, @@ -327,10 +428,11 @@ def _format_result( brute_output: str | None, total_rounds: int, effective_n_max: int = 100, + round_stats: list[dict] | None = None, ) -> ToolResult: - """ - 格式化测试结果。 - """ + """格式化测试结果。""" + statistics = self._compute_summary(round_stats or []) + if failed_round: return ToolResult.fail( f"Output mismatch at round {failed_round}" @@ -342,11 +444,13 @@ def _format_result( brute_output=brute_output, completed_rounds=failed_round - 1, total_rounds=total_rounds, + statistics=statistics, ) return ToolResult.ok( completed_rounds=total_rounds, total_rounds=total_rounds, effective_n_max=effective_n_max, + statistics=statistics, message=f"All {total_rounds} rounds passed", ) diff --git a/src/autocode_mcp/tools/test_verify.py b/src/autocode_mcp/tools/test_verify.py new file mode 100644 index 0000000..147a52e --- /dev/null +++ b/src/autocode_mcp/tools/test_verify.py @@ -0,0 +1,328 @@ +""" +Test Verification 工具 - 验证生成的测试数据。 + +检查文件完整性、答案一致性、约束覆盖等。 +""" + +from __future__ import annotations + +import os +from pathlib import Path + +from ..utils.compiler import run_binary +from ..utils.platform import get_exe_extension +from .base import Tool, ToolResult + + +class ProblemVerifyTestsTool(Tool): + """验证生成的测试数据。""" + + @property + def name(self) -> str: + return "problem_verify_tests" + + @property + def description(self) -> str: + return """验证生成的测试数据质量。 + + 自动执行以下检查: + 1. file_count: 每个 .in 有对应的 .ans,文件名连续 + 2. answer_consistency: 用 sol 重新运行 .in,对比输出与 .ans + 3. validator: 用 val 检查每个 .in 是否满足约束(如有 val.exe) + 4. no_empty: 没有空文件 + + 前置条件: + 1. 已运行 problem_generate_tests 生成测试数据 + 2. 已运行 solution_build 构建 sol + + 建议下一步: + - 如果验证通过:运行 problem_pack_polygon 打包 + - 如果验证失败:根据失败信息修复问题 + """ + + @property + def input_schema(self) -> dict: + return { + "type": "object", + "properties": { + "problem_dir": { + "type": "string", + "description": "题目目录路径", + }, + "tests_dir": { + "type": "string", + "description": "测试数据目录路径,默认为 problem_dir/tests", + }, + "verify_types": { + "type": "array", + "items": { + "type": "string", + "enum": ["file_count", "answer_consistency", "validator", "no_empty"], + }, + "description": "要执行的验证类型,默认全部执行", + }, + "sol_name": { + "type": "string", + "description": "标准解法文件名(不含扩展名),默认 'sol'", + }, + "timeout": { + "type": "integer", + "description": "单次执行超时(秒)", + "default": 60, + }, + }, + "required": ["problem_dir"], + } + + async def execute( + self, + problem_dir: str, + tests_dir: str | None = None, + verify_types: list[str] | None = None, + sol_name: str | None = None, + timeout: int = 60, + ) -> ToolResult: + """执行测试数据验证。""" + effective_sol_name = sol_name or "sol" + + # 解析测试目录 + if tests_dir: + if not os.path.isabs(tests_dir): + tests_dir = os.path.join(problem_dir, tests_dir) + else: + tests_dir = os.path.join(problem_dir, "tests") + + if not os.path.exists(tests_dir): + return ToolResult.fail(f"Tests directory not found: {tests_dir}") + + # 默认执行所有验证 + if not verify_types: + verify_types = ["file_count", "answer_consistency", "validator", "no_empty"] + + results = {} + all_passed = True + + # 1. 文件完整性检查 + if "file_count" in verify_types: + result = self._check_file_count(tests_dir) + results["file_count"] = result + if not result["passed"]: + all_passed = False + + # 2. 空文件检查 + if "no_empty" in verify_types: + result = self._check_no_empty(tests_dir) + results["no_empty"] = result + if not result["passed"]: + all_passed = False + + # 3. 答案一致性检查 + if "answer_consistency" in verify_types: + result = await self._check_answer_consistency( + problem_dir, + tests_dir, + effective_sol_name, + timeout, + ) + results["answer_consistency"] = result + if not result["passed"]: + all_passed = False + + # 4. Validator 检查 + if "validator" in verify_types: + result = await self._check_validator(problem_dir, tests_dir, timeout) + results["validator"] = result + if not result["passed"]: + all_passed = False + + # 汇总 + total_checks = len(results) + passed_checks = sum(1 for r in results.values() if r["passed"]) + + if all_passed: + return ToolResult.ok( + passed=True, + results=results, + total_checks=total_checks, + passed_checks=passed_checks, + tests_dir=tests_dir, + sol_name=effective_sol_name, + message=f"All {total_checks} verification checks passed", + ) + else: + return ToolResult.fail( + f"{passed_checks}/{total_checks} checks passed", + passed=False, + results=results, + total_checks=total_checks, + passed_checks=passed_checks, + tests_dir=tests_dir, + sol_name=effective_sol_name, + ) + + def _check_file_count(self, tests_dir: str) -> dict: + """检查文件完整性:每个 .in 有对应的 .ans。""" + tests_path = Path(tests_dir) + in_files = sorted(p.name for p in tests_path.iterdir() if p.is_file() and p.suffix == ".in") + ans_files = sorted(p.name for p in tests_path.iterdir() if p.is_file() and p.suffix == ".ans") + ans_file_set = set(ans_files) + in_file_set = set(in_files) + + missing_ans = [] + for in_file in in_files: + ans_file = Path(in_file).with_suffix(".ans").name + if ans_file not in ans_file_set: + missing_ans.append(in_file) + + orphan_ans = [] + for ans_file in ans_files: + in_file = Path(ans_file).with_suffix(".in").name + if in_file not in in_file_set: + orphan_ans.append(ans_file) + + non_numeric = [f for f in in_files if not Path(f).stem.isdigit()] + numeric_indices = sorted(int(Path(f).stem) for f in in_files if Path(f).stem.isdigit()) + numeric_index_set = set(numeric_indices) + expected_indices = list(range(1, max(numeric_indices) + 1)) if numeric_indices else [] + missing_indices = [ + idx for idx in expected_indices if idx not in numeric_index_set + ] + duplicate_indices = sorted( + idx for idx in numeric_index_set if numeric_indices.count(idx) > 1 + ) + + passed = ( + not missing_ans + and not orphan_ans + and not non_numeric + and not missing_indices + and not duplicate_indices + ) + return { + "passed": passed, + "total": len(in_files), + "missing_ans": missing_ans, + "orphan_ans": orphan_ans, + "missing_indices": missing_indices, + "duplicate_indices": duplicate_indices, + "non_numeric": non_numeric, + } + + def _check_no_empty(self, tests_dir: str) -> dict: + """检查没有空文件。""" + empty_files = [] + for f in os.listdir(tests_dir): + filepath = os.path.join(tests_dir, f) + if os.path.isfile(filepath) and os.path.getsize(filepath) == 0: + empty_files.append(f) + + return { + "passed": len(empty_files) == 0, + "total": len(os.listdir(tests_dir)), + "empty_files": empty_files, + } + + async def _check_answer_consistency( + self, problem_dir: str, tests_dir: str, sol_name: str, timeout: int + ) -> dict: + """用 sol 重新运行 .in,对比输出与 .ans。""" + exe_ext = get_exe_extension() + sol_exe = os.path.join(problem_dir, "solutions", f"{sol_name}{exe_ext}") + if not os.path.exists(sol_exe): + sol_exe = os.path.join(problem_dir, f"{sol_name}{exe_ext}") + + if not os.path.exists(sol_exe): + return { + "passed": False, + "total": 0, + "mismatches": [], + "error": f"{sol_name}{exe_ext} not found, run solution_build first", + } + + in_files = sorted( + f for f in os.listdir(tests_dir) if f.endswith(".in") + ) + + mismatches = [] + timed_out = [] + errors = [] + + for in_file in in_files: + in_path = os.path.join(tests_dir, in_file) + ans_file = Path(in_file).with_suffix(".ans").name + ans_path = os.path.join(tests_dir, ans_file) + + if not os.path.exists(ans_path): + continue + + with open(in_path, encoding="utf-8") as f: + input_data = f.read() + + with open(ans_path, encoding="utf-8") as f: + expected = f.read() + + result = await run_binary(sol_exe, input_data, timeout=timeout) + + if result.timed_out: + timed_out.append(in_file) + continue + + if not result.success: + errors.append({"file": in_file, "stderr": result.stderr}) + continue + + if result.stdout.strip() != expected.strip(): + mismatches.append({ + "file": in_file, + "expected": expected[:200], + "actual": result.stdout[:200], + }) + + passed = not mismatches and not timed_out and not errors + return { + "passed": passed, + "total": len(in_files), + "mismatches": mismatches, + "timed_out": timed_out, + "errors": errors, + } + + async def _check_validator( + self, problem_dir: str, tests_dir: str, timeout: int + ) -> dict: + """用 val 检查每个 .in 是否满足约束。""" + exe_ext = get_exe_extension() + val_exe = os.path.join(problem_dir, "files", f"val{exe_ext}") + + if not os.path.exists(val_exe): + return { + "passed": True, + "total": 0, + "skipped": True, + "message": "val.exe not found, validator check skipped", + } + + in_files = sorted( + f for f in os.listdir(tests_dir) if f.endswith(".in") + ) + + invalid = [] + for in_file in in_files: + in_path = os.path.join(tests_dir, in_file) + + with open(in_path, encoding="utf-8") as f: + input_data = f.read() + + result = await run_binary(val_exe, input_data, timeout=timeout) + + if result.return_code != 0: + invalid.append({ + "file": in_file, + "stderr": result.stderr[:200] if result.stderr else "", + }) + + return { + "passed": len(invalid) == 0, + "total": len(in_files), + "invalid": invalid, + } diff --git a/src/autocode_mcp/tools/validator.py b/src/autocode_mcp/tools/validator.py index 9ae6236..06cf24a 100644 --- a/src/autocode_mcp/tools/validator.py +++ b/src/autocode_mcp/tools/validator.py @@ -11,7 +11,7 @@ from ..utils.compiler import run_binary from ..utils.platform import get_exe_extension from .base import Tool, ToolResult -from .mixins import BuildToolMixin +from .mixins import BuildToolMixin, resolve_source class ValidatorBuildTool(Tool, BuildToolMixin): @@ -90,59 +90,46 @@ async def execute( compiler: str = "g++", ) -> ToolResult: """执行 Validator 构建。""" - # 解析源代码:source_path 优先于 code - source_dir = None - if source_path: - if not os.path.isabs(source_path): - source_path = os.path.join(problem_dir, source_path) - if not os.path.exists(source_path): - return ToolResult.fail(f"Source file not found: {source_path}") - try: - with open(source_path, encoding="utf-8") as f: - code = f.read() - except UnicodeDecodeError: - try: - with open(source_path, encoding="latin-1") as f: - code = f.read() - except Exception as e: - return ToolResult.fail(f"Failed to read source file: {e}") - source_dir = os.path.dirname(os.path.abspath(source_path)) - elif code is None: - return ToolResult.fail("Either 'code' or 'source_path' must be provided") + resolved, err = resolve_source(problem_dir, code, source_path) + if err is not None: + return err + assert resolved is not None # 确保目录存在 os.makedirs(problem_dir, exist_ok=True) - - # 保存到 files/ 子目录 files_dir = os.path.join(problem_dir, "files") os.makedirs(files_dir, exist_ok=True) - # 保存代码 - source_path = os.path.join(files_dir, "val.cpp") + canonical_path = os.path.join(files_dir, "val.cpp") try: - with open(source_path, "w", encoding="utf-8") as f: - f.write(code) + with open(canonical_path, "w", encoding="utf-8") as f: + f.write(resolved.code) except Exception as e: return ToolResult.fail(f"Failed to save code: {str(e)}") - # 编译 binary_path = os.path.join(files_dir, f"val{get_exe_extension()}") - include_dirs = [source_dir] if source_dir else None - compile_result = await self.build(source_path, binary_path, compiler=compiler, include_dirs=include_dirs) + compile_source = resolved.original_source_path or canonical_path + include_dirs = [resolved.include_dir] if resolved.include_dir else None + compile_result = await self.build(compile_source, binary_path, compiler=compiler, include_dirs=include_dirs) if not compile_result.success: return ToolResult.fail( f"Compilation failed: {compile_result.error}", - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, compile_log=compile_result.stderr, ) + binary_size = os.path.getsize(binary_path) if os.path.exists(binary_path) else 0 + # 如果没有测试用例,直接返回成功 if not test_cases: return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, + binary_size=binary_size, compile_log=compile_result.stderr, message="Validator built successfully (no test cases provided)", ) @@ -178,8 +165,10 @@ async def execute( total = len(test_cases) return ToolResult.ok( - source_path=source_path, + source_path=compile_source, + canonical_path=canonical_path, binary_path=binary_path, + binary_size=binary_size, compile_log=compile_result.stderr, test_results=test_results, score=score, diff --git a/tests/test_e2e_mcp.py b/tests/test_e2e_mcp.py index 1be1679..6207197 100644 --- a/tests/test_e2e_mcp.py +++ b/tests/test_e2e_mcp.py @@ -126,7 +126,7 @@ async def test_mcp_list_tools(mcp_client: MCPClient): tools = await mcp_client.list_tools() - assert len(tools) == 16 + assert len(tools) == 17 tool_names = {t["name"] for t in tools} expected_tools = { @@ -273,7 +273,7 @@ async def test_packaged_console_script_list_tools(packaged_mcp_client: MCPClient tools = await packaged_mcp_client.list_tools() - assert len(tools) == 16 + assert len(tools) == 17 tool_names = {t["name"] for t in tools} assert "solution_build" in tool_names assert "validator_build" in tool_names diff --git a/tests/test_packaging.py b/tests/test_packaging.py index 7718342..0437a67 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -11,7 +11,7 @@ def test_import(): """测试模块导入。""" from autocode_mcp import __version__ - assert __version__ == "0.6.0" + assert __version__ == "0.7.0" def test_tool_result(): diff --git a/tests/test_plugin_manifest.py b/tests/test_plugin_manifest.py index 7b635df..4c3adf3 100644 --- a/tests/test_plugin_manifest.py +++ b/tests/test_plugin_manifest.py @@ -11,7 +11,7 @@ def test_claude_plugin_manifest_links_mcp_config(): manifest = json.loads(Path(".claude-plugin/plugin.json").read_text(encoding="utf-8")) assert manifest["name"] == "autocode" - assert manifest["version"] == "0.6.0" + assert manifest["version"] == "0.7.0" def test_claude_plugin_manifest_has_interface_metadata(): diff --git a/tests/test_tools/test_problem.py b/tests/test_tools/test_problem.py index e2a7d43..0c31a66 100644 --- a/tests/test_tools/test_problem.py +++ b/tests/test_tools/test_problem.py @@ -14,6 +14,9 @@ ProblemPackPolygonTool, ) from autocode_mcp.tools.solution import SolutionBuildTool +from autocode_mcp.tools.test_verify import ProblemVerifyTestsTool +from autocode_mcp.utils.compiler import RunResult +from autocode_mcp.utils.platform import get_exe_extension @pytest.mark.asyncio @@ -293,6 +296,157 @@ async def test_problem_generate_tests_test_configs_validation(): assert "Generator not found" in result.error # 验证通过,但找不到 generator +@pytest.mark.asyncio +async def test_problem_generate_tests_rejects_unsafe_output_dir(): + """测试拒绝危险的测试输出目录。""" + tool = ProblemGenerateTestsTool() + + with tempfile.TemporaryDirectory() as tmpdir: + problem_dir = os.path.join(tmpdir, "unsafe_output") + os.makedirs(os.path.join(problem_dir, "files")) + os.makedirs(os.path.join(problem_dir, "solutions")) + + result = await tool.execute(problem_dir=problem_dir, output_dir=".") + assert not result.success + assert "output_dir cannot be the problem_dir root" in result.error + + result = await tool.execute(problem_dir=problem_dir, output_dir="solutions") + assert not result.success + assert "reserved directory" in result.error + + result = await tool.execute(problem_dir=problem_dir, output_dir="solutions/generated") + assert not result.success + assert "reserved directory" in result.error + + outside_dir = os.path.join(tmpdir, "outside") + result = await tool.execute(problem_dir=problem_dir, output_dir=outside_dir) + assert not result.success + assert "output_dir must be inside problem_dir" in result.error + + +@pytest.mark.asyncio +async def test_problem_generate_tests_rejects_symlinked_output_dir(): + """测试拒绝指向题目目录外部的符号链接输出目录。""" + tool = ProblemGenerateTestsTool() + + with tempfile.TemporaryDirectory() as tmpdir: + problem_dir = os.path.join(tmpdir, "symlink_output") + outside_dir = os.path.join(tmpdir, "outside_tests") + link_dir = os.path.join(problem_dir, "tests_link") + os.makedirs(os.path.join(problem_dir, "files")) + os.makedirs(os.path.join(problem_dir, "solutions")) + os.makedirs(outside_dir) + + try: + os.symlink(outside_dir, link_dir, target_is_directory=True) + except (OSError, NotImplementedError): + pytest.skip("symlink creation is not available") + + result = await tool.execute(problem_dir=problem_dir, output_dir="tests_link") + + assert not result.success + assert "output_dir must be inside problem_dir" in result.error + + +def test_problem_generate_tests_clear_only_generated_files(): + """测试清理输出目录时只删除旧的 .in/.ans 文件。""" + tool = ProblemGenerateTestsTool() + + with tempfile.TemporaryDirectory() as tmpdir: + tests_dir = os.path.join(tmpdir, "tests") + os.makedirs(tests_dir) + keep_path = os.path.join(tests_dir, "notes.txt") + old_in_path = os.path.join(tests_dir, "01.in") + old_ans_path = os.path.join(tests_dir, "01.ans") + + with open(keep_path, "w", encoding="utf-8") as f: + f.write("keep me") + with open(old_in_path, "w", encoding="utf-8") as f: + f.write("old input") + with open(old_ans_path, "w", encoding="utf-8") as f: + f.write("old answer") + + result = tool._clear_generated_tests(tests_dir) + + assert result is None + assert os.path.exists(keep_path) + assert not os.path.exists(old_in_path) + assert not os.path.exists(old_ans_path) + + +@pytest.mark.asyncio +async def test_problem_generate_tests_uses_custom_sol_name(monkeypatch): + """测试生成答案时使用自定义 sol_name。""" + tool = ProblemGenerateTestsTool() + + async def fake_run_binary_with_args(*args, **kwargs): + return RunResult(success=True, stdout="7\n") + + async def fake_run_binary(binary_path, stdin="", timeout=5, memory_mb=256): + assert os.path.basename(binary_path) == f"accepted{get_exe_extension()}" + assert stdin == "7\n" + return RunResult(success=True, stdout="7\n") + + monkeypatch.setattr("autocode_mcp.tools.problem.run_binary_with_args", fake_run_binary_with_args) + monkeypatch.setattr("autocode_mcp.tools.problem.run_binary", fake_run_binary) + + with tempfile.TemporaryDirectory() as tmpdir: + problem_dir = os.path.join(tmpdir, "custom_sol") + files_dir = os.path.join(problem_dir, "files") + solutions_dir = os.path.join(problem_dir, "solutions") + os.makedirs(files_dir) + os.makedirs(solutions_dir) + + exe_ext = get_exe_extension() + open(os.path.join(files_dir, f"gen{exe_ext}"), "w").close() + open(os.path.join(solutions_dir, f"accepted{exe_ext}"), "w").close() + + result = await tool.execute( + problem_dir=problem_dir, + test_count=1, + sol_name="accepted", + enable_dedup=False, + oversample_ratio=1.0, + ) + + assert result.success + assert result.data["sol_name"] == "accepted" + assert os.path.exists(os.path.join(problem_dir, "tests", "01.in")) + assert os.path.exists(os.path.join(problem_dir, "tests", "01.ans")) + + +def test_problem_verify_tests_file_count_requires_contiguous_numeric_names(): + """测试 file_count 会检查数字文件名连续性。""" + tool = ProblemVerifyTestsTool() + + with tempfile.TemporaryDirectory() as tmpdir: + for name in ["01.in", "01.ans", "03.in", "03.ans"]: + with open(os.path.join(tmpdir, name), "w", encoding="utf-8") as f: + f.write("x\n") + + result = tool._check_file_count(tmpdir) + + assert not result["passed"] + assert result["missing_indices"] == [2] + + +def test_problem_verify_tests_file_count_reports_large_gaps(): + """测试跳到大编号时会报告完整缺失区间。""" + tool = ProblemVerifyTestsTool() + + with tempfile.TemporaryDirectory() as tmpdir: + for name in ["01.in", "01.ans", "100.in", "100.ans"]: + with open(os.path.join(tmpdir, name), "w", encoding="utf-8") as f: + f.write("x\n") + + result = tool._check_file_count(tmpdir) + + assert not result["passed"] + assert result["missing_indices"][0] == 2 + assert result["missing_indices"][-1] == 99 + assert len(result["missing_indices"]) == 98 + + @pytest.mark.asyncio async def test_problem_pack_polygon_dynamic_test_count(): """测试 Polygon 打包使用动态 test-count。""" diff --git a/tests/test_tools/test_solution.py b/tests/test_tools/test_solution.py index 55fa36a..b064e63 100644 --- a/tests/test_tools/test_solution.py +++ b/tests/test_tools/test_solution.py @@ -277,6 +277,34 @@ async def test_solution_build(): assert os.path.exists(result.data["binary_path"]) +@pytest.mark.asyncio +async def test_solution_build_custom_name_keeps_standard_files(monkeypatch): + """测试自定义命名构建时仍保留 sol.cpp/sol 可供默认流程使用。""" + tool = SolutionBuildTool() + + async def fake_build(source_path, binary_path, compiler="g++", include_dirs=None): + with open(binary_path, "w", encoding="utf-8") as f: + f.write("binary") + return CompileResult(success=True, binary_path=binary_path) + + monkeypatch.setattr(tool, "build", fake_build) + + with tempfile.TemporaryDirectory() as tmpdir: + result = await tool.execute( + problem_dir=tmpdir, + solution_type="sol", + name="accepted", + code=SIMPLE_CPP, + ) + + assert result.success + assert os.path.exists(os.path.join(tmpdir, "solutions", "accepted.cpp")) + assert os.path.exists(os.path.join(tmpdir, "solutions", "sol.cpp")) + assert os.path.exists(result.data["binary_path"]) + assert os.path.exists(result.data["standard_binary_path"]) + assert result.data["effective_name"] == "accepted" + + @pytest.mark.asyncio async def test_solution_build_brute(): """测试暴力解法构建。""" diff --git a/tests/test_workflow_guard.py b/tests/test_workflow_guard.py index de0846b..3595990 100644 --- a/tests/test_workflow_guard.py +++ b/tests/test_workflow_guard.py @@ -72,3 +72,92 @@ def test_post_tool_marks_stress_passed(tmp_path): assert exit_code == 0 assert state["stress_passed"] is True + + +def test_pre_tool_denies_pack_before_tests_verified(tmp_path, capsys): + module = load_module() + problem_dir = tmp_path / "problem" + (problem_dir / "files").mkdir(parents=True) + (problem_dir / "solutions").mkdir(parents=True) + state = { + "problem_dir": str(problem_dir), + "created": True, + "sol_built": True, + "brute_built": True, + "validator_ready": True, + "validator_accuracy": 1.0, + "generator_built": True, + "stress_passed": True, + "checker_ready": False, + "validation_passed": True, + "tests_generated": True, + "generated_test_count": 3, + "tests_verified": False, + "packaged": False, + } + module.save_state(str(problem_dir), state) + + payload = { + "tool_name": "mcp__autocode__problem_pack_polygon", + "tool_input": {"problem_dir": str(problem_dir)}, + } + + exit_code = module.pre_tool(payload) + captured = capsys.readouterr().out + + assert exit_code == 0 + parsed = json.loads(captured) + assert parsed["hookSpecificOutput"]["permissionDecision"] == "deny" + assert "problem_verify_tests" in parsed["hookSpecificOutput"]["permissionDecisionReason"] + + +def test_post_tool_marks_tests_verified(tmp_path): + module = load_module() + problem_dir = tmp_path / "problem" + (problem_dir / "files").mkdir(parents=True) + (problem_dir / "solutions").mkdir(parents=True) + + payload = { + "tool_name": "mcp__autocode__problem_verify_tests", + "tool_input": {"problem_dir": str(problem_dir)}, + "tool_response": { + "structuredContent": { + "success": True, + "data": {"passed": True}, + } + }, + } + + exit_code = module.post_tool(payload) + state = module.load_state(str(problem_dir)) + + assert exit_code == 0 + assert state["tests_verified"] is True + + +def test_post_tool_clears_tests_verified_after_regeneration(tmp_path): + module = load_module() + problem_dir = tmp_path / "problem" + (problem_dir / "files").mkdir(parents=True) + (problem_dir / "solutions").mkdir(parents=True) + state = module.infer_state(str(problem_dir)) + state["tests_verified"] = True + module.save_state(str(problem_dir), state) + + payload = { + "tool_name": "mcp__autocode__problem_generate_tests", + "tool_input": {"problem_dir": str(problem_dir)}, + "tool_response": { + "structuredContent": { + "success": True, + "data": {"generated_tests": [1, 2]}, + } + }, + } + + exit_code = module.post_tool(payload) + state = module.load_state(str(problem_dir)) + + assert exit_code == 0 + assert state["tests_generated"] is True + assert state["tests_verified"] is False diff --git a/uv.lock b/uv.lock index d641e5c..df8efbd 100644 --- a/uv.lock +++ b/uv.lock @@ -36,7 +36,7 @@ wheels = [ [[package]] name = "autocode-mcp" -version = "0.6.0" +version = "0.7.0" source = { editable = "." } dependencies = [ { name = "mcp" },