Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 61 additions & 32 deletions olive/passes/qairt/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import shutil
from pathlib import Path
from typing import Union

from olive.common.config_utils import ParamCategory
from olive.hardware.accelerator import AcceleratorSpec
Expand All @@ -18,16 +19,21 @@


class QairtPipelinePass(Pass):
"""Run a QairtPipeline from a YAML recipe on a HuggingFace model.
"""Run a QairtPipeline from a YAML recipe on a HuggingFace or previously-built QAIRT model.

Executes the full LLMPipeline workflow (model loading, quantization, compilation)
defined by the recipe and exports the result as a QairtModelHandler. This pass
is intended to replace the QairtPreparation -> QairtGenAIBuilder workflow.

The input HfModelHandler is the authoritative source for the model identity.
If the recipe also specifies model_id_or_path and it differs from the handler's
path, an error is raised. If the recipe omits model_id_or_path, the handler's
path is used.
Accepts either an HfModelHandler or a QairtModelHandler as input:

- HfModelHandler: the handler's model_path is the authoritative model identity.
If the recipe also specifies model_id_or_path and it matches, no error is raised;
if it conflicts, a ValueError is raised.

- QairtModelHandler: used when chaining two QairtPipelinePass instances. The
handler's model_path (an HF checkpoint exported by the upstream pass) is injected
as model_id_or_path. The recipe must NOT specify model_id_or_path in this case.
"""

@classmethod
Expand Down Expand Up @@ -78,7 +84,7 @@

def _run_for_config(
self,
model: HfModelHandler,
model: Union[HfModelHandler, QairtModelHandler],
config: type[BasePassConfig],
output_model_path: str,
) -> QairtModelHandler:
Expand All @@ -92,8 +98,15 @@
"If already installed, please run `qairt-vm -i` for help troubleshooting issues."
) from exc

if not isinstance(model, HfModelHandler):
raise ValueError(f"QairtPipelinePass requires HfModelHandler as input, got {type(model).__name__}")
if isinstance(model, HfModelHandler):

Check warning

Code scanning / lintrunner

RUFF/SIM114 Warning

Combine if branches using logical or operator.
See https://docs.astral.sh/ruff/rules/if-with-same-arms
model_id = model.model_path
elif isinstance(model, QairtModelHandler):
model_id = model.model_path
else:
raise ValueError(
f"QairtPipelinePass requires HfModelHandler or QairtModelHandler as input, "
f"got {type(model).__name__}"
)

recipe_path = Path(config.recipe).resolve()
if not recipe_path.exists():
Expand All @@ -102,42 +115,58 @@
recipe_data = dict(Recipe.from_file(recipe_path))

recipe_model_id = recipe_data.get("model_id_or_path")
if recipe_model_id and recipe_model_id != model.model_path:
raise ValueError(
f"Conflict between recipe model_id_or_path '{recipe_model_id}' and input model "
f"path '{model.model_path}'. Remove model_id_or_path from the recipe or ensure "
"it matches the input model path."
)
if isinstance(model, HfModelHandler):
if recipe_model_id and recipe_model_id != model_id:
raise ValueError(
f"Conflict between recipe model_id_or_path '{recipe_model_id}' and input model "
f"path '{model_id}'. Remove model_id_or_path from the recipe or ensure "
"it matches the input model path."
)
else:
# QairtModelHandler: model_path is already the resolved local checkpoint;
# the recipe must not override it.
if recipe_model_id:
raise ValueError(
f"Recipe specifies model_id_or_path '{recipe_model_id}', but input is a "
f"QairtModelHandler with model_path '{model_id}'. Remove model_id_or_path "
"from the recipe when chaining QairtPipelinePass instances."
)

if config.cache_dir is not None:
recipe_data["cache_dir"] = config.cache_dir
if config.log_level is not None:
recipe_data["log_level"] = config.log_level

pipe = LLMPipeline.from_pretrained(model.model_path, recipe=recipe_data)
pipe = LLMPipeline.from_pretrained(model_id, recipe=recipe_data)
pipe.construct()

Path(output_model_path).mkdir(parents=True, exist_ok=True)
pipe.export(output_model_path)

# QairtEncapsulation needs config.json and generation_config.json to generate
# genai_config.json. Resolve the local HF cache path (model.model_path may be a
# HuggingFace repo ID rather than a local directory) and copy if not already present.
try:
from huggingface_hub import snapshot_download

local_model_path = snapshot_download(
model.model_path,
local_files_only=True,
ignore_patterns=["*.pt", "*.bin", "*.safetensors"],
)
except Exception as e:
logger.warning(
"Failed to resolve local HF cache for '%s': %s. File copy will be skipped.",
model.model_path,
e,
)
local_model_path = model.model_path
# genai_config.json. Resolve the source directory and copy files if not already present.
if isinstance(model, HfModelHandler):
# model_path may be a HuggingFace repo ID rather than a local directory;
# use snapshot_download to get the local cache path.
try:
from huggingface_hub import snapshot_download

local_model_path = snapshot_download(
model_id,
local_files_only=True,
ignore_patterns=["*.pt", "*.bin", "*.safetensors"],
)
except Exception as e:
logger.warning(
"Failed to resolve local HF cache for '%s': %s. File copy will be skipped.",
model_id,
e,
)
local_model_path = model_id
else:
# QairtModelHandler.model_path is an HF checkpoint written by save_pretrained;
# it already contains config.json / generation_config.json locally.
local_model_path = model_id

for fname in ("config.json", "generation_config.json"):
src = Path(local_model_path) / fname
Expand Down
74 changes: 72 additions & 2 deletions test/passes/qairt/test_pipeline_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_pipeline_pass_log_level_override(tmp_path, mock_hf_model, recipe_file,


def test_pipeline_pass_invalid_input_model(tmp_path, mock_qairt_prepared_model, recipe_file, mock_pipeline_modules):
"""Test that ValueError is raised when input is not HfModelHandler."""
"""Test that ValueError is raised when input is not HfModelHandler or QairtModelHandler."""
output_path = tmp_path / "output"

mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {}}
Expand All @@ -208,7 +208,7 @@ def test_pipeline_pass_invalid_input_model(tmp_path, mock_qairt_prepared_model,
disable_search=True,
)

with pytest.raises(ValueError, match="QairtPipelinePass requires HfModelHandler"):
with pytest.raises(ValueError, match="QairtPipelinePass requires HfModelHandler or QairtModelHandler"):
pipeline_pass.run(mock_qairt_prepared_model, str(output_path))


Expand Down Expand Up @@ -247,3 +247,73 @@ def import_side_effect(name, *args, **kwargs):

with pytest.raises(ImportError, match="Failed to import QAIRT Pipeline API"):
pipeline_pass.run(mock_hf_model, str(tmp_path / "output"))


def test_pipeline_pass_qairt_model_handler_success(tmp_path, mock_qairt_model, recipe_file, mock_pipeline_modules):
"""Test successful pass execution with QairtModelHandler as input."""
output_path = tmp_path / "output"

mock_pipeline_modules["Recipe"].from_file.return_value = {
"cache_dir": "./pipeline_cache",
"backend": "HTP",
"stages": {},
}

pipeline_pass = create_pass_from_dict(
QairtPipelinePass,
{"recipe": str(recipe_file)},
disable_search=True,
)

result = pipeline_pass.run(mock_qairt_model, str(output_path))

assert isinstance(result, QairtModelHandler)
assert result.model_path == str(output_path)
mock_pipeline_modules["LLMPipeline"].from_pretrained.assert_called_once_with(
mock_qairt_model.model_path,
recipe={"cache_dir": "./pipeline_cache", "backend": "HTP", "stages": {}},
)
mock_pipeline_modules["pipeline"].construct.assert_called_once()
mock_pipeline_modules["pipeline"].export.assert_called_once_with(str(output_path))


def test_pipeline_pass_qairt_model_handler_config_files_copied(
tmp_path, mock_qairt_model, recipe_file, mock_pipeline_modules
):
"""Test that config.json and generation_config.json are copied from QairtModelHandler.model_path."""
output_path = tmp_path / "output"
output_path.mkdir(parents=True, exist_ok=True)

mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {}}

pipeline_pass = create_pass_from_dict(
QairtPipelinePass,
{"recipe": str(recipe_file)},
disable_search=True,
)

pipeline_pass.run(mock_qairt_model, str(output_path))

assert (output_path / "config.json").exists()
assert (output_path / "generation_config.json").exists()


def test_pipeline_pass_qairt_model_handler_recipe_model_id_raises(
tmp_path, mock_qairt_model, recipe_file, mock_pipeline_modules
):
"""Test that ValueError is raised when QairtModelHandler input and recipe specifies model_id_or_path."""
output_path = tmp_path / "output"

mock_pipeline_modules["Recipe"].from_file.return_value = {
"model_id_or_path": "Qwen/Qwen3-7B",
"stages": {},
}

pipeline_pass = create_pass_from_dict(
QairtPipelinePass,
{"recipe": str(recipe_file)},
disable_search=True,
)

with pytest.raises(ValueError, match="Remove model_id_or_path from the recipe when chaining"):
pipeline_pass.run(mock_qairt_model, str(output_path))
Loading