Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion olive/passes/qairt/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,12 @@ def _run_for_config(
if src.exists() and not dst.exists():
shutil.copy2(src, dst)

return QairtModelHandler(model_path=output_model_path)
# Forward AR decode lengths to QairtEncapsulation so it can compute max_length correctly.
# arn in the pipeline recipe is the equivalent of sequence_lengths in QairtGenAIBuilder.
arn = recipe_data.get("stages", {}).get("genai_builder", {}).get("transform_options", {}).get("arn")
if arn is not None and (not isinstance(arn, list) or not arn or not all(isinstance(v, int) for v in arn)):
raise ValueError(
f"stages.genai_builder.transform_options.arn must be a non-empty list of integers, got: {arn!r}"
)
model_attrs = {"sequence_lengths": arn} if arn else {}
return QairtModelHandler(model_path=output_model_path, model_attributes=model_attrs)
73 changes: 73 additions & 0 deletions test/passes/qairt/test_pipeline_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,79 @@ def test_pipeline_pass_missing_recipe_file(tmp_path, mock_hf_model, mock_pipelin
pipeline_pass.run(mock_hf_model, str(output_path))


def test_pipeline_pass_forwards_arn_as_sequence_lengths(tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules):
"""ARN from genai_builder.transform_options is forwarded as sequence_lengths in model_attributes."""
output_path = tmp_path / "output"

mock_pipeline_modules["Recipe"].from_file.return_value = {
"stages": {
"genai_builder": {
"transform_options": {
"arn": [128, 1024, 4096],
}
}
}
}

pipeline_pass = create_pass_from_dict(
QairtPipelinePass,
{"recipe": str(recipe_file)},
disable_search=True,
)

result = pipeline_pass.run(mock_hf_model, str(output_path))

assert isinstance(result, QairtModelHandler)
assert result.model_attributes == {"sequence_lengths": [128, 1024, 4096]}


@pytest.mark.parametrize(
"bad_arn",
[
"128,1024", # string instead of list
1024, # bare int
[], # empty list
[128, "1024"], # list with non-int element
],
)
def test_pipeline_pass_invalid_arn_raises(tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules, bad_arn):
"""A malformed arn value raises ValueError with a clear message before propagating downstream."""
output_path = tmp_path / "output"

mock_pipeline_modules["Recipe"].from_file.return_value = {
"stages": {"genai_builder": {"transform_options": {"arn": bad_arn}}}
}

pipeline_pass = create_pass_from_dict(
QairtPipelinePass,
{"recipe": str(recipe_file)},
disable_search=True,
)

with pytest.raises(ValueError, match="must be a non-empty list of integers"):
pipeline_pass.run(mock_hf_model, str(output_path))


def test_pipeline_pass_no_arn_yields_empty_model_attributes(
tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules
):
"""When genai_builder.transform_options has no arn, sequence_lengths is not set in model_attributes."""
output_path = tmp_path / "output"

mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {"genai_builder": {"transform_options": {}}}}

pipeline_pass = create_pass_from_dict(
QairtPipelinePass,
{"recipe": str(recipe_file)},
disable_search=True,
)

result = pipeline_pass.run(mock_hf_model, str(output_path))

assert isinstance(result, QairtModelHandler)
assert "sequence_lengths" not in (result.model_attributes or {})


def test_pipeline_pass_import_error(tmp_path, mock_hf_model, recipe_file):
"""Test that ImportError is raised if qairt cannot be imported."""

Expand Down
Loading