Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 55 additions & 42 deletions graph_net/agent/graph_net_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,19 @@ def _extract_graph(self, script_path: Path, model_id: str) -> Path:
self.logger.info(f"Graph extracted to: {sample_dir}")
return sample_dir

def _get_subgraph_dirs(self, sample_dir: Path) -> list[Path]:
"""Return list of subgraph directories.

For single-graph models, returns [sample_dir].
For multi-subgraph models, returns [subgraph_0, subgraph_1, ...] sorted.
"""
subgraph_dirs = sorted(sample_dir.glob("subgraph_*/"))
return subgraph_dirs if subgraph_dirs else [sample_dir]

def _fix_model_name(self, sample_dir: Path, model_id: str) -> None:
"""Update model_name in graph_net.json to the original HuggingFace model_id (org/model)."""
for json_path in [
sample_dir / "graph_net.json",
*sample_dir.glob("subgraph_*/graph_net.json"),
]:
for target_dir in self._get_subgraph_dirs(sample_dir):
json_path = target_dir / "graph_net.json"
if not json_path.exists():
continue
try:
Expand All @@ -349,24 +356,20 @@ def _fix_model_name(self, sample_dir: Path, model_id: str) -> None:
self.logger.warning(f"Failed to fix model_name in {json_path}: {e}")

def _generate_graph_hash(self, sample_dir: Path) -> None:
"""Generate graph_hash.txt from model.py if it doesn't exist"""
graph_hash_path = sample_dir / "graph_hash.txt"
model_py_path = sample_dir / "model.py"

if graph_hash_path.exists():
return

if not model_py_path.exists():
self.logger.warning(f"model.py not found at {model_py_path}")
return

try:
model_code = model_py_path.read_text()
graph_hash = get_sha256_hash(model_code)
graph_hash_path.write_text(graph_hash)
self.logger.info(f"Generated graph_hash.txt: {graph_hash[:16]}...")
except (OSError, IOError) as e:
self.logger.warning(f"Failed to generate graph_hash.txt: {e}")
"""Generate graph_hash.txt from model.py for each subgraph."""
for target_dir in self._get_subgraph_dirs(sample_dir):
model_py = target_dir / "model.py"
hash_path = target_dir / "graph_hash.txt"
if hash_path.exists() or not model_py.exists():
continue
try:
graph_hash = get_sha256_hash(model_py.read_text())
hash_path.write_text(graph_hash)
rel = hash_path.relative_to(sample_dir)
self.logger.info(f"Generated {rel}: {graph_hash[:16]}...")
except (OSError, IOError) as e:
rel = hash_path.relative_to(sample_dir)
self.logger.warning(f"Failed to generate {rel}: {e}")

def _move_sample(self, sample_dir: Path, dest_parent: Path) -> Path:
"""Move sample_dir into dest_parent/, overwriting if destination exists"""
Expand All @@ -378,31 +381,41 @@ def _move_sample(self, sample_dir: Path, dest_parent: Path) -> Path:
return dest

def is_duplicate_sample(self, sample_dir: Path) -> bool:
"""Check if the extracted sample is a duplicate of an existing sample"""
graph_hash_path = sample_dir / "graph_hash.txt"
"""Check if the extracted sample is a duplicate of an existing sample.

if not graph_hash_path.exists():
return False
Collects all subgraph graph_hash.txt hashes into a frozenset and compares
against existing samples. Works for both single-graph and multi-subgraph.
"""

try:
current_hash = graph_hash_path.read_text().strip()
def _collect_hashes(path: Path) -> frozenset[str]:
hashes = set()
for target_dir in self._get_subgraph_dirs(path):
hash_path = target_dir / "graph_hash.txt"
if hash_path.exists():
try:
hashes.add(hash_path.read_text().strip())
except (OSError, IOError):
pass
return frozenset(hashes)

# Search for duplicates in success_dir (where past successful samples live)
for search_root in [self.workspace.success_dir, self.workspace.samples_dir]:
try:
current_hashes = _collect_hashes(sample_dir)
if not current_hashes:
return False

for search_root in [
self.workspace.success_dir,
self.workspace.samples_dir,
]:
if not search_root.exists():
continue
for hash_file in search_root.rglob("graph_hash.txt"):
if hash_file == graph_hash_path:
for existing_dir in search_root.iterdir():
if not existing_dir.is_dir() or existing_dir == sample_dir:
continue
try:
existing_hash = hash_file.read_text().strip()
if existing_hash == current_hash:
self.logger.info(f"Duplicate found: {hash_file.parent}")
return True
except (OSError, IOError):
continue

return False
existing_hashes = _collect_hashes(existing_dir)
if existing_hashes and existing_hashes == current_hashes:
self.logger.info(f"Duplicate found: {existing_dir}")
return True
except (OSError, IOError) as e:
self.logger.warning(f"Failed to check duplicate: {e}")
return False
return False
11 changes: 6 additions & 5 deletions graph_net/agent/scripts/gen_hash_and_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,17 @@
# 依赖:Python 3.6+
# =============================================================================

import hashlib
import os
import sys
from collections import defaultdict

# Ensure graph_net is importable when running this script standalone
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_GRAPHNET_ROOT = os.path.join(_SCRIPT_DIR, "..", "..", "..")
if _GRAPHNET_ROOT not in sys.path:
sys.path.insert(0, _GRAPHNET_ROOT)

def get_sha256_hash(content):
m = hashlib.sha256()
m.update(content.encode("utf-8"))
return m.hexdigest()
from graph_net.hash_util import get_sha256_hash # noqa: E402


def find_model_files(workspace):
Expand Down
Loading