Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 127 additions & 18 deletions triton_kernel_agent/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Verification Worker for testing and refining individual kernels."""

import ast
import json
import logging
import multiprocessing as mp
Expand Down Expand Up @@ -261,6 +262,27 @@ def _extract_code_from_response(
self.logger.warning("No code block found in LLM response")
return None

def _validate_kernel_candidate(self, kernel_code: str | None) -> str | None:
"""Return a reason when extracted kernel code is structurally malformed."""
if not kernel_code or not kernel_code.strip():
return "no Python kernel code was extracted from the model response"

try:
module = ast.parse(kernel_code)
except SyntaxError as exc:
line = f"line {exc.lineno}" if exc.lineno else "unknown line"
return f"invalid Python syntax ({line}: {exc.msg})"

has_kernel_function = any(
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == "kernel_function"
for node in module.body
)
if not has_kernel_function:
return "missing required top-level kernel_function definition"

return None

def _write_kernel(self, kernel_code: str):
"""Write only the kernel code to file."""
self.kernel_file.write_text(kernel_code)
Expand Down Expand Up @@ -413,6 +435,12 @@ def _refine_kernel(
response_text,
prefer_kernel_function=getattr(self, "_has_multiple_tests", False),
)
malformed_reason = self._validate_kernel_candidate(refined_kernel)

if malformed_reason:
raise ValueError(
f"Malformed LLM kernel response: {malformed_reason}"
)

if refined_kernel:
self.logger.info(
Expand All @@ -425,6 +453,8 @@ def _refine_kernel(
return kernel_code

except Exception as e:
if "Malformed LLM kernel response" in str(e):
raise
self.logger.error(f"Error refining kernel with LLM API: {e}")
# Fall back to mock refinement

Expand Down Expand Up @@ -482,6 +512,17 @@ def run(
self._has_multiple_tests = len(test_code) > 1

current_kernel = kernel_code
malformed_reason = self._validate_kernel_candidate(current_kernel)
if malformed_reason:
error_feedback = f"Malformed LLM kernel response: {malformed_reason}"
self.logger.warning(f"❌ {error_feedback}")
return {
"worker_id": self.worker_id,
"success": False,
"rounds": 0,
"error": error_feedback,
"history": list(self.history),
}

for round_num in range(self.max_rounds):
# Check if another worker has succeeded
Expand Down Expand Up @@ -516,12 +557,38 @@ def run(
"stderr": violation,
"history": list(self.history),
}
current_kernel = self._refine_kernel(
current_kernel,
error_info,
problem_description,
format_test_code_for_llm(test_code),
)
try:
current_kernel = self._refine_kernel(
current_kernel,
error_info,
problem_description,
format_test_code_for_llm(test_code),
)
except ValueError as exc:
if "Malformed LLM kernel response" not in str(exc):
raise
error_feedback = str(exc)
self.logger.warning(f"❌ {error_feedback}")
return {
"worker_id": self.worker_id,
"success": False,
"rounds": round_num + 1,
"error": error_feedback,
"history": list(self.history),
}
malformed_reason = self._validate_kernel_candidate(current_kernel)
if malformed_reason:
error_feedback = (
f"Malformed LLM kernel response: {malformed_reason}"
)
self.logger.warning(f"❌ {error_feedback}")
return {
"worker_id": self.worker_id,
"success": False,
"rounds": round_num + 1,
"error": error_feedback,
"history": list(self.history),
}
continue

# Log round
Expand All @@ -546,12 +613,36 @@ def run(
"history": list(self.history),
}

current_kernel = self._refine_kernel(
current_kernel,
error_info,
problem_description,
format_test_code_for_llm(test_code),
)
try:
current_kernel = self._refine_kernel(
current_kernel,
error_info,
problem_description,
format_test_code_for_llm(test_code),
)
except ValueError as exc:
if "Malformed LLM kernel response" not in str(exc):
raise
error_feedback = str(exc)
self.logger.warning(f"❌ {error_feedback}")
return {
"worker_id": self.worker_id,
"success": False,
"rounds": round_num + 1,
"error": error_feedback,
"history": list(self.history),
}
malformed_reason = self._validate_kernel_candidate(current_kernel)
if malformed_reason:
error_feedback = f"Malformed LLM kernel response: {malformed_reason}"
self.logger.warning(f"❌ {error_feedback}")
return {
"worker_id": self.worker_id,
"success": False,
"rounds": round_num + 1,
"error": error_feedback,
"history": list(self.history),
}

# Max rounds reached without success
self.logger.warning(f"Max rounds ({self.max_rounds}) reached without success")
Expand Down Expand Up @@ -619,6 +710,12 @@ def verify_with_refinement(
current_kernel = kernel_code
self._has_multiple_tests = len(test_code) > 1

malformed_reason = self._validate_kernel_candidate(current_kernel)
if malformed_reason:
error_feedback = f"Malformed LLM kernel response: {malformed_reason}"
self.logger.warning(f"❌ {error_feedback}")
return False, current_kernel, error_feedback

# Write files for testing (primary + additional tests)
self._write_files(current_kernel, test_code)

Expand Down Expand Up @@ -653,12 +750,24 @@ def verify_with_refinement(
}

# Refine kernel
refined_kernel = self._refine_kernel(
current_kernel,
error_info,
problem_description,
format_test_code_for_llm(test_code),
)
try:
refined_kernel = self._refine_kernel(
current_kernel,
error_info,
problem_description,
format_test_code_for_llm(test_code),
)
except ValueError as exc:
if "Malformed LLM kernel response" not in str(exc):
raise
error_feedback = str(exc)
self.logger.warning(f"❌ {error_feedback}")
return False, current_kernel, error_feedback
malformed_reason = self._validate_kernel_candidate(refined_kernel)
if malformed_reason:
error_feedback = f"Malformed LLM kernel response: {malformed_reason}"
self.logger.warning(f"❌ {error_feedback}")
return False, current_kernel, error_feedback

# Write and test refined kernel
self._write_kernel(refined_kernel)
Expand Down
Loading