diff --git a/graph_net/torch/extractor.py b/graph_net/torch/extractor.py index 568ad995a..fe87ba332 100644 --- a/graph_net/torch/extractor.py +++ b/graph_net/torch/extractor.py @@ -49,6 +49,7 @@ def __init__( mut_graph_codes=None, placeholder_auto_rename=False, workspace_path=None, + param_buffer_ids=None, ): self.subgraph_counter = 0 self.name = name @@ -64,6 +65,7 @@ def __init__( raise EnvironmentError( "Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set." ) + self.param_buffer_ids = param_buffer_ids or set() def move_files(self, source_dir, target_dir): os.makedirs(target_dir, exist_ok=True) @@ -150,7 +152,16 @@ def try_rename_placeholder(node): self.mut_graph_codes.append(base_code) # 4. Save tensor metadata - converted = utils.convert_state_and_inputs(params, []) + # Separate model weights (parameters + buffers) from real inputs in params + weights = {} + example_inputs = {} + for name, value in params.items(): + if id(value) in self.param_buffer_ids: + weights[name] = value + else: + example_inputs[name] = value + + converted = utils.convert_state_and_inputs(weights, example_inputs) utils.save_converted_to_text(converted, file_path=subgraph_path) utils.save_constraints_text( converted, @@ -280,9 +291,24 @@ def wrapper(model: torch.nn.Module): model_path = None if hasattr(model, "__graph_net_file_path__"): model_path = os.path.dirname(model.__graph_net_file_path__) - extractor = get_graph_extractor_maker(model_path)( - name, dynamic, mut_graph_codes, placeholder_auto_rename - ) + + # Collect parameter and buffer ids from the original model for distinguishing weights and inputs in __call__ + param_buffer_ids = set() + for _, p in model.named_parameters(): + param_buffer_ids.add(id(p)) + for _, b in model.named_buffers(): + param_buffer_ids.add(id(b)) + + maker = get_graph_extractor_maker(model_path) + if maker is GraphExtractor: + extractor = maker( + name, dynamic, mut_graph_codes, placeholder_auto_rename, + param_buffer_ids=param_buffer_ids, + ) + else: + extractor = maker( + name, dynamic, mut_graph_codes, placeholder_auto_rename + ) # return torch.compile(backend=extractor, dynamic=dynamic) compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic) return compiled_model diff --git a/graph_net/torch/utils.py b/graph_net/torch/utils.py index 658ade37f..efc41aeac 100644 --- a/graph_net/torch/utils.py +++ b/graph_net/torch/utils.py @@ -75,6 +75,8 @@ def process_tensor(tensor): processed_inputs = process_tensor(example_inputs) elif isinstance(example_inputs, (list, tuple)): processed_inputs = [process_tensor(t) for t in example_inputs] + elif isinstance(example_inputs, dict): + processed_inputs = {k: process_tensor(v) for k, v in example_inputs.items()} else: processed_inputs = {"type": "unknown", "value": example_inputs} @@ -181,13 +183,28 @@ def process_tensor_info(tensor_info, name_prefix="example_input"): return lines input_infos = converted["input_info"] - if isinstance(input_infos, dict): - input_infos = [input_infos] input_lines = [] - for idx, input_info in enumerate(input_infos): - input_info["name"] = f"input_{idx}" - input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) + if isinstance(input_infos, dict) and input_infos: + # Check if it's a dict of named tensor infos (e.g., placeholder inputs) + first_val = next(iter(input_infos.values())) + if isinstance(first_val, dict) and "type" in first_val: + # Named inputs: {name: tensor_info} + for name, input_info in input_infos.items(): + input_info["name"] = name + input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) + else: + # Single input info dict (e.g., a single tensor's info) + input_infos = [input_infos] + for idx, input_info in enumerate(input_infos): + input_info["name"] = f"input_{idx}" + input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) + else: + if isinstance(input_infos, dict): + input_infos = [input_infos] + for idx, input_info in enumerate(input_infos): + input_info["name"] = f"input_{idx}" + input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input")) with open(f"{file_path}/input_meta.py", "w") as f: f.write("\n".join(input_lines))