Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@
"supported_algorithms": [ "awq" ],
"supported_quantization_encodings": [ ],
"run_on_target": true,
"extra_dependencies": [ "amd-quark" ]
"extra_dependencies": [ "amd-quark>=0.12.0" ]
},
"QuarkQuantizationVitisAI": {
"module_path": "olive.passes.quark_vitisai.quark_quantization_vitisai.QuarkQuantizationVitisAI",
Expand All @@ -616,7 +616,7 @@
"supported_algorithms": [ "awq" ],
"supported_quantization_encodings": [ ],
"run_on_target": true,
"extra_dependencies": [ "amd-quark" ]
"extra_dependencies": [ "amd-quark>=0.12.0" ]
},
"KQuant": {
"module_path": "olive.passes.pytorch.kquant.KQuant",
Expand Down
17 changes: 12 additions & 5 deletions olive/passes/quark_quantizer/quark_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class QuarkQuantization(Pass):

Routes to the appropriate backend based on input model type:
- ONNXModelHandler -> Quark-ONNX quantization (quark.onnx)
- HfModelHandler -> Quark-Torch quantization (quark.torch, Quark 0.11 API)
- HfModelHandler -> Quark-Torch quantization (quark.torch, Quark 0.12 API)
"""

@classmethod
Expand Down Expand Up @@ -188,7 +188,7 @@ def _run_for_config(
logger.info("[INFO] Running QuarkQuantization using Quark-ONNX API")
return self._run_quark_onnx(model, config, output_model_path)
else:
logger.info("[INFO] Running QuarkQuantization using Quark-Torch 0.11 API")
logger.info("[INFO] Running QuarkQuantization using Quark-Torch 0.12 API")
return self._run_quark_torch(model, config, output_model_path)
Comment thread
thpereir marked this conversation as resolved.

# ── ONNX path ───────────────────────────────────────────
Expand All @@ -201,8 +201,8 @@ def _run_quark_onnx(
) -> ONNXModelHandler:
from quark import __version__ as QuarkVersion

if version.parse(QuarkVersion) < version.parse("0.11.0"):
raise ValueError("Quark ONNX Quantization is only supported for amd-quark>=0.11.0")
if version.parse(QuarkVersion) < version.parse("0.12.0"):
raise ValueError("Quark ONNX Quantization is only supported for amd-quark>=0.12.0")

from olive.passes.quark_quantizer.onnx.quantize_quark import run_quark_quantization

Expand All @@ -217,7 +217,9 @@ def _run_quark_onnx(
data_reader = None
if config.data_config:
data_config = validate_config(config.data_config, DataConfig)
data_reader = data_config.to_data_container().create_calibration_dataloader()
data_reader = data_config.to_data_container().create_calibration_dataloader(
model_path=model.model_path, io_config=model.io_config
)

run_config = config.model_dump()
to_delete = [
Expand Down Expand Up @@ -253,6 +255,11 @@ def _run_quark_torch(
config: BasePassConfig,
output_model_path: str,
) -> HfModelHandler:
from quark import __version__ as QuarkVersion

if version.parse(QuarkVersion) < version.parse("0.12.0"):
raise ValueError("Quark Torch Quantization is only supported for amd-quark>=0.12.0")

from olive.passes.quark_quantizer.torch.quark_torch_quantization import run_quark_torch_quantization

return run_quark_torch_quantization(model, config, output_model_path)
17 changes: 6 additions & 11 deletions olive/passes/quark_quantizer/torch/quark_torch_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: MIT
#

"""Quark 0.11 Torch quantization for LLMs.
"""Quark 0.12 Torch quantization for LLMs.

Uses LLMTemplate + ModelQuantizer from the Quark public API.
"""
Expand All @@ -27,7 +27,7 @@ def run_quark_torch_quantization(
config: BasePassConfig,
output_model_path: str,
) -> HfModelHandler:
"""Run Quark 0.11 torch quantization on a HuggingFace model.
"""Run Quark 0.12 torch quantization on a HuggingFace model.

Args:
model: Olive HfModelHandler pointing to the source model.
Expand All @@ -53,8 +53,7 @@ def run_quark_torch_quantization(
get_calib_dataloader,
get_model,
get_tokenizer,
prepare_for_moe_quant,
revert_model_patching,
preprocess_for_quantization,
)

output_dir = Path(output_model_path)
Expand All @@ -75,7 +74,7 @@ def run_quark_torch_quantization(
trust_remote_code=config.trust_remote_code,
)

prepare_for_moe_quant(torch_model)
preprocess_for_quantization(torch_model)

model_type = (
torch_model.config.model_type
Expand Down Expand Up @@ -145,15 +144,11 @@ def run_quark_torch_quantization(
logger.info("[INFO] Freezing quantized model")
torch_model = quantizer.freeze(torch_model)

# 6. Revert model patching
logger.info("[INFO] Reverting model patching")
revert_model_patching(torch_model)

# 7. Validate export configuration
# 6. Validate export configuration
if config.custom_mode != "quark" and config.export_weight_format == "fake_quantized":
raise ValueError("'fake_quantized' export is only supported with custom_mode='quark'")

# 8. Export model
# 7. Export model
logger.info("[INFO] Exporting quantized model to: %s", output_dir)

export_formats = config.model_export
Expand Down
73 changes: 73 additions & 0 deletions test/passes/quark_quantizer/test_quark_onnx_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from unittest.mock import patch

import numpy as np
import pytest
from onnxruntime.quantization.calibrate import CalibrationDataReader

Expand Down Expand Up @@ -194,3 +197,73 @@ def test_static_qdq_u8s8_with_layerwise_mixed_precision_quantization(tmp_path):
p = create_pass_from_dict(QuarkQuantization, config, disable_search=True)
out = p.run(input_model, tmp_path)
assert out is not None


@Registry.register_dataloader()
def _test_quant_dataloader_with_label(dataset, batch_size, **kwargs):
"""Yield model input and a label column as a real dataset pipeline would.

The label must be filtered out before being passed to the ONNX calibration reader.
"""

class _ReaderWithLabel(CalibrationDataReader):
# pylint: disable=W0223
def __init__(self):
super().__init__()
self.samples = [{"input": np.random.randn(1, 1).astype(np.float32), "label": 0} for _ in range(4)]
self._iter = iter(self.samples)

def get_next(self):
return next(self._iter, None)

return _ReaderWithLabel()


def test_calibration_dataloader_filters_label_columns(tmp_path):
"""Regression: Quark 0.12 rejects calibration inputs that are not model inputs.

The pass must pass model_path/io_config to create_calibration_dataloader() so
non-input columns (e.g. a label column) are stripped before reaching Quark.
"""
input_model = get_onnx_model()
config = {
"quant_mode": "static",
"quant_format": "QDQ",
"global_config": {
"activation": {"symmetric": False, "calibration_method": "MinMax", "data_type": "UInt8"},
"weight": {"symmetric": True, "calibration_method": "MinMax", "data_type": "Int8"},
},
"data_config": DataConfig(
name="test_quant_dc_config_with_label",
load_dataset_config=DataComponentConfig(type="simple_dataset"),
dataloader_config=DataComponentConfig(type="_test_quant_dataloader_with_label"),
),
}
p = create_pass_from_dict(QuarkQuantization, config, disable_search=True)
# Should complete without "Invalid input name: label" from Quark
out = p.run(input_model, tmp_path)
assert out is not None


def test_onnx_path_raises_clear_error_for_old_quark(tmp_path):
"""ONNX path must raise a clear ValueError, not an ImportError, for amd-quark < 0.12.0."""
input_model = get_onnx_model()
config = {"quant_mode": "static", "quant_format": "QDQ"}
p = create_pass_from_dict(QuarkQuantization, config, disable_search=True)
with patch("quark.__version__", "0.11.2"), pytest.raises(ValueError, match=r"amd-quark>=0\.12\.0"):
p.run(input_model, tmp_path)


def test_torch_path_raises_clear_error_for_old_quark(tmp_path):
"""Torch path must raise a clear ValueError, not an ImportError, for amd-quark < 0.12.0."""
from unittest.mock import MagicMock

from olive.model import HfModelHandler

# Use a MagicMock so handler construction/validation is bypassed.
# The version gate fires inside _run_quark_torch before any model loading.
input_model = MagicMock(spec=HfModelHandler)
config = {"quant_scheme": "uint4_wo_128"}
p = create_pass_from_dict(QuarkQuantization, config, disable_search=True)
with patch("quark.__version__", "0.11.2"), pytest.raises(ValueError, match=r"amd-quark>=0\.12\.0"):
p.run(input_model, tmp_path)