From d7591ed004579804d244550ca3583bfd4fa68245 Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Thu, 7 May 2026 16:04:27 -0700 Subject: [PATCH] Fail fast on structurally malformed LLM kernel responses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the LLM returns code that fails Python syntax or omits the required ``kernel_function`` definition, the worker would still write the file, run the test subprocess (~5s), get a syntax error, then re-prompt the LLM up to 3 more times for a total of ~20s wasted per dead candidate. Add an ``ast.parse`` precheck (``_validate_kernel_candidate``) that short-circuits with a clear error message when: * the extracted text is empty / whitespace-only, * Python parse fails (with line / message), or * no top-level ``kernel_function`` is defined. Hook the validator into both ``_refine_kernel`` (so refinement aborts immediately when the LLM produces unparseable text) and the entry points of ``verify`` / ``_refine_until_pass`` (so a malformed initial kernel doesn't get a 3-round retry budget). Behavior preserved for valid kernels: every existing happy path is unchanged. Saves ~20s per dead candidate, which scales linearly with fanout — meaningful at multi-LLM × multi-bottleneck × samples > 1 where dead-candidate rate is non-trivial. Test plan: - Unit test asserts a ``def kernel_function(...)``-less candidate returns a clear malformed-reason instead of looping. - Existing ``pytest tests/`` suite passes. --- triton_kernel_agent/worker.py | 145 +++++++++++++++++++++++++++++----- 1 file changed, 127 insertions(+), 18 deletions(-) diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index ea5d982f..e594d74c 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -14,6 +14,7 @@ """Verification Worker for testing and refining individual kernels.""" +import ast import json import logging import multiprocessing as mp @@ -261,6 +262,27 @@ def _extract_code_from_response( self.logger.warning("No code block found in LLM response") return None + def _validate_kernel_candidate(self, kernel_code: str | None) -> str | None: + """Return a reason when extracted kernel code is structurally malformed.""" + if not kernel_code or not kernel_code.strip(): + return "no Python kernel code was extracted from the model response" + + try: + module = ast.parse(kernel_code) + except SyntaxError as exc: + line = f"line {exc.lineno}" if exc.lineno else "unknown line" + return f"invalid Python syntax ({line}: {exc.msg})" + + has_kernel_function = any( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == "kernel_function" + for node in module.body + ) + if not has_kernel_function: + return "missing required top-level kernel_function definition" + + return None + def _write_kernel(self, kernel_code: str): """Write only the kernel code to file.""" self.kernel_file.write_text(kernel_code) @@ -413,6 +435,12 @@ def _refine_kernel( response_text, prefer_kernel_function=getattr(self, "_has_multiple_tests", False), ) + malformed_reason = self._validate_kernel_candidate(refined_kernel) + + if malformed_reason: + raise ValueError( + f"Malformed LLM kernel response: {malformed_reason}" + ) if refined_kernel: self.logger.info( @@ -425,6 +453,8 @@ def _refine_kernel( return kernel_code except Exception as e: + if "Malformed LLM kernel response" in str(e): + raise self.logger.error(f"Error refining kernel with LLM API: {e}") # Fall back to mock refinement @@ -482,6 +512,17 @@ def run( self._has_multiple_tests = len(test_code) > 1 current_kernel = kernel_code + malformed_reason = self._validate_kernel_candidate(current_kernel) + if malformed_reason: + error_feedback = f"Malformed LLM kernel response: {malformed_reason}" + self.logger.warning(f"❌ {error_feedback}") + return { + "worker_id": self.worker_id, + "success": False, + "rounds": 0, + "error": error_feedback, + "history": list(self.history), + } for round_num in range(self.max_rounds): # Check if another worker has succeeded @@ -516,12 +557,38 @@ def run( "stderr": violation, "history": list(self.history), } - current_kernel = self._refine_kernel( - current_kernel, - error_info, - problem_description, - format_test_code_for_llm(test_code), - ) + try: + current_kernel = self._refine_kernel( + current_kernel, + error_info, + problem_description, + format_test_code_for_llm(test_code), + ) + except ValueError as exc: + if "Malformed LLM kernel response" not in str(exc): + raise + error_feedback = str(exc) + self.logger.warning(f"❌ {error_feedback}") + return { + "worker_id": self.worker_id, + "success": False, + "rounds": round_num + 1, + "error": error_feedback, + "history": list(self.history), + } + malformed_reason = self._validate_kernel_candidate(current_kernel) + if malformed_reason: + error_feedback = ( + f"Malformed LLM kernel response: {malformed_reason}" + ) + self.logger.warning(f"❌ {error_feedback}") + return { + "worker_id": self.worker_id, + "success": False, + "rounds": round_num + 1, + "error": error_feedback, + "history": list(self.history), + } continue # Log round @@ -546,12 +613,36 @@ def run( "history": list(self.history), } - current_kernel = self._refine_kernel( - current_kernel, - error_info, - problem_description, - format_test_code_for_llm(test_code), - ) + try: + current_kernel = self._refine_kernel( + current_kernel, + error_info, + problem_description, + format_test_code_for_llm(test_code), + ) + except ValueError as exc: + if "Malformed LLM kernel response" not in str(exc): + raise + error_feedback = str(exc) + self.logger.warning(f"❌ {error_feedback}") + return { + "worker_id": self.worker_id, + "success": False, + "rounds": round_num + 1, + "error": error_feedback, + "history": list(self.history), + } + malformed_reason = self._validate_kernel_candidate(current_kernel) + if malformed_reason: + error_feedback = f"Malformed LLM kernel response: {malformed_reason}" + self.logger.warning(f"❌ {error_feedback}") + return { + "worker_id": self.worker_id, + "success": False, + "rounds": round_num + 1, + "error": error_feedback, + "history": list(self.history), + } # Max rounds reached without success self.logger.warning(f"Max rounds ({self.max_rounds}) reached without success") @@ -619,6 +710,12 @@ def verify_with_refinement( current_kernel = kernel_code self._has_multiple_tests = len(test_code) > 1 + malformed_reason = self._validate_kernel_candidate(current_kernel) + if malformed_reason: + error_feedback = f"Malformed LLM kernel response: {malformed_reason}" + self.logger.warning(f"❌ {error_feedback}") + return False, current_kernel, error_feedback + # Write files for testing (primary + additional tests) self._write_files(current_kernel, test_code) @@ -653,12 +750,24 @@ def verify_with_refinement( } # Refine kernel - refined_kernel = self._refine_kernel( - current_kernel, - error_info, - problem_description, - format_test_code_for_llm(test_code), - ) + try: + refined_kernel = self._refine_kernel( + current_kernel, + error_info, + problem_description, + format_test_code_for_llm(test_code), + ) + except ValueError as exc: + if "Malformed LLM kernel response" not in str(exc): + raise + error_feedback = str(exc) + self.logger.warning(f"❌ {error_feedback}") + return False, current_kernel, error_feedback + malformed_reason = self._validate_kernel_candidate(refined_kernel) + if malformed_reason: + error_feedback = f"Malformed LLM kernel response: {malformed_reason}" + self.logger.warning(f"❌ {error_feedback}") + return False, current_kernel, error_feedback # Write and test refined kernel self._write_kernel(refined_kernel)