Skip to content

KernelBench Format Support #63

@sandlbn

Description

@sandlbn

Problem

The Fuser outputs kernel_function(...) with arguments that don't have a 1:1 mapping to the original Model's parameters. Static/heuristic mapping fails for:

  • Fused operations (Conv+BN → fused_weight, fused_bias)
  • Reordered arguments
  • Computed intermediates
  • Non-obvious naming conventions

Solution: LLM Mapping Step

Add a new pipeline step that uses an LLM to generate the argument mapping by inspecting:

  1. Original KernelBench problem (Model class, get_inputs, get_init_inputs)
  2. Generated composed kernel (kernel_function signature)
  3. Subgraphs metadata (operation fusion info)

Pipeline

extract → dispatch → compose → map_kernelbench (LLM) → format_kernelbench

LLM Mapping Step

Input to LLM:

  • Original Model code
  • kernel_function signature + implementation
  • Subgraph fusion information

Output from LLM:

{
  "args": [
    {"kernel_arg": "x", "source": "input", "expr": "inputs[0]"},
    {"kernel_arg": "weight", "source": "param", "expr": "model.conv.weight"},
    {"kernel_arg": "fused_bias", "source": "computed", "expr": "fuse_conv_bn_bias(model)"}
  ],
  "helpers_needed": ["fuse_conv_bn_weights"],
  "notes": "BN folded into conv, need to fuse weights at load time"
}

Format Step

Takes the LLM-generated mapping and produces final kernelbench_model.py:

  • Generates ModelNew.__init__ with correct parameter registration
  • Generates forward() that calls kernel_function with mapped args
  • Includes any helper functions (weight fusion, etc.)

Why This Works or it should

  1. LLM understands the transformation - it can read both sides and reason about the correspondence, it will be extremely hard for static method
  2. Handles edge cases - fused weights, reordered args, computed values
  3. Produces explicit mapping - debuggable, auditable
  4. Reuses existing infrastructure - same LLM dispatch pattern as kernel generation

Verification

After generation, run:

original_out = Model(*init_inputs).forward(*inputs)
new_out = ModelNew(*init_inputs).forward(*inputs)
assert torch.allclose(original_out, new_out, rtol=1e-3)

If verification fails, can re-run mapping step with error feedback.

Implementation Scope

  1. kernelbench_mapping.j2 - prompt template for LLM mapping
  2. Fuser/kernelbench_mapper.py - calls LLM, parses mapping JSON
  3. Fuser/kernelbench_formatter.py - applies mapping to generate final code
  4. Pipeline integration via --output-format kernelbench flag

Prompt Template

You are analyzing a Triton kernel generated from a PyTorch model to create an argument mapping.

## Original PyTorch Model

```python
{{ original_model_code }}

Generated Triton Kernel

{{ kernel_function_code }}

Subgraph Info

{{ subgraph_info }}

Task

Analyze how kernel_function arguments map to the original Model's parameters and inputs.

Output a JSON mapping:

{
  "args": [
    {
      "kernel_arg": "<argument name in kernel_function>",
      "source": "input|param|buffer|computed",
      "expr": "<Python expression to get value from Model instance or inputs>"
    }
  ],
  "weight_fusion": {
    "needed": true|false,
    "description": "<what fusion is needed, e.g., Conv+BN weight folding>"
  },
  "forward_inputs": ["<list of forward() parameter names>"]
}

Rules:

  • source: "input" → comes from forward() arguments, expr like x or inputs[0]
  • source: "param" → comes from model parameter, expr like self.conv.weight
  • source: "buffer" → comes from model buffer, expr like self.bn.running_mean
  • source: "computed" → requires computation, expr like self.fused_weight (will be precomputed)

For fused operations (e.g., Conv+BN), identify which weights need to be fused at load time.

Output only valid JSON, no explanation.

---

## Example Usage

```bash
# Full pipeline with KernelBench output
python -m Fuser.pipeline \
    --problem /path/to/kernelbench/level2/problem.py \
    --output-format kernelbench \
    --verify

# Standalone mapping step
python -m Fuser.kernelbench_mapper \
    --problem /path/to/problem.py \
    --composed-kernel /path/to/composed.py \
    --subgraphs /path/to/subgraphs.json \
    --output /path/to/mapping.json

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions