From bb404cb6fec6cd11378e6e04d3aa21f4380b5065 Mon Sep 17 00:00:00 2001 From: xjtu-L <2701938983@qq.com> Date: Mon, 18 May 2026 03:18:57 +0000 Subject: [PATCH] fix: separate model weights from inputs in extraction output Previously, `extractor.py` passed all placeholder params (including lifted parameters, buffers, SymInt symbols, and real inputs) to `convert_state_and_inputs(params, [])`, causing all of them to be written to `weight_meta.py`. This led to SymInt placeholders like `Program_weight_tensor_meta_s0` appearing in `weight_meta.py`, which are not real model weights. This change: - Collects `id()` of original model parameters and buffers in `wrapper()` - Passes them to `GraphExtractor` via `param_buffer_ids` - Splits `params` into `weights` (params/buffers by identity) and `example_inputs` (real inputs including SymInt) - Writes weights to `weight_meta.py` and inputs to `input_meta.py` - Updates `utils.py` to handle dict-style `example_inputs` --- graph_net/torch/extractor.py | 34 ++++++++++++++++++++++++++++++---- graph_net/torch/utils.py | 27 ++++++++++++++++++++++----- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/graph_net/torch/extractor.py b/graph_net/torch/extractor.py index 568ad995ad..fe87ba3325 100644 --- a/graph_net/torch/extractor.py +++ b/graph_net/torch/extractor.py @@ -49,6 +49,7 @@ def __init__( mut_graph_codes=None, placeholder_auto_rename=False, workspace_path=None, + param_buffer_ids=None, ): self.subgraph_counter = 0 self.name = name @@ -64,6 +65,7 @@ def __init__( raise EnvironmentError( "Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set." ) + self.param_buffer_ids = param_buffer_ids or set() def move_files(self, source_dir, target_dir): os.makedirs(target_dir, exist_ok=True) @@ -150,7 +152,16 @@ def try_rename_placeholder(node): self.mut_graph_codes.append(base_code) # 4. Save tensor metadata - converted = utils.convert_state_and_inputs(params, []) + # Separate model weights (parameters + buffers) from real inputs in params + weights = {} + example_inputs = {} + for name, value in params.items(): + if id(value) in self.param_buffer_ids: + weights[name] = value + else: + example_inputs[name] = value + + converted = utils.convert_state_and_inputs(weights, example_inputs) utils.save_converted_to_text(converted, file_path=subgraph_path) utils.save_constraints_text( converted, @@ -280,9 +291,24 @@ def wrapper(model: torch.nn.Module): model_path = None if hasattr(model, "__graph_net_file_path__"): model_path = os.path.dirname(model.__graph_net_file_path__) - extractor = get_graph_extractor_maker(model_path)( - name, dynamic, mut_graph_codes, placeholder_auto_rename - ) + + # Collect parameter and buffer ids from the original model for distinguishing weights and inputs in __call__ + param_buffer_ids = set() + for _, p in model.named_parameters(): + param_buffer_ids.add(id(p)) + for _, b in model.named_buffers(): + param_buffer_ids.add(id(b)) + + maker = get_graph_extractor_maker(model_path) + if maker is GraphExtractor: + extractor = maker( + name, dynamic, mut_graph_codes, placeholder_auto_rename, + param_buffer_ids=param_buffer_ids, + ) + else: + extractor = maker( + name, dynamic, mut_graph_codes, placeholder_auto_rename + ) # return torch.compile(backend=extractor, dynamic=dynamic) compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic) return compiled_model diff --git a/graph_net/torch/utils.py b/graph_net/torch/utils.py index 658ade37fe..efc41aeac6 100644 --- a/graph_net/torch/utils.py +++ b/graph_net/torch/utils.py @@ -75,6 +75,8 @@ def process_tensor(tensor): processed_inputs = process_tensor(example_inputs) elif isinstance(example_inputs, (list, tuple)): processed_inputs = [process_tensor(t) for t in example_inputs] + elif isinstance(example_inputs, dict): + processed_inputs = {k: process_tensor(v) for k, v in example_inputs.items()} else: processed_inputs = {"type": "unknown", "value": example_inputs} @@ -181,13 +183,28 @@ def process_tensor_info(tensor_info, name_prefix="example_input"): return lines input_infos = converted["input_info"] - if isinstance(input_infos, dict): - input_infos = [input_infos] input_lines = [] - for idx, input_info in enumerate(input_infos): - input_info["name"] = f"input_{idx}" - input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) + if isinstance(input_infos, dict) and input_infos: + # Check if it's a dict of named tensor infos (e.g., placeholder inputs) + first_val = next(iter(input_infos.values())) + if isinstance(first_val, dict) and "type" in first_val: + # Named inputs: {name: tensor_info} + for name, input_info in input_infos.items(): + input_info["name"] = name + input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) + else: + # Single input info dict (e.g., a single tensor's info) + input_infos = [input_infos] + for idx, input_info in enumerate(input_infos): + input_info["name"] = f"input_{idx}" + input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) + else: + if isinstance(input_infos, dict): + input_infos = [input_infos] + for idx, input_info in enumerate(input_infos): + input_info["name"] = f"input_{idx}" + input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) with open(f"{file_path}/input_meta.py", "w") as f: f.write("\n".join(input_lines))