Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions graph_net/agent/graph_extractor/subprocess_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,13 @@ def extract(self, code_path: Path, model_id: str) -> Path:
else:
env["PYTHONPATH"] = str(graphnet_root)

# Ensure GRAPH_NET_EXTRACT_WORKSPACE points to our workspace
if "GRAPH_NET_EXTRACT_WORKSPACE" not in env:
env["GRAPH_NET_EXTRACT_WORKSPACE"] = str(self.workspace)
# Ensure GRAPH_NET_EXTRACT_WORKSPACE points to samples dir
# so extraction output goes to workspace/samples/ instead of root
samples_dir = self.workspace / "samples"
samples_dir.mkdir(parents=True, exist_ok=True)
env["GRAPH_NET_EXTRACT_WORKSPACE"] = str(samples_dir)
# Also set in current process env so _get_workspace_path() can find it
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = str(samples_dir)

# Run script in subprocess via Popen so we can kill on timeout
proc = subprocess.Popen(
Expand Down Expand Up @@ -225,7 +229,12 @@ def _find_output_dir_robust(self, model_id: str) -> Optional[Path]:
def _get_workspace_path(self) -> Optional[Path]:
"""Get workspace path from environment or instance variable"""
workspace_env = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
return Path(workspace_env) if workspace_env else self.workspace
if workspace_env:
return Path(workspace_env)
# Default to samples/ subdir to avoid cluttering workspace root
samples_dir = self.workspace / "samples"
samples_dir.mkdir(parents=True, exist_ok=True)
return samples_dir

def _find_dir_by_pattern(
self, workspace_path: Path, model_id: str, safe_model_id: str
Expand Down
46 changes: 33 additions & 13 deletions graph_net/agent/graph_net_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import shutil
from enum import Enum
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -118,6 +119,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus:
ExtractionStatus.ERROR – unexpected error
"""
self.last_timeout_success = False
sample_dir: Optional[Path] = None
try:
self.logger.info(f"Starting extraction for model: {model_id}")

Expand All @@ -144,6 +146,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus:

if self.is_duplicate_sample(sample_dir):
self.logger.info("Duplicate sample detected, skipping verification")
self._move_sample(sample_dir, self.workspace.success_dir)
return ExtractionStatus.OK

if not self.sample_verifier.verify(sample_dir):
Expand All @@ -152,6 +155,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus:
model_id,
Exception("Sample verification failed"),
)
self._move_sample(sample_dir, self.workspace.failed_dir)
return ExtractionStatus.VERIFY_FAILED

if getattr(self.sample_verifier, "last_timeout_success", False):
Expand All @@ -160,12 +164,15 @@ def extract_sample(self, model_id: str) -> ExtractionStatus:
f"Sample verification for {model_id} passed via timeout skip"
)

self._move_sample(sample_dir, self.workspace.success_dir)
self.logger.info(f"Successfully extracted sample for {model_id}")
return ExtractionStatus.OK

except SampleVerificationError as e:
self.logger.error(f"Extraction failed for {model_id}: {e}")
self.error_classifier.classify_and_record(model_id, e)
if sample_dir and sample_dir.exists():
self._move_sample(sample_dir, self.workspace.failed_dir)
return ExtractionStatus.VERIFY_FAILED
except (
ModelFetchError,
Expand All @@ -175,10 +182,14 @@ def extract_sample(self, model_id: str) -> ExtractionStatus:
) as e:
self.logger.error(f"Extraction failed for {model_id}: {e}")
self.error_classifier.classify_and_record(model_id, e)
if sample_dir and sample_dir.exists():
self._move_sample(sample_dir, self.workspace.failed_dir)
return ExtractionStatus.EXTRACT_FAILED
except Exception as e:
self.logger.error(f"Unexpected error for {model_id}: {e}", exc_info=True)
self.error_classifier.classify_and_record(model_id, e)
if sample_dir and sample_dir.exists():
self._move_sample(sample_dir, self.workspace.failed_dir)
return ExtractionStatus.ERROR

@staticmethod
Expand Down Expand Up @@ -357,6 +368,15 @@ def _generate_graph_hash(self, sample_dir: Path) -> None:
except (OSError, IOError) as e:
self.logger.warning(f"Failed to generate graph_hash.txt: {e}")

def _move_sample(self, sample_dir: Path, dest_parent: Path) -> Path:
"""Move sample_dir into dest_parent/, overwriting if destination exists"""
dest = dest_parent / sample_dir.name
if dest.exists():
shutil.rmtree(dest)
shutil.move(str(sample_dir), str(dest))
self.logger.info(f"Moved sample to: {dest}")
return dest

def is_duplicate_sample(self, sample_dir: Path) -> bool:
"""Check if the extracted sample is a duplicate of an existing sample"""
graph_hash_path = sample_dir / "graph_hash.txt"
Expand All @@ -366,21 +386,21 @@ def is_duplicate_sample(self, sample_dir: Path) -> bool:

try:
current_hash = graph_hash_path.read_text().strip()
samples_root = self.workspace.samples_dir

if not samples_root.exists():
return False

for hash_file in samples_root.rglob("graph_hash.txt"):
if hash_file == graph_hash_path:
continue
try:
existing_hash = hash_file.read_text().strip()
if existing_hash == current_hash:
self.logger.info(f"Duplicate found: {hash_file.parent}")
return True
except (OSError, IOError):
# Search for duplicates in success_dir (where past successful samples live)
for search_root in [self.workspace.success_dir, self.workspace.samples_dir]:
if not search_root.exists():
continue
for hash_file in search_root.rglob("graph_hash.txt"):
if hash_file == graph_hash_path:
continue
try:
existing_hash = hash_file.read_text().strip()
if existing_hash == current_hash:
self.logger.info(f"Duplicate found: {hash_file.parent}")
return True
except (OSError, IOError):
continue

return False
except (OSError, IOError) as e:
Expand Down
10 changes: 8 additions & 2 deletions graph_net/agent/parallel_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,14 @@ def main() -> int:
)

# --- Save results ---
output_file = (
args.output or f"parallel_extract_{start_time.strftime('%Y%m%d_%H%M%S')}.json"
from graph_net.agent.utils.workspace_manager import (
WorkspaceManager as _WorkspaceManager,
)

_ws = _WorkspaceManager(workspace)
output_file = args.output or str(
_ws.logs_and_lists_dir
/ f"parallel_extract_{start_time.strftime('%Y%m%d_%H%M%S')}.json"
)
_save_results(results, output_file)

Expand Down
18 changes: 18 additions & 0 deletions graph_net/agent/utils/workspace_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def _ensure_directories(self):
self.generated_dir,
self.samples_dir,
self.logs_dir,
self.success_dir,
self.failed_dir,
self.logs_and_lists_dir,
]
for dir_path in dirs:
dir_path.mkdir(parents=True, exist_ok=True)
Expand All @@ -46,6 +49,21 @@ def logs_dir(self) -> Path:
"""Directory for logs"""
return self.workspace_root / "logs"

@property
def success_dir(self) -> Path:
"""Directory for successfully extracted samples"""
return self.workspace_root / "success"

@property
def failed_dir(self) -> Path:
"""Directory for failed extraction artifacts"""
return self.workspace_root / "failed"

@property
def logs_and_lists_dir(self) -> Path:
"""Directory for result JSONs, model lists, and run logs"""
return self.workspace_root / "logs_and_lists"

def get_model_dir(self, model_id: str) -> Path:
"""Get directory path for a specific model"""
return self.models_dir / model_id.replace("/", "_")
Expand Down
Loading