From d20497e93455110b75c0afa3c5fac3035f963035 Mon Sep 17 00:00:00 2001 From: Kyle Romero Date: Thu, 25 Jun 2026 14:44:08 -0700 Subject: [PATCH] Add QairtModelHandler input support to QairtPipelinePass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow QairtPipelinePass to accept a QairtModelHandler as input in addition to HfModelHandler. This enables chaining two pipeline passes in an Olive recipe — the first pass exports an HF checkpoint via save_pretrained, and the second pass loads from that checkpoint as its model_id_or_path. When input is QairtModelHandler: model.model_path is used directly as model_id_or_path; the recipe must not also specify model_id_or_path (raises ValueError if it does); config.json and generation_config.json are copied from model.model_path rather than via snapshot_download. --- olive/passes/qairt/pipeline.py | 93 ++++++++++++++++--------- test/passes/qairt/test_pipeline_pass.py | 74 +++++++++++++++++++- 2 files changed, 133 insertions(+), 34 deletions(-) diff --git a/olive/passes/qairt/pipeline.py b/olive/passes/qairt/pipeline.py index ca630774e..279e10072 100644 --- a/olive/passes/qairt/pipeline.py +++ b/olive/passes/qairt/pipeline.py @@ -6,6 +6,7 @@ import logging import shutil from pathlib import Path +from typing import Union from olive.common.config_utils import ParamCategory from olive.hardware.accelerator import AcceleratorSpec @@ -18,16 +19,21 @@ class QairtPipelinePass(Pass): - """Run a QairtPipeline from a YAML recipe on a HuggingFace model. + """Run a QairtPipeline from a YAML recipe on a HuggingFace or previously-built QAIRT model. Executes the full LLMPipeline workflow (model loading, quantization, compilation) defined by the recipe and exports the result as a QairtModelHandler. This pass is intended to replace the QairtPreparation -> QairtGenAIBuilder workflow. - The input HfModelHandler is the authoritative source for the model identity. - If the recipe also specifies model_id_or_path and it differs from the handler's - path, an error is raised. If the recipe omits model_id_or_path, the handler's - path is used. + Accepts either an HfModelHandler or a QairtModelHandler as input: + + - HfModelHandler: the handler's model_path is the authoritative model identity. + If the recipe also specifies model_id_or_path and it matches, no error is raised; + if it conflicts, a ValueError is raised. + + - QairtModelHandler: used when chaining two QairtPipelinePass instances. The + handler's model_path (an HF checkpoint exported by the upstream pass) is injected + as model_id_or_path. The recipe must NOT specify model_id_or_path in this case. """ @classmethod @@ -78,7 +84,7 @@ def validate_config( def _run_for_config( self, - model: HfModelHandler, + model: Union[HfModelHandler, QairtModelHandler], config: type[BasePassConfig], output_model_path: str, ) -> QairtModelHandler: @@ -92,8 +98,15 @@ def _run_for_config( "If already installed, please run `qairt-vm -i` for help troubleshooting issues." ) from exc - if not isinstance(model, HfModelHandler): - raise ValueError(f"QairtPipelinePass requires HfModelHandler as input, got {type(model).__name__}") + if isinstance(model, HfModelHandler): + model_id = model.model_path + elif isinstance(model, QairtModelHandler): + model_id = model.model_path + else: + raise ValueError( + f"QairtPipelinePass requires HfModelHandler or QairtModelHandler as input, " + f"got {type(model).__name__}" + ) recipe_path = Path(config.recipe).resolve() if not recipe_path.exists(): @@ -102,42 +115,58 @@ def _run_for_config( recipe_data = dict(Recipe.from_file(recipe_path)) recipe_model_id = recipe_data.get("model_id_or_path") - if recipe_model_id and recipe_model_id != model.model_path: - raise ValueError( - f"Conflict between recipe model_id_or_path '{recipe_model_id}' and input model " - f"path '{model.model_path}'. Remove model_id_or_path from the recipe or ensure " - "it matches the input model path." - ) + if isinstance(model, HfModelHandler): + if recipe_model_id and recipe_model_id != model_id: + raise ValueError( + f"Conflict between recipe model_id_or_path '{recipe_model_id}' and input model " + f"path '{model_id}'. Remove model_id_or_path from the recipe or ensure " + "it matches the input model path." + ) + else: + # QairtModelHandler: model_path is already the resolved local checkpoint; + # the recipe must not override it. + if recipe_model_id: + raise ValueError( + f"Recipe specifies model_id_or_path '{recipe_model_id}', but input is a " + f"QairtModelHandler with model_path '{model_id}'. Remove model_id_or_path " + "from the recipe when chaining QairtPipelinePass instances." + ) if config.cache_dir is not None: recipe_data["cache_dir"] = config.cache_dir if config.log_level is not None: recipe_data["log_level"] = config.log_level - pipe = LLMPipeline.from_pretrained(model.model_path, recipe=recipe_data) + pipe = LLMPipeline.from_pretrained(model_id, recipe=recipe_data) pipe.construct() Path(output_model_path).mkdir(parents=True, exist_ok=True) pipe.export(output_model_path) # QairtEncapsulation needs config.json and generation_config.json to generate - # genai_config.json. Resolve the local HF cache path (model.model_path may be a - # HuggingFace repo ID rather than a local directory) and copy if not already present. - try: - from huggingface_hub import snapshot_download - - local_model_path = snapshot_download( - model.model_path, - local_files_only=True, - ignore_patterns=["*.pt", "*.bin", "*.safetensors"], - ) - except Exception as e: - logger.warning( - "Failed to resolve local HF cache for '%s': %s. File copy will be skipped.", - model.model_path, - e, - ) - local_model_path = model.model_path + # genai_config.json. Resolve the source directory and copy files if not already present. + if isinstance(model, HfModelHandler): + # model_path may be a HuggingFace repo ID rather than a local directory; + # use snapshot_download to get the local cache path. + try: + from huggingface_hub import snapshot_download + + local_model_path = snapshot_download( + model_id, + local_files_only=True, + ignore_patterns=["*.pt", "*.bin", "*.safetensors"], + ) + except Exception as e: + logger.warning( + "Failed to resolve local HF cache for '%s': %s. File copy will be skipped.", + model_id, + e, + ) + local_model_path = model_id + else: + # QairtModelHandler.model_path is an HF checkpoint written by save_pretrained; + # it already contains config.json / generation_config.json locally. + local_model_path = model_id for fname in ("config.json", "generation_config.json"): src = Path(local_model_path) / fname diff --git a/test/passes/qairt/test_pipeline_pass.py b/test/passes/qairt/test_pipeline_pass.py index 09ab4c096..6518291c1 100644 --- a/test/passes/qairt/test_pipeline_pass.py +++ b/test/passes/qairt/test_pipeline_pass.py @@ -197,7 +197,7 @@ def test_pipeline_pass_log_level_override(tmp_path, mock_hf_model, recipe_file, def test_pipeline_pass_invalid_input_model(tmp_path, mock_qairt_prepared_model, recipe_file, mock_pipeline_modules): - """Test that ValueError is raised when input is not HfModelHandler.""" + """Test that ValueError is raised when input is not HfModelHandler or QairtModelHandler.""" output_path = tmp_path / "output" mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {}} @@ -208,7 +208,7 @@ def test_pipeline_pass_invalid_input_model(tmp_path, mock_qairt_prepared_model, disable_search=True, ) - with pytest.raises(ValueError, match="QairtPipelinePass requires HfModelHandler"): + with pytest.raises(ValueError, match="QairtPipelinePass requires HfModelHandler or QairtModelHandler"): pipeline_pass.run(mock_qairt_prepared_model, str(output_path)) @@ -247,3 +247,73 @@ def import_side_effect(name, *args, **kwargs): with pytest.raises(ImportError, match="Failed to import QAIRT Pipeline API"): pipeline_pass.run(mock_hf_model, str(tmp_path / "output")) + + +def test_pipeline_pass_qairt_model_handler_success(tmp_path, mock_qairt_model, recipe_file, mock_pipeline_modules): + """Test successful pass execution with QairtModelHandler as input.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "cache_dir": "./pipeline_cache", + "backend": "HTP", + "stages": {}, + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + result = pipeline_pass.run(mock_qairt_model, str(output_path)) + + assert isinstance(result, QairtModelHandler) + assert result.model_path == str(output_path) + mock_pipeline_modules["LLMPipeline"].from_pretrained.assert_called_once_with( + mock_qairt_model.model_path, + recipe={"cache_dir": "./pipeline_cache", "backend": "HTP", "stages": {}}, + ) + mock_pipeline_modules["pipeline"].construct.assert_called_once() + mock_pipeline_modules["pipeline"].export.assert_called_once_with(str(output_path)) + + +def test_pipeline_pass_qairt_model_handler_config_files_copied( + tmp_path, mock_qairt_model, recipe_file, mock_pipeline_modules +): + """Test that config.json and generation_config.json are copied from QairtModelHandler.model_path.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + + mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {}} + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + pipeline_pass.run(mock_qairt_model, str(output_path)) + + assert (output_path / "config.json").exists() + assert (output_path / "generation_config.json").exists() + + +def test_pipeline_pass_qairt_model_handler_recipe_model_id_raises( + tmp_path, mock_qairt_model, recipe_file, mock_pipeline_modules +): + """Test that ValueError is raised when QairtModelHandler input and recipe specifies model_id_or_path.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "model_id_or_path": "Qwen/Qwen3-7B", + "stages": {}, + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + with pytest.raises(ValueError, match="Remove model_id_or_path from the recipe when chaining"): + pipeline_pass.run(mock_qairt_model, str(output_path))