diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index c3cac10a5..c25928e86 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -332,12 +332,19 @@ def _extract_graph(self, script_path: Path, model_id: str) -> Path: self.logger.info(f"Graph extracted to: {sample_dir}") return sample_dir + def _get_subgraph_dirs(self, sample_dir: Path) -> list[Path]: + """Return list of subgraph directories. + + For single-graph models, returns [sample_dir]. + For multi-subgraph models, returns [subgraph_0, subgraph_1, ...] sorted. + """ + subgraph_dirs = sorted(sample_dir.glob("subgraph_*/")) + return subgraph_dirs if subgraph_dirs else [sample_dir] + def _fix_model_name(self, sample_dir: Path, model_id: str) -> None: """Update model_name in graph_net.json to the original HuggingFace model_id (org/model).""" - for json_path in [ - sample_dir / "graph_net.json", - *sample_dir.glob("subgraph_*/graph_net.json"), - ]: + for target_dir in self._get_subgraph_dirs(sample_dir): + json_path = target_dir / "graph_net.json" if not json_path.exists(): continue try: @@ -349,24 +356,20 @@ def _fix_model_name(self, sample_dir: Path, model_id: str) -> None: self.logger.warning(f"Failed to fix model_name in {json_path}: {e}") def _generate_graph_hash(self, sample_dir: Path) -> None: - """Generate graph_hash.txt from model.py if it doesn't exist""" - graph_hash_path = sample_dir / "graph_hash.txt" - model_py_path = sample_dir / "model.py" - - if graph_hash_path.exists(): - return - - if not model_py_path.exists(): - self.logger.warning(f"model.py not found at {model_py_path}") - return - - try: - model_code = model_py_path.read_text() - graph_hash = get_sha256_hash(model_code) - graph_hash_path.write_text(graph_hash) - self.logger.info(f"Generated graph_hash.txt: {graph_hash[:16]}...") - except (OSError, IOError) as e: - self.logger.warning(f"Failed to generate graph_hash.txt: {e}") + """Generate graph_hash.txt from model.py for each subgraph.""" + for target_dir in self._get_subgraph_dirs(sample_dir): + model_py = target_dir / "model.py" + hash_path = target_dir / "graph_hash.txt" + if hash_path.exists() or not model_py.exists(): + continue + try: + graph_hash = get_sha256_hash(model_py.read_text()) + hash_path.write_text(graph_hash) + rel = hash_path.relative_to(sample_dir) + self.logger.info(f"Generated {rel}: {graph_hash[:16]}...") + except (OSError, IOError) as e: + rel = hash_path.relative_to(sample_dir) + self.logger.warning(f"Failed to generate {rel}: {e}") def _move_sample(self, sample_dir: Path, dest_parent: Path) -> Path: """Move sample_dir into dest_parent/, overwriting if destination exists""" @@ -378,31 +381,41 @@ def _move_sample(self, sample_dir: Path, dest_parent: Path) -> Path: return dest def is_duplicate_sample(self, sample_dir: Path) -> bool: - """Check if the extracted sample is a duplicate of an existing sample""" - graph_hash_path = sample_dir / "graph_hash.txt" + """Check if the extracted sample is a duplicate of an existing sample. - if not graph_hash_path.exists(): - return False + Collects all subgraph graph_hash.txt hashes into a frozenset and compares + against existing samples. Works for both single-graph and multi-subgraph. + """ - try: - current_hash = graph_hash_path.read_text().strip() + def _collect_hashes(path: Path) -> frozenset[str]: + hashes = set() + for target_dir in self._get_subgraph_dirs(path): + hash_path = target_dir / "graph_hash.txt" + if hash_path.exists(): + try: + hashes.add(hash_path.read_text().strip()) + except (OSError, IOError): + pass + return frozenset(hashes) - # Search for duplicates in success_dir (where past successful samples live) - for search_root in [self.workspace.success_dir, self.workspace.samples_dir]: + try: + current_hashes = _collect_hashes(sample_dir) + if not current_hashes: + return False + + for search_root in [ + self.workspace.success_dir, + self.workspace.samples_dir, + ]: if not search_root.exists(): continue - for hash_file in search_root.rglob("graph_hash.txt"): - if hash_file == graph_hash_path: + for existing_dir in search_root.iterdir(): + if not existing_dir.is_dir() or existing_dir == sample_dir: continue - try: - existing_hash = hash_file.read_text().strip() - if existing_hash == current_hash: - self.logger.info(f"Duplicate found: {hash_file.parent}") - return True - except (OSError, IOError): - continue - - return False + existing_hashes = _collect_hashes(existing_dir) + if existing_hashes and existing_hashes == current_hashes: + self.logger.info(f"Duplicate found: {existing_dir}") + return True except (OSError, IOError) as e: self.logger.warning(f"Failed to check duplicate: {e}") - return False + return False diff --git a/graph_net/agent/scripts/gen_hash_and_dedup.py b/graph_net/agent/scripts/gen_hash_and_dedup.py index a5b628e60..722d912e8 100644 --- a/graph_net/agent/scripts/gen_hash_and_dedup.py +++ b/graph_net/agent/scripts/gen_hash_and_dedup.py @@ -57,16 +57,17 @@ # 依赖:Python 3.6+ # ============================================================================= -import hashlib import os import sys from collections import defaultdict +# Ensure graph_net is importable when running this script standalone +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_GRAPHNET_ROOT = os.path.join(_SCRIPT_DIR, "..", "..", "..") +if _GRAPHNET_ROOT not in sys.path: + sys.path.insert(0, _GRAPHNET_ROOT) -def get_sha256_hash(content): - m = hashlib.sha256() - m.update(content.encode("utf-8")) - return m.hexdigest() +from graph_net.hash_util import get_sha256_hash # noqa: E402 def find_model_files(workspace):