diff --git a/graph_net/torch/extractor.py b/graph_net/torch/extractor.py index 568ad995a..1e6e700e1 100644 --- a/graph_net/torch/extractor.py +++ b/graph_net/torch/extractor.py @@ -3,16 +3,21 @@ import json import shutil import glob + +from graph_net.hash_util import get_sha256_hash + +from pathlib import Path from graph_net.torch import utils from graph_net.torch.fx_graph_serialize_util import serialize_graph_module_to_str torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True -torch._dynamo.config.capture_sparse_compute = True +try: + torch._dynamo.config.capture_sparse_compute = True +except AttributeError: + pass torch._dynamo.config.raise_on_ctx_manager_usage = False torch._dynamo.config.allow_rnn = True - - # used as configuration of python3 -m graph_net.torch.run_model class RunModelDecorator: def __init__(self, config): @@ -39,8 +44,6 @@ def make_config( "custom_extractor_config": custom_extractor_config, }, } - - class GraphExtractor: def __init__( self, @@ -162,6 +165,14 @@ def try_rename_placeholder(node): with open(os.path.join(subgraph_path, "model.py"), "w") as fp: fp.write(write_code) + # 5.1 Generate graph_hash.txt from model.py + + graph_hash = hashlib.sha256(write_code.encode("utf-8")).hexdigest() + + graph_hash = get_sha256_hash(write_code) + + (Path(subgraph_path) / "graph_hash.txt").write_text(graph_hash) + # 6. Save metadata LAST — graph_net.json serves as the # completion marker: if it exists, all other files are guaranteed # to be fully written. @@ -180,8 +191,6 @@ def try_rename_placeholder(node): ) return gm.forward - - def extract( name, dynamic=True, @@ -219,8 +228,6 @@ def extract( class GraphModule(torch.nn.Module): - - def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor): l_x_ = L_x_ mul = l_x_ * 2; l_x_ = None @@ -246,8 +253,6 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor): class GraphModule(torch.nn.Module): - - def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor): l_x_ = L_x_ mul = l_x_ * 2; l_x_ = None @@ -304,13 +309,9 @@ def decorator_or_wrapper(obj): ) return decorator_or_wrapper - - def make_extractor_config(extractor_config): kwargs = extractor_config if extractor_config is not None else {} return make_extractor_config_impl(**kwargs) - - def make_extractor_config_impl( custom_extractor_path: str = None, custom_extractor_config: dict = None ): diff --git a/graph_net/torch/sample_pass/backward_graph_extractor.py b/graph_net/torch/sample_pass/backward_graph_extractor.py index e50d3b7e9..25ca3b81f 100644 --- a/graph_net/torch/sample_pass/backward_graph_extractor.py +++ b/graph_net/torch/sample_pass/backward_graph_extractor.py @@ -27,7 +27,7 @@ def __call__(self): module, forward_inputs = get_torch_module_and_inputs( self.model_path, use_dummy_inputs=False, device=self.device ) - module.train() + module.eval() eval_forward_dir = os.path.join( self.output_dir, "eval_forward", self.rel_model_path @@ -35,6 +35,10 @@ def __call__(self): if not os.path.exists(eval_forward_dir): shutil.copytree(self.model_path, eval_forward_dir) + forward_inputs = [ + inp.detach().clone() if isinstance(inp, torch.Tensor) else inp + for inp in forward_inputs + ] forward_inputs = self.set_requires_grad_for_forward_inputs( self.model_path, module, forward_inputs ) diff --git a/graph_net_bench/torch/test_compiler.py b/graph_net_bench/torch/test_compiler.py index d341dece4..1dde4fdc7 100755 --- a/graph_net_bench/torch/test_compiler.py +++ b/graph_net_bench/torch/test_compiler.py @@ -267,7 +267,9 @@ def eager_model_call(): ) torch.manual_seed(runtime_seed) - if not isinstance(expected_out, tuple): + if isinstance(expected_out, list): + expected_out = tuple(expected_out) + elif not isinstance(expected_out, tuple): expected_out = (expected_out,) except (TypeError, RuntimeError) as e: print(f"Eager model execution failed: {str(e)}", file=sys.stderr) @@ -288,7 +290,9 @@ def compiled_model_call(): compiled_model_call, args, compiler ) - if not isinstance(compiled_out, tuple): + if isinstance(compiled_out, list): + compiled_out = tuple(compiled_out) + elif not isinstance(compiled_out, tuple): compiled_out = (compiled_out,) if args.compiler == "xla": compiled_out = tuple(item.to("cpu").to("cuda") for item in compiled_out) @@ -352,7 +356,11 @@ def _get_output_dtypes(outs): ] def _align_output_device(outs, device): - return [x.to(device) if x.device != device else x for x in outs] + if isinstance(outs, torch.Tensor): + return outs.to(device) if outs.device != device else outs + if isinstance(outs, (list, tuple)): + return type(outs)(_align_output_device(x, device) for x in outs) + return outs eager_dtypes = _get_output_dtypes(expected_out) compiled_dtypes = _get_output_dtypes(compiled_out) diff --git a/tools/triton_kernel_extractor/__main__.py b/tools/triton_kernel_extractor/__main__.py index c55aa002a..625ac7e7a 100644 --- a/tools/triton_kernel_extractor/__main__.py +++ b/tools/triton_kernel_extractor/__main__.py @@ -212,6 +212,74 @@ def _run_analyze(args: argparse.Namespace) -> None: analyze_cache(cache_dir, output_dir) +# --------------------------------------------------------------------------- +# Subcommand: dedup +# --------------------------------------------------------------------------- + + +def _add_dedup_parser(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "dedup", + help="Deduplicate extracted Triton kernels by source content.", + description=( + "Walk extracted triton_kernel/*.py files across samples and " + "compute dedup statistics. Two kernels from different subgraphs " + "are considered duplicates when their normalized source code is " + "identical. This is kernel-level dedup (by compiled Triton " + "kernel content), distinct from graph-level dedup via graph_hash.txt." + ), + ) + parser.add_argument( + "--input-dir", + type=Path, + required=True, + help=( + "Root directory containing per-sample subdirectories with " + "triton_kernel/*.py files (output of the extract pipeline step)." + ), + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output JSON path for the dedup report.", + ) + parser.set_defaults(func=_run_dedup) + + +def _run_dedup(args: argparse.Namespace) -> None: + from .kernel_dedup import dedup_kernels + import json + + logger.info("Scanning: %s", args.input_dir) + report = dedup_kernels(args.input_dir) + if not report: + logger.warning("No data to report.") + return + + # Print summary. + logger.info("") + logger.info("=== Kernel Dedup Report ===") + logger.info("Total samples scanned: %d", report["total_samples"]) + logger.info("Total kernel instances: %d", report["total_kernel_instances"]) + logger.info("Unique kernel hashes: %d", report["unique_kernel_hashes"]) + logger.info("Dedup rate: %.1f%%", report["dedup_rate_percent"]) + logger.info("Avg kernels per sample: %.2f", report["avg_kernels_per_sample"]) + logger.info("") + logger.info("Top kernel types by frequency:") + for name, count in list(report["kernel_name_freq"].items())[:20]: + bar = "#" * min(count, 60) + logger.info(" %-50s %4d %s", name, count, bar) + + # Write JSON. + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text( + json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8" + ) + logger.info("") + logger.info("Report saved to: %s", args.output) + + # --------------------------------------------------------------------------- # Backward-compatible argument detection # --------------------------------------------------------------------------- @@ -225,7 +293,7 @@ def _needs_implicit_extract(argv: list[str]) -> bool: """ if not argv: return False - known_subcommands = {"extract", "analyze"} + known_subcommands = {"extract", "analyze", "dedup"} first = argv[0] if first in known_subcommands: return False @@ -263,6 +331,7 @@ def main(argv: list[str] | None = None) -> None: subparsers = parser.add_subparsers(dest="command") _add_extract_parser(subparsers) _add_analyze_parser(subparsers) + _add_dedup_parser(subparsers) args = parser.parse_args(argv) diff --git a/tools/triton_kernel_extractor/kernel_dedup.py b/tools/triton_kernel_extractor/kernel_dedup.py new file mode 100644 index 000000000..baf1a788f --- /dev/null +++ b/tools/triton_kernel_extractor/kernel_dedup.py @@ -0,0 +1,194 @@ +"""Deduplicate Triton kernels by source content across extracted samples. + +Reads the output of the ``extract`` pipeline step (``triton_kernel/*.py`` +files under each sample directory) and identifies duplicate kernels — kernels +with identical source code that appear in multiple samples. + +This is distinct from ``graph_hash.txt``-based dedup which identifies +identical *subgraphs*. Kernel-level dedup identifies identical *compiled +kernels* produced by TorchInductor, which may come from different subgraphs +that share the same operator patterns (e.g., two different backward graphs +both containing ``layer_norm_bwd``). + +Usage (as subcommand):: + + python3 -m tools.triton_kernel_extractor dedup \\ + --input-dir /data/output/extracted \\ + --output /tmp/dedup_report.json + +Usage (standalone):: + + python3 tools/triton_kernel_extractor/kernel_dedup.py \\ + --input-dir /data/output/extracted \\ + --output /tmp/dedup_report.json +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import logging +import os +from collections import Counter +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def _normalize_kernel_source(source: str) -> str: + """Strip comments and blank lines to produce a normalized form for hashing.""" + lines: list[str] = [] + for line in source.splitlines(): + stripped = line.strip() + if not stripped: + continue + if stripped.startswith("#"): + continue + lines.append(stripped) + return "\n".join(lines) + + +def _hash_source(normalized: str) -> str: + """Compute a stable MD5 hash of the normalized kernel source.""" + return hashlib.md5(normalized.encode("utf-8")).hexdigest() + + +def dedup_kernels(input_dir: Path) -> dict: + """Walk *input_dir* for ``triton_kernel/*.py`` files and compute dedup stats. + + Parameters + ---------- + input_dir: + Root directory containing per-sample subdirectories. Each sample + directory is expected to have a ``triton_kernel/`` subdirectory + with ``.py`` files produced by the ``extract`` pipeline step. + + Returns + ------- + dict with keys: + - ``total_samples``: number of sample directories scanned + - ``total_kernel_instances``: total ``.py`` files found + - ``unique_kernel_hashes``: number of distinct (normalized) source hashes + - ``dedup_rate_percent``: ``(1 - unique / total) * 100`` + - ``avg_kernels_per_sample``: mean kernel count per sample + - ``kernel_name_freq``: ``{kernel_name: occurrence_count}`` across all samples + - ``per_sample``: list of ``{path, kernel_count, hashes}`` per sample + """ + if not input_dir.is_dir(): + logger.error("Input directory does not exist: %s", input_dir) + return {} + + # Enumerate sample directories (any dir containing triton_kernel/). + samples: list[Path] = [] + for dirpath, dirnames, _ in os.walk(str(input_dir)): + if "triton_kernel" in dirnames: + samples.append(Path(dirpath)) + + if not samples: + logger.warning("No samples with triton_kernel/ found under %s", input_dir) + return {} + + all_hashes: list[str] = [] + kernel_name_counter: Counter[str] = Counter() + per_sample: list[dict] = [] + + for sample_dir in sorted(samples): + kernel_dir = sample_dir / "triton_kernel" + kernel_files = sorted(kernel_dir.glob("*.py")) + hashes: list[str] = [] + for kf in kernel_files: + try: + source = kf.read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read kernel file: %s", kf) + continue + normalized = _normalize_kernel_source(source) + h = _hash_source(normalized) + hashes.append(h) + all_hashes.append(h) + kernel_name_counter[kf.stem] += 1 + + per_sample.append( + { + "sample": str(sample_dir), + "kernel_count": len(hashes), + "hashes": hashes, + } + ) + + total = len(all_hashes) + unique = len(set(all_hashes)) + dedup_rate = round((1 - unique / total) * 100, 2) if total > 0 else 0.0 + avg = round(total / len(samples), 2) if samples else 0.0 + + return { + "total_samples": len(samples), + "total_kernel_instances": total, + "unique_kernel_hashes": unique, + "dedup_rate_percent": dedup_rate, + "avg_kernels_per_sample": avg, + "kernel_name_freq": dict(kernel_name_counter.most_common()), + "per_sample": per_sample, + } + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Deduplicate Triton kernels by source content.", + ) + parser.add_argument( + "--input-dir", + type=Path, + required=True, + help=( + "Root directory containing per-sample subdirectories with " + "triton_kernel/*.py files (output of the extract pipeline step)." + ), + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output JSON path for the dedup report.", + ) + args = parser.parse_args(argv) + + logging.basicConfig( + format="%(message)s", + level=logging.INFO, + ) + + logger.info("Scanning: %s", args.input_dir) + report = dedup_kernels(args.input_dir) + if not report: + logger.warning("No data to report.") + return + + # Print summary. + print("\n=== Kernel Dedup Report ===") + print(f"Total samples scanned: {report['total_samples']}") + print(f"Total kernel instances: {report['total_kernel_instances']}") + print(f"Unique kernel hashes: {report['unique_kernel_hashes']}") + print(f"Dedup rate: {report['dedup_rate_percent']}%") + print(f"Avg kernels per sample: {report['avg_kernels_per_sample']}") + print("\nTop kernel types by frequency:") + for name, count in list(report["kernel_name_freq"].items())[:20]: + bar = "#" * min(count, 40) + print(f" {name:50s} {count:4d} {bar}") + + # Write JSON. + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text( + json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8" + ) + print(f"\nReport saved to: {args.output}") + + +if __name__ == "__main__": + main()