From 823a0223a8c86bf0e0dc48188ce82eed1db904be Mon Sep 17 00:00:00 2001 From: Thiago Rocha Date: Fri, 19 Jun 2026 18:05:58 -0400 Subject: [PATCH] [AMD] Update QuarkQuantization pass for Quark 0.12 Fix the ONNX calibration path: Quark 0.12 rejects extra inputs in the calibration data reader, so pass model_path/io_config to filter out dataset labels (e.g. `class`) that are not model inputs. Update the Torch path for the Quark 0.12 API: replace the removed `revert_model_patching` and the deprecated `prepare_for_moe_quant` (which now raises on transformers>=5) with `preprocess_for_quantization`. Bump the required amd-quark version to >=0.12.0. --- olive/olive_config.json | 4 +- .../quark_quantizer/quark_quantization.py | 17 +++-- .../torch/quark_torch_quantization.py | 17 ++--- .../test_quark_onnx_quantization.py | 73 +++++++++++++++++++ 4 files changed, 93 insertions(+), 18 deletions(-) diff --git a/olive/olive_config.json b/olive/olive_config.json index 1978e61dc..edec8a9ea 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -606,7 +606,7 @@ "supported_algorithms": [ "awq" ], "supported_quantization_encodings": [ ], "run_on_target": true, - "extra_dependencies": [ "amd-quark" ] + "extra_dependencies": [ "amd-quark>=0.12.0" ] }, "QuarkQuantizationVitisAI": { "module_path": "olive.passes.quark_vitisai.quark_quantization_vitisai.QuarkQuantizationVitisAI", @@ -616,7 +616,7 @@ "supported_algorithms": [ "awq" ], "supported_quantization_encodings": [ ], "run_on_target": true, - "extra_dependencies": [ "amd-quark" ] + "extra_dependencies": [ "amd-quark>=0.12.0" ] }, "KQuant": { "module_path": "olive.passes.pytorch.kquant.KQuant", diff --git a/olive/passes/quark_quantizer/quark_quantization.py b/olive/passes/quark_quantizer/quark_quantization.py index 165d21d93..f15a2d587 100644 --- a/olive/passes/quark_quantizer/quark_quantization.py +++ b/olive/passes/quark_quantizer/quark_quantization.py @@ -30,7 +30,7 @@ class QuarkQuantization(Pass): Routes to the appropriate backend based on input model type: - ONNXModelHandler -> Quark-ONNX quantization (quark.onnx) - - HfModelHandler -> Quark-Torch quantization (quark.torch, Quark 0.11 API) + - HfModelHandler -> Quark-Torch quantization (quark.torch, Quark 0.12 API) """ @classmethod @@ -188,7 +188,7 @@ def _run_for_config( logger.info("[INFO] Running QuarkQuantization using Quark-ONNX API") return self._run_quark_onnx(model, config, output_model_path) else: - logger.info("[INFO] Running QuarkQuantization using Quark-Torch 0.11 API") + logger.info("[INFO] Running QuarkQuantization using Quark-Torch 0.12 API") return self._run_quark_torch(model, config, output_model_path) # ── ONNX path ─────────────────────────────────────────── @@ -201,8 +201,8 @@ def _run_quark_onnx( ) -> ONNXModelHandler: from quark import __version__ as QuarkVersion - if version.parse(QuarkVersion) < version.parse("0.11.0"): - raise ValueError("Quark ONNX Quantization is only supported for amd-quark>=0.11.0") + if version.parse(QuarkVersion) < version.parse("0.12.0"): + raise ValueError("Quark ONNX Quantization is only supported for amd-quark>=0.12.0") from olive.passes.quark_quantizer.onnx.quantize_quark import run_quark_quantization @@ -217,7 +217,9 @@ def _run_quark_onnx( data_reader = None if config.data_config: data_config = validate_config(config.data_config, DataConfig) - data_reader = data_config.to_data_container().create_calibration_dataloader() + data_reader = data_config.to_data_container().create_calibration_dataloader( + model_path=model.model_path, io_config=model.io_config + ) run_config = config.model_dump() to_delete = [ @@ -253,6 +255,11 @@ def _run_quark_torch( config: BasePassConfig, output_model_path: str, ) -> HfModelHandler: + from quark import __version__ as QuarkVersion + + if version.parse(QuarkVersion) < version.parse("0.12.0"): + raise ValueError("Quark Torch Quantization is only supported for amd-quark>=0.12.0") + from olive.passes.quark_quantizer.torch.quark_torch_quantization import run_quark_torch_quantization return run_quark_torch_quantization(model, config, output_model_path) diff --git a/olive/passes/quark_quantizer/torch/quark_torch_quantization.py b/olive/passes/quark_quantizer/torch/quark_torch_quantization.py index de313ad1c..3bd1c4568 100644 --- a/olive/passes/quark_quantizer/torch/quark_torch_quantization.py +++ b/olive/passes/quark_quantizer/torch/quark_torch_quantization.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT # -"""Quark 0.11 Torch quantization for LLMs. +"""Quark 0.12 Torch quantization for LLMs. Uses LLMTemplate + ModelQuantizer from the Quark public API. """ @@ -27,7 +27,7 @@ def run_quark_torch_quantization( config: BasePassConfig, output_model_path: str, ) -> HfModelHandler: - """Run Quark 0.11 torch quantization on a HuggingFace model. + """Run Quark 0.12 torch quantization on a HuggingFace model. Args: model: Olive HfModelHandler pointing to the source model. @@ -53,8 +53,7 @@ def run_quark_torch_quantization( get_calib_dataloader, get_model, get_tokenizer, - prepare_for_moe_quant, - revert_model_patching, + preprocess_for_quantization, ) output_dir = Path(output_model_path) @@ -75,7 +74,7 @@ def run_quark_torch_quantization( trust_remote_code=config.trust_remote_code, ) - prepare_for_moe_quant(torch_model) + preprocess_for_quantization(torch_model) model_type = ( torch_model.config.model_type @@ -145,15 +144,11 @@ def run_quark_torch_quantization( logger.info("[INFO] Freezing quantized model") torch_model = quantizer.freeze(torch_model) - # 6. Revert model patching - logger.info("[INFO] Reverting model patching") - revert_model_patching(torch_model) - - # 7. Validate export configuration + # 6. Validate export configuration if config.custom_mode != "quark" and config.export_weight_format == "fake_quantized": raise ValueError("'fake_quantized' export is only supported with custom_mode='quark'") - # 8. Export model + # 7. Export model logger.info("[INFO] Exporting quantized model to: %s", output_dir) export_formats = config.model_export diff --git a/test/passes/quark_quantizer/test_quark_onnx_quantization.py b/test/passes/quark_quantizer/test_quark_onnx_quantization.py index cbab88de1..9d70d160a 100644 --- a/test/passes/quark_quantizer/test_quark_onnx_quantization.py +++ b/test/passes/quark_quantizer/test_quark_onnx_quantization.py @@ -3,6 +3,9 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from unittest.mock import patch + +import numpy as np import pytest from onnxruntime.quantization.calibrate import CalibrationDataReader @@ -194,3 +197,73 @@ def test_static_qdq_u8s8_with_layerwise_mixed_precision_quantization(tmp_path): p = create_pass_from_dict(QuarkQuantization, config, disable_search=True) out = p.run(input_model, tmp_path) assert out is not None + + +@Registry.register_dataloader() +def _test_quant_dataloader_with_label(dataset, batch_size, **kwargs): + """Yield model input and a label column as a real dataset pipeline would. + + The label must be filtered out before being passed to the ONNX calibration reader. + """ + + class _ReaderWithLabel(CalibrationDataReader): + # pylint: disable=W0223 + def __init__(self): + super().__init__() + self.samples = [{"input": np.random.randn(1, 1).astype(np.float32), "label": 0} for _ in range(4)] + self._iter = iter(self.samples) + + def get_next(self): + return next(self._iter, None) + + return _ReaderWithLabel() + + +def test_calibration_dataloader_filters_label_columns(tmp_path): + """Regression: Quark 0.12 rejects calibration inputs that are not model inputs. + + The pass must pass model_path/io_config to create_calibration_dataloader() so + non-input columns (e.g. a label column) are stripped before reaching Quark. + """ + input_model = get_onnx_model() + config = { + "quant_mode": "static", + "quant_format": "QDQ", + "global_config": { + "activation": {"symmetric": False, "calibration_method": "MinMax", "data_type": "UInt8"}, + "weight": {"symmetric": True, "calibration_method": "MinMax", "data_type": "Int8"}, + }, + "data_config": DataConfig( + name="test_quant_dc_config_with_label", + load_dataset_config=DataComponentConfig(type="simple_dataset"), + dataloader_config=DataComponentConfig(type="_test_quant_dataloader_with_label"), + ), + } + p = create_pass_from_dict(QuarkQuantization, config, disable_search=True) + # Should complete without "Invalid input name: label" from Quark + out = p.run(input_model, tmp_path) + assert out is not None + + +def test_onnx_path_raises_clear_error_for_old_quark(tmp_path): + """ONNX path must raise a clear ValueError, not an ImportError, for amd-quark < 0.12.0.""" + input_model = get_onnx_model() + config = {"quant_mode": "static", "quant_format": "QDQ"} + p = create_pass_from_dict(QuarkQuantization, config, disable_search=True) + with patch("quark.__version__", "0.11.2"), pytest.raises(ValueError, match=r"amd-quark>=0\.12\.0"): + p.run(input_model, tmp_path) + + +def test_torch_path_raises_clear_error_for_old_quark(tmp_path): + """Torch path must raise a clear ValueError, not an ImportError, for amd-quark < 0.12.0.""" + from unittest.mock import MagicMock + + from olive.model import HfModelHandler + + # Use a MagicMock so handler construction/validation is bypassed. + # The version gate fires inside _run_quark_torch before any model loading. + input_model = MagicMock(spec=HfModelHandler) + config = {"quant_scheme": "uint4_wo_128"} + p = create_pass_from_dict(QuarkQuantization, config, disable_search=True) + with patch("quark.__version__", "0.11.2"), pytest.raises(ValueError, match=r"amd-quark>=0\.12\.0"): + p.run(input_model, tmp_path)