Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions graph_net/torch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
import json
import shutil
import glob

from graph_net.hash_util import get_sha256_hash

from pathlib import Path
from graph_net.torch import utils
from graph_net.torch.fx_graph_serialize_util import serialize_graph_module_to_str

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.capture_sparse_compute = True
try:
torch._dynamo.config.capture_sparse_compute = True
except AttributeError:
pass
torch._dynamo.config.raise_on_ctx_manager_usage = False
torch._dynamo.config.allow_rnn = True


# used as configuration of python3 -m graph_net.torch.run_model
class RunModelDecorator:
def __init__(self, config):
Expand All @@ -39,8 +44,6 @@ def make_config(
"custom_extractor_config": custom_extractor_config,
},
}


class GraphExtractor:
def __init__(
self,
Expand Down Expand Up @@ -162,6 +165,14 @@ def try_rename_placeholder(node):
with open(os.path.join(subgraph_path, "model.py"), "w") as fp:
fp.write(write_code)

# 5.1 Generate graph_hash.txt from model.py

graph_hash = hashlib.sha256(write_code.encode("utf-8")).hexdigest()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调用hash_util.py中的函数。


graph_hash = get_sha256_hash(write_code)

(Path(subgraph_path) / "graph_hash.txt").write_text(graph_hash)

# 6. Save metadata LAST — graph_net.json serves as the
# completion marker: if it exists, all other files are guaranteed
# to be fully written.
Expand All @@ -180,8 +191,6 @@ def try_rename_placeholder(node):
)

return gm.forward


def extract(
name,
dynamic=True,
Expand Down Expand Up @@ -219,8 +228,6 @@ def extract(

class GraphModule(torch.nn.Module):



def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
l_x_ = L_x_
mul = l_x_ * 2; l_x_ = None
Expand All @@ -246,8 +253,6 @@ def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):

class GraphModule(torch.nn.Module):



def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
l_x_ = L_x_
mul = l_x_ * 2; l_x_ = None
Expand Down Expand Up @@ -304,13 +309,9 @@ def decorator_or_wrapper(obj):
)

return decorator_or_wrapper


def make_extractor_config(extractor_config):
kwargs = extractor_config if extractor_config is not None else {}
return make_extractor_config_impl(**kwargs)


def make_extractor_config_impl(
custom_extractor_path: str = None, custom_extractor_config: dict = None
):
Expand Down
6 changes: 5 additions & 1 deletion graph_net/torch/sample_pass/backward_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ def __call__(self):
module, forward_inputs = get_torch_module_and_inputs(
self.model_path, use_dummy_inputs=False, device=self.device
)
module.train()
module.eval()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval模式下不会生成反向图吧?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.eval() 不会禁用梯度计算,只有 torch.no_grad() / torch.inference_mode() 才会。eval 仅改变特定层的前向行为(dropout → identity,BatchNorm → 用 running stats 而非 batch stats),反向传播完全正常。而且使用 eval 模式反而更好

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不行,反向图生成时,这些算子就应该用train模式

Comment thread
Dayuxiaoshui marked this conversation as resolved.

eval_forward_dir = os.path.join(
self.output_dir, "eval_forward", self.rel_model_path
)
if not os.path.exists(eval_forward_dir):
shutil.copytree(self.model_path, eval_forward_dir)

forward_inputs = [
inp.detach().clone() if isinstance(inp, torch.Tensor) else inp
for inp in forward_inputs
]
forward_inputs = self.set_requires_grad_for_forward_inputs(
self.model_path, module, forward_inputs
)
Expand Down
14 changes: 11 additions & 3 deletions graph_net_bench/torch/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def eager_model_call():
)

torch.manual_seed(runtime_seed)
if not isinstance(expected_out, tuple):
if isinstance(expected_out, list):
expected_out = tuple(expected_out)
elif not isinstance(expected_out, tuple):
expected_out = (expected_out,)
except (TypeError, RuntimeError) as e:
print(f"Eager model execution failed: {str(e)}", file=sys.stderr)
Expand All @@ -288,7 +290,9 @@ def compiled_model_call():
compiled_model_call, args, compiler
)

if not isinstance(compiled_out, tuple):
if isinstance(compiled_out, list):
compiled_out = tuple(compiled_out)
elif not isinstance(compiled_out, tuple):
compiled_out = (compiled_out,)
if args.compiler == "xla":
compiled_out = tuple(item.to("cpu").to("cuda") for item in compiled_out)
Expand Down Expand Up @@ -352,7 +356,11 @@ def _get_output_dtypes(outs):
]

def _align_output_device(outs, device):
return [x.to(device) if x.device != device else x for x in outs]
if isinstance(outs, torch.Tensor):
return outs.to(device) if outs.device != device else outs
if isinstance(outs, (list, tuple)):
return type(outs)(_align_output_device(x, device) for x in outs)
return outs

eager_dtypes = _get_output_dtypes(expected_out)
compiled_dtypes = _get_output_dtypes(compiled_out)
Expand Down
71 changes: 70 additions & 1 deletion tools/triton_kernel_extractor/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,74 @@ def _run_analyze(args: argparse.Namespace) -> None:
analyze_cache(cache_dir, output_dir)


# ---------------------------------------------------------------------------
# Subcommand: dedup
# ---------------------------------------------------------------------------


def _add_dedup_parser(subparsers: argparse._SubParsersAction) -> None:
parser = subparsers.add_parser(
"dedup",
help="Deduplicate extracted Triton kernels by source content.",
description=(
"Walk extracted triton_kernel/*.py files across samples and "
"compute dedup statistics. Two kernels from different subgraphs "
"are considered duplicates when their normalized source code is "
"identical. This is kernel-level dedup (by compiled Triton "
"kernel content), distinct from graph-level dedup via graph_hash.txt."
),
)
parser.add_argument(
"--input-dir",
type=Path,
required=True,
help=(
"Root directory containing per-sample subdirectories with "
"triton_kernel/*.py files (output of the extract pipeline step)."
),
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Output JSON path for the dedup report.",
)
parser.set_defaults(func=_run_dedup)


def _run_dedup(args: argparse.Namespace) -> None:
from .kernel_dedup import dedup_kernels
import json

logger.info("Scanning: %s", args.input_dir)
report = dedup_kernels(args.input_dir)
if not report:
logger.warning("No data to report.")
return

# Print summary.
logger.info("")
logger.info("=== Kernel Dedup Report ===")
logger.info("Total samples scanned: %d", report["total_samples"])
logger.info("Total kernel instances: %d", report["total_kernel_instances"])
logger.info("Unique kernel hashes: %d", report["unique_kernel_hashes"])
logger.info("Dedup rate: %.1f%%", report["dedup_rate_percent"])
logger.info("Avg kernels per sample: %.2f", report["avg_kernels_per_sample"])
logger.info("")
logger.info("Top kernel types by frequency:")
for name, count in list(report["kernel_name_freq"].items())[:20]:
bar = "#" * min(count, 60)
logger.info(" %-50s %4d %s", name, count, bar)

# Write JSON.
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(
json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8"
)
logger.info("")
logger.info("Report saved to: %s", args.output)


# ---------------------------------------------------------------------------
# Backward-compatible argument detection
# ---------------------------------------------------------------------------
Expand All @@ -225,7 +293,7 @@ def _needs_implicit_extract(argv: list[str]) -> bool:
"""
if not argv:
return False
known_subcommands = {"extract", "analyze"}
known_subcommands = {"extract", "analyze", "dedup"}
first = argv[0]
if first in known_subcommands:
return False
Expand Down Expand Up @@ -263,6 +331,7 @@ def main(argv: list[str] | None = None) -> None:
subparsers = parser.add_subparsers(dest="command")
_add_extract_parser(subparsers)
_add_analyze_parser(subparsers)
_add_dedup_parser(subparsers)

args = parser.parse_args(argv)

Expand Down
Loading
Loading