| """ |
| task_executor.py — REST API task executor for the macOS AI Agent environment. |
| """ |
|
|
| import difflib |
| import logging |
| import os |
| import re |
| import shutil |
| import signal |
| import subprocess |
| import threading |
| import time |
| import uuid |
| from http import HTTPStatus |
|
|
| from flask import Flask, jsonify, request |
| from waitress import serve |
|
|
| |
| |
| |
|
|
| TASK_BASE_DIR = os.environ.get("TASK_BASE_DIR", "/Users/AgentUser/tasks") |
| API_PORT = int(os.environ.get("API_PORT", "9090")) |
| API_TOKEN = os.environ.get("API_TOKEN", "") |
|
|
| |
| |
| |
| os.makedirs(TASK_BASE_DIR, exist_ok=True) |
| LOG_FILE = os.path.join(TASK_BASE_DIR, "task_executor.log") |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| handlers=[ |
| logging.FileHandler(LOG_FILE, encoding="utf-8"), |
| logging.StreamHandler(), |
| ], |
| ) |
| log = logging.getLogger("task_executor") |
|
|
| |
| |
| |
| _tasks: dict[str, dict] = {} |
| _tasks_lock = threading.Lock() |
| _TASK_MAX_AGE = int(os.environ.get("TASK_MAX_AGE", "3600")) |
|
|
|
|
| def _evict_old_tasks() -> None: |
| """Drop completed/failed tasks older than TASK_MAX_AGE seconds.""" |
| cutoff = time.monotonic() - _TASK_MAX_AGE |
| with _tasks_lock: |
| stale = [ |
| tid for tid, t in _tasks.items() |
| if t["status"] not in ("pending", "running") |
| and t.get("_created", 0) < cutoff |
| ] |
| for tid in stale: |
| _tasks.pop(tid) |
|
|
| app = Flask(__name__) |
|
|
| def _check_auth() -> bool: |
| """Return True if the request is authorised (or auth is disabled).""" |
| if not API_TOKEN: |
| return True |
| auth = request.headers.get("Authorization", "") |
| return auth == f"Bearer {API_TOKEN}" |
|
|
| |
| |
| |
|
|
| class _TaskTimeoutError(RuntimeError): |
| """Raised by _run() when a subprocess exceeds its allotted time.""" |
|
|
|
|
| |
| |
| |
|
|
| def _run( |
| command: list[str] | str, |
| cwd: str | None = None, |
| timeout: int = 120, |
| shell: bool = False, |
| ) -> tuple[int, str, str]: |
| """ |
| Run a command in a new process group so the entire child tree can be |
| killed on timeout via os.killpg + SIGKILL (POSIX). |
| |
| • list[str] → all internal commands (git clone/checkout/apply/diff). |
| argv passed directly to execvp; no shell, no injection. |
| • shell=True → user-supplied test_command and lint_command only. |
| |
| Raises _TaskTimeoutError on timeout. |
| Returns (exit_code, stdout, stderr). |
| """ |
| proc = subprocess.Popen( |
| command, |
| cwd=cwd, |
| shell=shell, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| env=os.environ.copy(), |
| start_new_session=True, |
| ) |
| try: |
| out, err = proc.communicate(timeout=timeout) |
| return proc.returncode, out, err |
| except subprocess.TimeoutExpired: |
| try: |
| os.killpg(os.getpgid(proc.pid), signal.SIGKILL) |
| except ProcessLookupError: |
| pass |
| proc.wait() |
| raise _TaskTimeoutError(f"Command timed out after {timeout}s") |
|
|
|
|
| |
| |
| |
|
|
| def _parse_pytest(text: str) -> tuple[int, int]: |
| """pytest: '5 passed, 2 failed, 1 error in 3.14s'""" |
| passed, failed = 0, 0 |
| m = re.search(r"(\d+)\s+passed", text) |
| if m: |
| passed = int(m.group(1)) |
| m = re.search(r"(\d+)\s+failed", text) |
| if m: |
| failed = int(m.group(1)) |
| m = re.search(r"(\d+)\s+error", text) |
| if m: |
| failed += int(m.group(1)) |
| return passed, failed |
|
|
|
|
| def _parse_cargo(text: str) -> tuple[int, int]: |
| """cargo test: 'test result: ok. 5 passed; 0 failed; ...' — sums across binaries.""" |
| passed, failed = 0, 0 |
| for m in re.finditer(r"test result:.*?(\d+)\s+passed;\s*(\d+)\s+failed", text): |
| passed += int(m.group(1)) |
| failed += int(m.group(2)) |
| return passed, failed |
|
|
|
|
| def _parse_go(text: str) -> tuple[int, int]: |
| """go test: '--- PASS/FAIL:' lines; falls back to package-level ok/FAIL.""" |
| passed = len(re.findall(r"^--- PASS:", text, re.MULTILINE)) |
| failed = len(re.findall(r"^--- FAIL:", text, re.MULTILINE)) |
| if passed == 0 and failed == 0: |
| passed = len(re.findall(r"^ok\s+\S+", text, re.MULTILINE)) |
| failed = len(re.findall(r"^FAIL\s+\S+", text, re.MULTILINE)) |
| return passed, failed |
|
|
|
|
| def _parse_jest(text: str) -> tuple[int, int]: |
| """Jest: 'Tests: 2 failed, 5 passed, 7 total'""" |
| passed, failed = 0, 0 |
| m = re.search(r"^Tests:\s+(.+)$", text, re.MULTILINE) |
| if m: |
| summary = m.group(1) |
| p = re.search(r"(\d+)\s+passed", summary) |
| f = re.search(r"(\d+)\s+failed", summary) |
| if p: |
| passed = int(p.group(1)) |
| if f: |
| failed = int(f.group(1)) |
| return passed, failed |
|
|
|
|
| def _parse_dotnet(text: str) -> tuple[int, int]: |
| """dotnet test: 'Failed: 2, Passed: 3, Skipped: 0, Total: 5'""" |
| passed, failed = 0, 0 |
| m = re.search(r"Failed:\s*(\d+),\s*Passed:\s*(\d+)", text) |
| if m: |
| failed = int(m.group(1)) |
| passed = int(m.group(2)) |
| return passed, failed |
|
|
|
|
| def _parse_junit(text: str) -> tuple[int, int]: |
| """Maven / Gradle / sbt: 'Tests run: 7, Failures: 2, Errors: 0, Skipped: 0'""" |
| passed_total = failed_total = 0 |
| for m in re.finditer( |
| r"Tests run:\s*(\d+),\s*Failures:\s*(\d+),\s*Errors:\s*(\d+)", text |
| ): |
| run = int(m.group(1)) |
| failures = int(m.group(2)) |
| errors = int(m.group(3)) |
| failed_total += failures + errors |
| passed_total += max(run - failures - errors, 0) |
| return passed_total, failed_total |
|
|
|
|
| def _dispatch_test_parser(test_command: str, text: str) -> tuple[int, int]: |
| """Route to the correct test parser; falls back to trying all.""" |
| cmd = test_command.lower() |
| if "pytest" in cmd or "py.test" in cmd: |
| return _parse_pytest(text) |
| if "cargo" in cmd: |
| return _parse_cargo(text) |
| if "go test" in cmd: |
| return _parse_go(text) |
| if ( |
| "jest" in cmd |
| or ("npm" in cmd and "test" in cmd) |
| or ("yarn" in cmd and "test" in cmd) |
| or ("pnpm" in cmd and "test" in cmd) |
| ): |
| return _parse_jest(text) |
| if "dotnet" in cmd: |
| return _parse_dotnet(text) |
| if "mvn" in cmd or "gradle" in cmd or "sbt" in cmd or "junit" in cmd: |
| return _parse_junit(text) |
| for parser in ( |
| _parse_pytest, _parse_cargo, _parse_go, |
| _parse_jest, _parse_dotnet, _parse_junit, |
| ): |
| p, f = parser(text) |
| if p or f: |
| return p, f |
| return 0, 0 |
|
|
|
|
| |
| |
| |
|
|
| def _parse_lint_errors(lint_command: str, text: str, exit_code: int) -> int: |
| """ |
| Extract an error count from linter output. |
| |
| Each branch targets the canonical output format of the most common CLI |
| linters. The fallback counts lines that contain the word 'error' |
| (case-insensitive), which covers the long tail of less common tools. |
| |
| Soft scoring only — the result is stored but never changes task status. |
| """ |
| cmd = lint_command.lower() |
|
|
| |
| |
| if "ruff" in cmd: |
| m = re.search(r"Found\s+(\d+)\s+error", text) |
| if m: |
| return int(m.group(1)) |
| |
| if "--output-format json" in cmd or "-o json" in cmd: |
| try: |
| import json |
| return len(json.loads(text)) |
| except Exception: |
| pass |
|
|
| |
| |
| if "flake8" in cmd: |
| return len([l for l in text.splitlines() if re.match(r".+:\d+:\d+:\s+[EWF]", l)]) |
|
|
| |
| |
| if "mypy" in cmd: |
| m = re.search(r"Found\s+(\d+)\s+error", text) |
| if m: |
| return int(m.group(1)) |
| return text.count(": error:") |
|
|
| |
| |
| if "pylint" in cmd: |
| return len(re.findall(r"^\S+:\d+:\d+:\s+[EF]\d{4}:", text, re.MULTILINE)) |
|
|
| |
| |
| if "clippy" in cmd or ("cargo" in cmd and "check" in cmd): |
| return len(re.findall(r"^error\[", text, re.MULTILINE)) |
|
|
| |
| |
| |
| if "eslint" in cmd: |
| if "--format json" in cmd or "-f json" in cmd: |
| try: |
| import json |
| data = json.loads(text) |
| return sum( |
| sum(1 for msg in f.get("messages", []) if msg.get("severity") == 2) |
| for f in data |
| ) |
| except Exception: |
| pass |
| m = re.search(r"(\d+)\s+error", text) |
| return int(m.group(1)) if m else 0 |
|
|
| |
| |
| if "go vet" in cmd or "staticcheck" in cmd: |
| return len([l for l in text.splitlines() if l.strip()]) |
|
|
| |
| if "clang-tidy" in cmd or "cppcheck" in cmd: |
| return len(re.findall(r"\berror\b", text, re.IGNORECASE)) |
|
|
| |
| if "dotnet" in cmd and "build" in cmd: |
| m = re.search(r"(\d+)\s+Error\(s\)", text) |
| return int(m.group(1)) if m else 0 |
|
|
| |
| |
| if exit_code != 0: |
| return len(re.findall(r"\berror\b", text, re.IGNORECASE)) |
| return 0 |
|
|
|
|
| |
| |
| |
|
|
| def _normalise_patch(patch: str) -> list[str]: |
| """ |
| Strip unified-diff metadata lines and return only the content lines |
| (lines starting with +, -, or space) with leading +/- preserved. |
| |
| Lines stripped: |
| diff --git ... |
| index ... |
| --- a/... |
| +++ b/... |
| @@ ... @@ (hunk headers) |
| |
| This lets difflib compare the actual code changes regardless of line |
| numbers, file paths, or git object hashes — so minor reformatting of |
| the patch header doesn't penalise an otherwise identical solution. |
| """ |
| kept: list[str] = [] |
| for line in patch.splitlines(): |
| if ( |
| line.startswith("diff ") |
| or line.startswith("index ") |
| or line.startswith("--- ") |
| or line.startswith("+++ ") |
| or line.startswith("@@ ") |
| ): |
| continue |
| kept.append(line) |
| return kept |
|
|
|
|
| def _patch_similarity(agent_patch: str, reference_patch: str) -> float: |
| """ |
| Return a similarity ratio in [0.0, 1.0] between two unified diffs after |
| normalisation. Uses difflib.SequenceMatcher (Ratcliff/Obershelp algorithm). |
| |
| 1.0 = identical changes after stripping metadata. |
| 0.0 = completely different changes. |
| |
| This is an informational signal — the caller decides what threshold |
| constitutes an acceptable match. |
| """ |
| a = _normalise_patch(agent_patch) |
| b = _normalise_patch(reference_patch) |
| if not a and not b: |
| return 1.0 |
| if not a or not b: |
| return 0.0 |
| return difflib.SequenceMatcher(None, a, b).ratio() |
|
|
|
|
| |
| |
| |
|
|
| def _execute( |
| task_id: str, |
| repo_url: str, |
| base_commit: str, |
| patch: str, |
| test_command: str, |
| timeout: int, |
| lint_command: str, |
| capture_diff: bool, |
| reference_patch: str, |
| ) -> None: |
| """ |
| Full task lifecycle: |
| |
| 1. Create isolated workspace |
| 2. git clone |
| 3. git checkout <base_commit> |
| 4. git apply <agent patch> (if patch provided) |
| 5. Run test_command → pass/fail counts |
| 6. Run lint_command (if provided; soft score) |
| 7. git diff <base_commit> (if capture_diff or reference_patch) |
| 8. Compute patch_similarity (if reference_patch provided) |
| 9. Single atomic _update (status + all signals) |
| 10. shutil.rmtree cleanup |
| """ |
| task_dir = os.path.join(TASK_BASE_DIR, task_id) |
| repo_dir = os.path.join(task_dir, "repo") |
| patch_file = os.path.join(task_dir, "task.patch") |
|
|
| stdout_parts: list[str] = [] |
| stderr_parts: list[str] = [] |
| start = time.monotonic() |
|
|
| |
| final_update: dict = { |
| "status": "failed", |
| "exit_code": -1, |
| "stdout": "", |
| "stderr": "", |
| "tests_passed": 0, |
| "tests_failed": 0, |
| "lint_errors": None, |
| "lint_output": None, |
| "patch_diff": None, |
| "patch_similarity": None, |
| "execution_time": 0.0, |
| } |
|
|
| def _update(**kw: object) -> None: |
| with _tasks_lock: |
| _tasks[task_id].update(kw) |
|
|
| _update(status="running") |
|
|
| try: |
| os.makedirs(task_dir, exist_ok=True) |
|
|
| |
| rc, out, err = _run(["git", "clone", repo_url, repo_dir], timeout=120) |
| stdout_parts.append(out); stderr_parts.append(err) |
| if rc != 0: |
| raise RuntimeError(f"git clone failed (rc={rc}): {err.strip()}") |
|
|
| |
| rc, out, err = _run(["git", "checkout", base_commit], cwd=repo_dir, timeout=60) |
| stdout_parts.append(out); stderr_parts.append(err) |
| if rc != 0: |
| raise RuntimeError(f"git checkout failed (rc={rc}): {err.strip()}") |
|
|
| |
| if patch and patch.strip(): |
| with open(patch_file, "w", encoding="utf-8") as fh: |
| fh.write(patch) |
| rc, out, err = _run(["git", "apply", patch_file], cwd=repo_dir, timeout=30) |
| stdout_parts.append(out); stderr_parts.append(err) |
| if rc != 0: |
| raise RuntimeError(f"git apply failed (rc={rc}): {err.strip()}") |
|
|
| |
| rc, out, err = _run(test_command, cwd=repo_dir, timeout=timeout, shell=True) |
| stdout_parts.append(out); stderr_parts.append(err) |
| test_exit_code = rc |
|
|
| combined_stdout = "\n".join(filter(None, stdout_parts)) |
| combined_stderr = "\n".join(filter(None, stderr_parts)) |
| passed, failed = _dispatch_test_parser( |
| test_command, combined_stdout + "\n" + combined_stderr |
| ) |
|
|
| |
| lint_errors_count: int | None = None |
| lint_out: str | None = None |
|
|
| if lint_command and lint_command.strip(): |
| try: |
| lint_rc, l_out, l_err = _run( |
| lint_command, cwd=repo_dir, timeout=120, shell=True |
| ) |
| lint_out = (l_out + "\n" + l_err).strip() or None |
| lint_errors_count = _parse_lint_errors( |
| lint_command, lint_out or "", lint_rc |
| ) |
| log.info( |
| "Task %s lint finished — rc=%d errors=%s", |
| task_id, lint_rc, lint_errors_count, |
| ) |
| except _TaskTimeoutError: |
| lint_out = "Lint timed out after 120s" |
| lint_errors_count = None |
| log.warning("Task %s lint timed out", task_id) |
| except Exception as exc: |
| lint_out = f"Lint error: {exc}" |
| lint_errors_count = None |
| log.warning("Task %s lint exception: %s", task_id, exc) |
|
|
| |
| |
| |
| patch_diff_text: str | None = None |
|
|
| if capture_diff or (reference_patch and reference_patch.strip()): |
| try: |
| _, diff_out, _ = _run( |
| ["git", "diff", base_commit], cwd=repo_dir, timeout=30 |
| ) |
| patch_diff_text = diff_out.strip() or None |
| except Exception as exc: |
| log.warning("Task %s git diff failed: %s", task_id, exc) |
|
|
| |
| similarity: float | None = None |
|
|
| if reference_patch and reference_patch.strip(): |
| try: |
| agent_diff = patch_diff_text or (patch if patch and patch.strip() else "") |
| if agent_diff: |
| similarity = round( |
| _patch_similarity(agent_diff, reference_patch), 4 |
| ) |
| log.info("Task %s patch_similarity=%.4f", task_id, similarity) |
| except Exception as exc: |
| log.warning("Task %s similarity computation failed: %s", task_id, exc) |
|
|
| |
| final_update = { |
| "status": "completed", |
| "exit_code": test_exit_code, |
| "stdout": combined_stdout, |
| "stderr": combined_stderr, |
| "tests_passed": passed, |
| "tests_failed": failed, |
| "lint_errors": lint_errors_count, |
| "lint_output": lint_out, |
| "patch_diff": patch_diff_text, |
| "patch_similarity": similarity, |
| "execution_time": round(time.monotonic() - start, 3), |
| } |
|
|
| except _TaskTimeoutError as exc: |
| stderr_parts.append(str(exc)) |
| log.error("Task %s timed out after %ds", task_id, timeout) |
| final_update.update({ |
| "stdout": "\n".join(filter(None, stdout_parts)), |
| "stderr": "\n".join(filter(None, stderr_parts)), |
| "execution_time": round(time.monotonic() - start, 3), |
| }) |
|
|
| except Exception as exc: |
| stderr_parts.append(str(exc)) |
| log.exception("Task %s failed: %s", task_id, exc) |
| final_update.update({ |
| "stdout": "\n".join(filter(None, stdout_parts)), |
| "stderr": "\n".join(filter(None, stderr_parts)), |
| "execution_time": round(time.monotonic() - start, 3), |
| }) |
|
|
| finally: |
| _update(**final_update) |
| try: |
| shutil.rmtree(task_dir, ignore_errors=True) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| @app.route("/task/submit", methods=["POST"]) |
| def submit(): |
| """ |
| POST /task/submit |
| |
| Body (JSON): |
| repo_url str required |
| base_commit str optional (default: HEAD) |
| patch str optional — agent's unified diff |
| test_command str required — e.g. "python3 -m pytest tests/ -x" |
| timeout int optional (default: 300) |
| lint_command str optional — e.g. "ruff check . --output-format json" |
| capture_diff bool optional (default: false) |
| reference_patch str optional — ground-truth unified diff |
| |
| Returns 202: { "task_id": "<uuid>", "status": "pending" } |
| """ |
| if not _check_auth(): |
| return jsonify(error="Unauthorized"), HTTPStatus.UNAUTHORIZED |
|
|
| _evict_old_tasks() |
|
|
| body = request.get_json(force=True, silent=True) |
| if not body: |
| return jsonify(error="Request body must be valid JSON"), HTTPStatus.BAD_REQUEST |
|
|
| missing = [f for f in ("repo_url", "test_command") if not body.get(f)] |
| if missing: |
| return jsonify(error=f"Missing required fields: {missing}"), HTTPStatus.BAD_REQUEST |
|
|
| task_id = str(uuid.uuid4()) |
| record: dict = { |
| "task_id": task_id, |
| "status": "pending", |
| "_created": time.monotonic(), |
| "repo_url": body["repo_url"], |
| "base_commit": body.get("base_commit", "HEAD"), |
| "test_command": body["test_command"], |
| "timeout": int(body.get("timeout", 300)), |
| "exit_code": None, |
| "stdout": None, |
| "stderr": None, |
| "tests_passed": None, |
| "tests_failed": None, |
| "lint_errors": None, |
| "lint_output": None, |
| "patch_diff": None, |
| "patch_similarity": None, |
| "execution_time": None, |
| } |
|
|
| with _tasks_lock: |
| _tasks[task_id] = record |
|
|
| threading.Thread( |
| target=_execute, |
| args=( |
| task_id, |
| body["repo_url"], |
| body.get("base_commit", "HEAD"), |
| body.get("patch", ""), |
| body["test_command"], |
| int(body.get("timeout", 300)), |
| body.get("lint_command", ""), |
| bool(body.get("capture_diff", False)), |
| body.get("reference_patch", ""), |
| ), |
| daemon=True, |
| ).start() |
|
|
| log.info("Task %s submitted — repo=%s", task_id, body["repo_url"]) |
| return jsonify(task_id=task_id, status="pending"), HTTPStatus.ACCEPTED |
|
|
|
|
| @app.route("/task/<task_id>", methods=["GET"]) |
| def status(task_id: str): |
| """ |
| GET /task/<task_id> |
| Returns: { "task_id": "...", "status": "pending|running|completed|failed" } |
| """ |
| if not _check_auth(): |
| return jsonify(error="Unauthorized"), HTTPStatus.UNAUTHORIZED |
| with _tasks_lock: |
| t = _tasks.get(task_id) |
| if t is None: |
| return jsonify(error="Task not found"), HTTPStatus.NOT_FOUND |
| return jsonify(task_id=t["task_id"], status=t["status"]) |
|
|
|
|
| @app.route("/task/<task_id>/result", methods=["GET"]) |
| def result(task_id: str): |
| """ |
| GET /task/<task_id>/result |
| |
| Returns 202 while running. |
| Returns 200 with full record on completion: |
| { |
| "task_id": "...", |
| "status": "completed|failed", |
| "exit_code": 0, |
| "stdout": "...", |
| "stderr": "...", |
| "tests_passed": 5, |
| "tests_failed": 1, |
| "lint_errors": 3, |
| "lint_output": "...", |
| "patch_diff": "...", |
| "patch_similarity": 0.9412, |
| "execution_time": 14.2 |
| } |
| """ |
| if not _check_auth(): |
| return jsonify(error="Unauthorized"), HTTPStatus.UNAUTHORIZED |
| with _tasks_lock: |
| t = _tasks.get(task_id) |
| if t is None: |
| return jsonify(error="Task not found"), HTTPStatus.NOT_FOUND |
| if t["status"] in ("pending", "running"): |
| return jsonify( |
| task_id=task_id, |
| status=t["status"], |
| message="Task not yet complete — poll again shortly", |
| ), HTTPStatus.ACCEPTED |
| return jsonify(t) |
|
|
|
|
| @app.route("/task/<task_id>", methods=["DELETE"]) |
| def delete(task_id: str): |
| """ |
| DELETE /task/<task_id> |
| Returns: { "task_id": "...", "deleted": true } |
| """ |
| if not _check_auth(): |
| return jsonify(error="Unauthorized"), HTTPStatus.UNAUTHORIZED |
| with _tasks_lock: |
| if task_id not in _tasks: |
| return jsonify(error="Task not found"), HTTPStatus.NOT_FOUND |
| _tasks.pop(task_id) |
| log.info("Task %s deleted", task_id) |
| return jsonify(task_id=task_id, deleted=True) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| log.info("Task executor starting on 0.0.0.0:%d", API_PORT) |
| serve(app, host="0.0.0.0", port=API_PORT, threads=16) |
|
|