diff --git a/olive/passes/qairt/pipeline.py b/olive/passes/qairt/pipeline.py index ca630774e..d012072cc 100644 --- a/olive/passes/qairt/pipeline.py +++ b/olive/passes/qairt/pipeline.py @@ -154,4 +154,12 @@ def _run_for_config( if src.exists() and not dst.exists(): shutil.copy2(src, dst) - return QairtModelHandler(model_path=output_model_path) + # Forward AR decode lengths to QairtEncapsulation so it can compute max_length correctly. + # arn in the pipeline recipe is the equivalent of sequence_lengths in QairtGenAIBuilder. + arn = recipe_data.get("stages", {}).get("genai_builder", {}).get("transform_options", {}).get("arn") + if arn is not None and (not isinstance(arn, list) or not arn or not all(isinstance(v, int) for v in arn)): + raise ValueError( + f"stages.genai_builder.transform_options.arn must be a non-empty list of integers, got: {arn!r}" + ) + model_attrs = {"sequence_lengths": arn} if arn else {} + return QairtModelHandler(model_path=output_model_path, model_attributes=model_attrs) diff --git a/test/passes/qairt/test_pipeline_pass.py b/test/passes/qairt/test_pipeline_pass.py index 09ab4c096..b4ae1b60a 100644 --- a/test/passes/qairt/test_pipeline_pass.py +++ b/test/passes/qairt/test_pipeline_pass.py @@ -228,6 +228,79 @@ def test_pipeline_pass_missing_recipe_file(tmp_path, mock_hf_model, mock_pipelin pipeline_pass.run(mock_hf_model, str(output_path)) +def test_pipeline_pass_forwards_arn_as_sequence_lengths(tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules): + """ARN from genai_builder.transform_options is forwarded as sequence_lengths in model_attributes.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "stages": { + "genai_builder": { + "transform_options": { + "arn": [128, 1024, 4096], + } + } + } + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + result = pipeline_pass.run(mock_hf_model, str(output_path)) + + assert isinstance(result, QairtModelHandler) + assert result.model_attributes == {"sequence_lengths": [128, 1024, 4096]} + + +@pytest.mark.parametrize( + "bad_arn", + [ + "128,1024", # string instead of list + 1024, # bare int + [], # empty list + [128, "1024"], # list with non-int element + ], +) +def test_pipeline_pass_invalid_arn_raises(tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules, bad_arn): + """A malformed arn value raises ValueError with a clear message before propagating downstream.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "stages": {"genai_builder": {"transform_options": {"arn": bad_arn}}} + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + with pytest.raises(ValueError, match="must be a non-empty list of integers"): + pipeline_pass.run(mock_hf_model, str(output_path)) + + +def test_pipeline_pass_no_arn_yields_empty_model_attributes( + tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules +): + """When genai_builder.transform_options has no arn, sequence_lengths is not set in model_attributes.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {"genai_builder": {"transform_options": {}}}} + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + result = pipeline_pass.run(mock_hf_model, str(output_path)) + + assert isinstance(result, QairtModelHandler) + assert "sequence_lengths" not in (result.model_attributes or {}) + + def test_pipeline_pass_import_error(tmp_path, mock_hf_model, recipe_file): """Test that ImportError is raised if qairt cannot be imported."""