diff --git a/tools/triton_kernel_extractor/cache_pruner.py b/tools/triton_kernel_extractor/cache_pruner.py new file mode 100644 index 000000000..b9104b4bb --- /dev/null +++ b/tools/triton_kernel_extractor/cache_pruner.py @@ -0,0 +1,165 @@ +"""Prune large TorchInductor cache artifacts while preserving later workflows.""" + +from __future__ import annotations + +import json +import logging +import re +from pathlib import Path + +from .config import is_sample_dir + +logger = logging.getLogger(__name__) +_MAKE_RANGE_END_PATTERN = re.compile(r"tt\.make_range\s*\{end = (\d+)") + +# Files needed to keep later extraction and analysis working. +_KEEP_FILENAMES = { + "test_compiler_log.log", + "graph_hash.txt", + "model.py", + "output_code.py", +} + + +def _should_keep_file(path: Path) -> bool: + """Return True if *path* is required by later extract/analyze steps.""" + name = path.name + if name in _KEEP_FILENAMES: + return True + if name.endswith(".best_config"): + return True + if name.endswith(".ptx"): + return True + if name.endswith(".json"): + return True + return False + + +def _write_pruned_triton_metadata(sample_dir: Path) -> None: + """Persist compact metadata needed after deleting bulky Triton sources.""" + for source_path in sample_dir.rglob("*.source"): + ptx_path = source_path.with_suffix(".ptx") + if not ptx_path.is_file(): + continue + try: + content = source_path.read_text(encoding="utf-8", errors="replace") + except OSError: + continue + block_values = [int(m) for m in _MAKE_RANGE_END_PATTERN.findall(content)] + if not block_values: + continue + meta_path = source_path.with_suffix(".pruned_meta.json") + try: + meta_path.write_text( + json.dumps({"block_values": block_values}, separators=(",", ":")), + encoding="utf-8", + ) + except OSError: + logger.debug("Cannot write pruned Triton metadata: %s", meta_path) + + +def _looks_like_sample_cache_dir(directory: Path) -> bool: + """Heuristic for directories that contain a compiled sample cache.""" + if not directory.is_dir(): + return False + if (directory / "test_compiler_log.log").is_file(): + return True + if (directory / "original_graph").is_dir(): + return True + if (directory / "triton").is_dir(): + return True + return False + + +def _prune_one_sample(sample_dir: Path) -> tuple[int, int]: + """Delete unneeded files from a single sample cache directory.""" + removed_files = 0 + removed_dirs = 0 + + _write_pruned_triton_metadata(sample_dir) + + for path in sorted(sample_dir.rglob("*"), reverse=True): + if path.is_file(): + if not _should_keep_file(path): + try: + path.unlink() + removed_files += 1 + except FileNotFoundError: + pass + + # Remove now-empty directories bottom-up, but keep the sample root itself. + for path in sorted(sample_dir.rglob("*"), reverse=True): + if path.is_dir(): + try: + path.rmdir() + removed_dirs += 1 + except OSError: + pass + + return removed_files, removed_dirs + + +def prune_compilation_cache(cache_dir: Path) -> tuple[int, int]: + """Prune large intermediate artifacts from compiled sample caches. + + This keeps only the files required for later extraction and analysis: + ``test_compiler_log.log``, ``original_graph/{model.py,graph_hash.txt}``, + ``output_code.py``, ``*.best_config``, ``*.ptx``, and lightweight Triton + metadata files (``*.json`` / ``*.pruned_meta.json``) used to disambiguate + autotuning candidates when ``triton_cache_hash`` is unavailable. + """ + if not cache_dir.is_dir(): + logger.warning("Cache directory does not exist, skipping: %s", cache_dir) + return 0, 0 + + total_files = 0 + total_dirs = 0 + + # Support pruning a single sample cache directory directly. + if _looks_like_sample_cache_dir(cache_dir): + removed_files, removed_dirs = _prune_one_sample(cache_dir) + logger.info( + "Pruned %s: -%d files, -%d dirs", + cache_dir, + removed_files, + removed_dirs, + ) + logger.info( + "Prune complete: removed %d files and %d directories under %s", + removed_files, + removed_dirs, + cache_dir, + ) + return removed_files, removed_dirs + + # Prune sample caches in the root, kept/, and discarded/ trees. + roots = [cache_dir, cache_dir / "kept", cache_dir / "discarded"] + for root in roots: + if not root.is_dir(): + continue + for child in sorted(root.iterdir()): + if not child.is_dir(): + continue + if root == cache_dir and not _looks_like_sample_cache_dir(child): + continue + if root != cache_dir and not is_sample_dir(child.name): + continue + + removed_files, removed_dirs = _prune_one_sample(child) + if removed_files or removed_dirs: + logger.info( + "Pruned %s: -%d files, -%d dirs", + child.relative_to(cache_dir), + removed_files, + removed_dirs, + ) + total_files += removed_files + total_dirs += removed_dirs + + logger.info( + "Prune complete: removed %d files and %d directories under %s", + total_files, + total_dirs, + cache_dir, + ) + return total_files, total_dirs diff --git a/tools/triton_kernel_extractor/compiler.py b/tools/triton_kernel_extractor/compiler.py index a97f44654..bda4dd981 100644 --- a/tools/triton_kernel_extractor/compiler.py +++ b/tools/triton_kernel_extractor/compiler.py @@ -12,6 +12,7 @@ from concurrent.futures import Future, ProcessPoolExecutor, as_completed from pathlib import Path +from .cache_pruner import prune_compilation_cache from .config import PipelineConfig from .sample_enumerator import compute_unique_dir @@ -63,6 +64,10 @@ def _compile_one_sample( """Compile a single graph sample on a specific GPU. Returns one of ``"compiled"``, ``"skipped"``, or ``"failed"``. + + After each compile attempt, prune the sample-local TorchInductor cache + immediately so that large intermediate artifacts do not accumulate across + many samples. """ unique_dir = compute_unique_dir(sample_path, graph_dir) @@ -124,6 +129,8 @@ def _compile_one_sample( else: shutil.copy2(str(item), str(dest)) + prune_compilation_cache(sample_cache_dir) + if result.returncode != 0: return "failed" return "compiled" diff --git a/tools/triton_kernel_extractor/kernel_extractor.py b/tools/triton_kernel_extractor/kernel_extractor.py index 5e7a34691..4c04dd0c1 100644 --- a/tools/triton_kernel_extractor/kernel_extractor.py +++ b/tools/triton_kernel_extractor/kernel_extractor.py @@ -7,13 +7,16 @@ import re import shutil from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) # Compiled regex that replaces the original perl one-liner: # # perl -0777 -ne ' -# while (/async_compile\.triton\(\x27([^\x27]+)\x27,\s*\x27\x27\x27(.*?)\x27\x27\x27/gs) { +# while ( +# /async_compile\.triton\(\x27([^\x27]+)\x27,\s*\x27\x27\x27(.*?)\x27\x27\x27/gs +# ) { # print "===KERNEL_NAME===$1\n$2\n===KERNEL_END===\n"; # }' # @@ -22,35 +25,96 @@ r"async_compile\.triton\('([^']+)',\s*'''(.*?)'''", re.DOTALL, ) +_KERNEL_PATH_PATTERN = re.compile(r"^# kernel path:\s*(.+)$", re.MULTILINE) +_MAKE_RANGE_END_PATTERN = re.compile(r"tt\.make_range\s*\{end = (\d+)") +_BLOCK_KEY_ORDER = ("XBLOCK", "YBLOCK", "ZBLOCK", "RBLOCK") -def _collect_best_config_hashes(sample_cache_dir: Path) -> set[str]: - """Gather all autotuning-selected cache hashes from a sample directory. +def _collect_best_configs( + sample_cache_dir: Path, +) -> tuple[set[str], dict[str, dict[str, Any]]]: + """Gather autotuning-selected configs from a sample directory. TorchInductor writes ``.best_config`` JSON files (one per autotuned kernel) - in 2-char prefix subdirectories of the sample cache. Each file contains a - ``triton_cache_hash`` field identifying the winning configuration among - multiple compiled candidates in ``triton/0/``. + in 2-char prefix subdirectories of the sample cache. Older versions expose + a ``triton_cache_hash`` field identifying the winning candidate under + ``triton/0/``. Newer TorchInductor/Triton combinations may leave that field + as ``null``; in that case the prefix directory from ``# kernel path:`` in + ``output_code.py`` is used to find the matching best-config and compare the + recorded tuning parameters against candidate metadata. This function is called once per sample and the result is reused for every kernel in that sample. """ hashes: set[str] = set() + by_prefix: dict[str, dict[str, Any]] = {} for bc_path in sample_cache_dir.rglob("*.best_config"): try: data = json.loads(bc_path.read_text(encoding="utf-8")) cache_hash = data.get("triton_cache_hash") if cache_hash: hashes.add(cache_hash) + by_prefix[bc_path.parent.name] = data except (OSError, json.JSONDecodeError): logger.debug("Skipping malformed .best_config: %s", bc_path) - return hashes + return hashes, by_prefix + + +def _extract_block_values_from_source(source_path: Path) -> list[int]: + """Extract Triton block dimensions from compact metadata or ``*.source``.""" + meta_path = source_path.with_suffix(".pruned_meta.json") + if meta_path.is_file(): + try: + data = json.loads(meta_path.read_text(encoding="utf-8")) + block_values = data.get("block_values") + if isinstance(block_values, list): + return [int(v) for v in block_values] + except (OSError, json.JSONDecodeError, TypeError, ValueError): + return [] + + try: + content = source_path.read_text(encoding="utf-8", errors="replace") + except OSError: + return [] + return [int(m) for m in _MAKE_RANGE_END_PATTERN.findall(content)] + + +def _candidate_matches_best_config(ptx_path: Path, best_config: dict[str, Any]) -> bool: + """Return True when a PTX candidate matches a ``.best_config`` record.""" + json_path = ptx_path.with_suffix(".json") + if json_path.is_file(): + try: + metadata = json.loads(json_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + metadata = {} + else: + metadata = {} + + for key in ("num_warps", "num_stages"): + if key in best_config and key in metadata and metadata[key] != best_config[key]: + return False + + block_values = _extract_block_values_from_source(ptx_path.with_suffix(".source")) + best_block_keys = [key for key in _BLOCK_KEY_ORDER if key in best_config] + if best_block_keys: + if len(block_values) < len(best_block_keys): + return False + for idx, key in enumerate(best_block_keys): + if block_values[idx] != best_config[key]: + return False + + # If only some metadata files survived, accept the candidate only when at + # least one meaningful field was available and matched. + return bool( + any(key in metadata for key in ("num_warps", "num_stages")) or block_values + ) def _find_best_ptx( sample_cache_dir: Path, kernel_name: str, best_hashes: set[str], + best_config: dict[str, Any] | None = None, ) -> str | None: """Locate the corresponding PTX for a given kernel via autotuning results. @@ -62,8 +126,9 @@ def _find_best_ptx( Resolution strategy: - 0 candidates → return ``None`` (no PTX compiled for this kernel). - 1 candidate → return its PTX (no disambiguation needed). - - N candidates → intersect directory names with *best_hashes*; the match - identifies the autotuning winner. + - N candidates → first intersect directory names with *best_hashes*. + - If no hash is available, compare candidate ``*.json`` / ``*.source`` + metadata against the kernel's ``.best_config``. """ triton_base = sample_cache_dir / "triton" / "0" if not triton_base.is_dir(): @@ -101,7 +166,27 @@ def _find_best_ptx( logger.warning("Cannot read PTX file: %s", ptx_path) return None - # Fallback: no .best_config match (should not happen based on validation). + if best_config: + metadata_matches = [ + ptx_path + for ptx_path in candidates + if _candidate_matches_best_config(ptx_path, best_config) + ] + if len(metadata_matches) == 1: + try: + return metadata_matches[0].read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read PTX file: %s", metadata_matches[0]) + return None + if len(metadata_matches) > 1: + logger.warning( + "Multiple metadata-matched PTX candidates for %s in %s", + kernel_name, + sample_cache_dir.name, + ) + return None + + # Fallback: no .best_config match. logger.warning( "Multiple PTX candidates for %s but no .best_config match in %s", kernel_name, @@ -112,8 +197,8 @@ def _find_best_ptx( def extract_kernels_from_file( output_code_path: Path, -) -> list[tuple[str, str]]: - """Parse an ``output_code.py`` and return ``(name, source)`` pairs. +) -> list[tuple[str, str, str | None]]: + """Parse an ``output_code.py`` and return ``(name, source, prefix)`` tuples. The file is read entirely into memory (``output_code.py`` files produced by TorchInductor are typically well under 1 MB). Returns an empty list if the @@ -124,14 +209,24 @@ def extract_kernels_from_file( except OSError: logger.warning("Cannot read output_code.py: %s", output_code_path) return [] - return _TRITON_KERNEL_PATTERN.findall(content) + + kernels: list[tuple[str, str, str | None]] = [] + kernel_path_matches = list(_KERNEL_PATH_PATTERN.finditer(content)) + for match in _TRITON_KERNEL_PATTERN.finditer(content): + prefix: str | None = None + preceding_paths = [m for m in kernel_path_matches if m.start() < match.start()] + if preceding_paths: + kernel_path = preceding_paths[-1].group(1).strip() + prefix = Path(kernel_path).parent.name or None + kernels.append((match.group(1), match.group(2), prefix)) + return kernels def extract_triton_kernels( cache_dir: Path, output_dir: Path, ) -> tuple[int, int, int, int, int]: - """Walk kept samples, extract autotuning-selected triton kernels and corresponding PTX. + """Walk kept samples and extract selected triton kernels with PTX. For every kept sample that contains ``original_graph/graph_hash.txt``: @@ -197,8 +292,8 @@ def extract_triton_kernels( og_dir.mkdir() shutil.copy2(str(model_src), str(og_dir / "model.py")) - # Pre-collect autotuning best-config hashes once per sample. - best_hashes = _collect_best_config_hashes(sample_cache_dir) + # Pre-collect autotuning best-config data once per sample. + best_hashes, best_configs_by_prefix = _collect_best_configs(sample_cache_dir) # Track kernel names already written for this sample to detect # duplicates across multiple output_code.py files. @@ -214,7 +309,7 @@ def extract_triton_kernels( triton_dir = tmp_dir / "triton_kernel" triton_dir.mkdir(exist_ok=True) - for name, source in kernels: + for name, source, kernel_prefix in kernels: if name in seen_kernels: logger.debug( "Duplicate kernel %s in %s, skipping", name, sample_name @@ -229,7 +324,17 @@ def extract_triton_kernels( total_kernels += 1 # Locate and write the corresponding PTX for this kernel. - ptx_content = _find_best_ptx(sample_cache_dir, name, best_hashes) + best_config = ( + best_configs_by_prefix.get(kernel_prefix) + if kernel_prefix is not None + else None + ) + ptx_content = _find_best_ptx( + sample_cache_dir, + name, + best_hashes, + best_config=best_config, + ) if ptx_content is not None: ptx_dir = tmp_dir / "ptx" ptx_dir.mkdir(exist_ok=True)