From 6b5b65236fc6c1b5681dfb72c9427d6815f63db3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 22 Jun 2026 16:35:33 +0200 Subject: [PATCH 1/5] extend command line --test to trigger speedup measure --- olive/cli/base.py | 45 +++++++++++++++++++++++++++++---- olive/cli/run.py | 4 ++- test/cli/test_base.py | 59 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 6 deletions(-) diff --git a/olive/cli/base.py b/olive/cli/base.py index 50f1e55bf..fda595367 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -22,6 +22,9 @@ TEST_OUTPUT_MARKER_FILE = "olive_test_output.json" +# Metrics that --test can evaluate via the injected OnnxDiscrepancyCheck pass. +TEST_METRICS = ("mae", "speedup") + def _get_test_output_marker_path(output_path: str) -> Path: return Path(output_path) / TEST_OUTPUT_MARKER_FILE @@ -67,8 +70,19 @@ def mark_test_output_path(output_path: Optional[str]) -> None: _get_test_output_marker_path(output_path).write_text(json.dumps({"type": "olive_hf_test_output"}, indent=2)) -def add_discrepancy_check_pass(run_config: dict) -> dict: - """Inject OnnxDiscrepancyCheck pass when --test is active and not already configured.""" +def warn_unused_test_metrics(test, metrics: Optional[list]) -> None: + """Warn when --test_metrics is provided without --test, since it has no effect.""" + if metrics and test in (None, False): + logger.warning("--test_metrics is ignored because --test is not enabled.") + + +def add_discrepancy_check_pass(run_config: dict, metrics: Optional[list] = None) -> dict: + """Inject OnnxDiscrepancyCheck pass when --test is active and not already configured. + + ``metrics`` selects which test metrics to evaluate. Supported values are defined in + ``TEST_METRICS`` (``"mae"`` for the max-absolute-error accuracy check and ``"speedup"`` for the + ONNX-vs-PyTorch latency measurement). When ``None``, all metrics are evaluated. + """ passes = run_config.get("passes", {}) # Skip if already configured for pass_config in passes.values(): @@ -86,12 +100,21 @@ def add_discrepancy_check_pass(run_config: dict) -> dict: if report_dir and Path(report_dir).suffix and not Path(report_dir).is_dir(): report_dir = str(Path(report_dir).parent) logger.debug("Adding OnnxDiscrepancyCheck pass with reference_model_path=%s", reference_model_path) - passes["discrepancy_check"] = { + + selected_metrics = set(metrics) if metrics else set(TEST_METRICS) + pass_config = { "type": "OnnxDiscrepancyCheck", "reference_model_path": reference_model_path, - "max_mae": 0.1, "report_output_dir": report_dir, } + # Enforce the max-absolute-error threshold only when the accuracy metric is requested. + if "mae" in selected_metrics: + pass_config["max_mae"] = 0.1 + # Disable the latency/speedup measurement when the speedup metric is not requested. + if "speedup" not in selected_metrics: + pass_config["timing_iterations"] = 0 + + passes["discrepancy_check"] = pass_config run_config["passes"] = passes return run_config @@ -135,12 +158,13 @@ def _run_workflow(self): from olive.workflows import run as olive_run validate_test_output_path(self.args.output_path, getattr(self.args, "test", None)) + warn_unused_test_metrics(getattr(self.args, "test", None), getattr(self.args, "test_metrics", None)) Path(self.args.output_path).mkdir(parents=True, exist_ok=True) with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self._get_run_config(tempdir) if getattr(self.args, "test", None) not in (None, False): - run_config = add_discrepancy_check_pass(run_config) + run_config = add_discrepancy_check_pass(run_config, getattr(self.args, "test_metrics", None)) if self.args.save_config_file or self.args.dry_run: self._save_config_file(run_config) if self.args.dry_run: @@ -505,6 +529,17 @@ def add_input_model_options( "Optionally provide a folder where the generated test model should be saved and reused." ), ) + model_group.add_argument( + "--test_metrics", + type=str, + nargs="+", + choices=list(TEST_METRICS), + help=( + "Metrics to evaluate during a --test run: 'mae' enforces the max absolute error between the " + "ONNX and reference model outputs, and 'speedup' measures ONNX-vs-PyTorch inference latency. " + "Defaults to all metrics. Only used together with --test." + ), + ) if enable_hf_adapter: assert enable_hf, "enable_hf must be True when enable_hf_adapter is True." diff --git a/olive/cli/run.py b/olive/cli/run.py index 8554ddbe4..7d173f3c0 100644 --- a/olive/cli/run.py +++ b/olive/cli/run.py @@ -15,6 +15,7 @@ mark_test_output_path, save_discrepancy_check_results, validate_test_output_path, + warn_unused_test_metrics, ) from olive.telemetry import action @@ -83,8 +84,9 @@ def run(self): output_path = run_config.get("output_dir") or run_config.get("engine", {}).get("output_dir") validate_test_output_path(output_path, self.args.test) + warn_unused_test_metrics(self.args.test, getattr(self.args, "test_metrics", None)) if self.args.test not in (None, False): - run_config = add_discrepancy_check_pass(run_config) + run_config = add_discrepancy_check_pass(run_config, getattr(self.args, "test_metrics", None)) workflow_output = olive_run( run_config, list_required_packages=self.args.list_required_packages, diff --git a/test/cli/test_base.py b/test/cli/test_base.py index bb34cef3f..b09fa96d5 100644 --- a/test/cli/test_base.py +++ b/test/cli/test_base.py @@ -338,3 +338,62 @@ def test_get_input_model_config_no_crash_without_onnx_file_name(tmp_path): # model_path should remain unchanged since no onnx_file_name to guide rewriting assert config["config"]["model_path"] == stale_model_path + + +def _discrepancy_run_config(): + return { + "input_model": {"type": "HfModel", "test_model_path": "ref_model"}, + "output_dir": "out_dir", + } + + +def test_add_discrepancy_check_pass_default_enables_all_metrics(): + from olive.cli.base import add_discrepancy_check_pass + + run_config = add_discrepancy_check_pass(_discrepancy_run_config()) + + pass_config = run_config["passes"]["discrepancy_check"] + assert pass_config["type"] == "OnnxDiscrepancyCheck" + assert pass_config["reference_model_path"] == "ref_model" + # mae metric -> threshold enforced; speedup metric -> timing not disabled + assert pass_config["max_mae"] == 0.1 + assert "timing_iterations" not in pass_config + + +def test_add_discrepancy_check_pass_speedup_only_disables_mae(): + from olive.cli.base import add_discrepancy_check_pass + + run_config = add_discrepancy_check_pass(_discrepancy_run_config(), metrics=["speedup"]) + + pass_config = run_config["passes"]["discrepancy_check"] + assert "max_mae" not in pass_config + assert "timing_iterations" not in pass_config + + +def test_add_discrepancy_check_pass_mae_only_disables_speedup(): + from olive.cli.base import add_discrepancy_check_pass + + run_config = add_discrepancy_check_pass(_discrepancy_run_config(), metrics=["mae"]) + + pass_config = run_config["passes"]["discrepancy_check"] + assert pass_config["max_mae"] == 0.1 + assert pass_config["timing_iterations"] == 0 + + +def test_warn_unused_test_metrics_logs_when_test_disabled(): + from olive.cli.base import warn_unused_test_metrics + + with patch("olive.cli.base.logger") as mock_logger: + warn_unused_test_metrics(test=None, metrics=["speedup"]) + + mock_logger.warning.assert_called_once() + assert "--test_metrics is ignored" in mock_logger.warning.call_args[0][0] + + +def test_warn_unused_test_metrics_silent_when_test_enabled(): + from olive.cli.base import warn_unused_test_metrics + + with patch("olive.cli.base.logger") as mock_logger: + warn_unused_test_metrics(test=True, metrics=["speedup"]) + + mock_logger.warning.assert_not_called() From 15287d831edc19e7e1529340cbcdb2f53fbc7de9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 14:47:04 +0000 Subject: [PATCH 2/5] Document --test_metrics speedup usage --- docs/source/how-to/cli/cli-fast-test.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/source/how-to/cli/cli-fast-test.md b/docs/source/how-to/cli/cli-fast-test.md index 49fa13055..9557e5253 100644 --- a/docs/source/how-to/cli/cli-fast-test.md +++ b/docs/source/how-to/cli/cli-fast-test.md @@ -50,6 +50,23 @@ This is a quick way to confirm that: If you omit the folder and just pass `--test`, `olive run` will save the reduced model under `/test_model`. +### Optional: choose which `--test` metrics to run + +By default, `--test` evaluates both: + +- `mae`: maximum absolute error between the ONNX and reference model outputs +- `speedup`: ONNX-vs-PyTorch latency measurement + +You can select a subset with `--test_metrics`. For example, to run only speedup checks: + +```bash +olive run \ + --config out/qwen/config.json \ + --test out/qwen-test-model \ + --test_metrics speedup \ + --output_path out/qwen-test-run +``` + ## Step 3: run the full conversion Once the smoke test succeeds, rerun the conversion on the full Qwen checkpoint by removing `--test`. From 89d98c43fbadbe1fa5cda302789ba0f4802ebc77 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:24:19 +0000 Subject: [PATCH 3/5] Fix default test metrics to be mae-only, make speedup opt-in --- olive/cli/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/olive/cli/base.py b/olive/cli/base.py index fda595367..82632ca1d 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -81,7 +81,8 @@ def add_discrepancy_check_pass(run_config: dict, metrics: Optional[list] = None) ``metrics`` selects which test metrics to evaluate. Supported values are defined in ``TEST_METRICS`` (``"mae"`` for the max-absolute-error accuracy check and ``"speedup"`` for the - ONNX-vs-PyTorch latency measurement). When ``None``, all metrics are evaluated. + ONNX-vs-PyTorch latency measurement). When ``None``, only ``"mae"`` is evaluated; pass + ``["speedup"]`` or ``["mae", "speedup"]`` explicitly to enable timing. """ passes = run_config.get("passes", {}) # Skip if already configured @@ -101,7 +102,7 @@ def add_discrepancy_check_pass(run_config: dict, metrics: Optional[list] = None) report_dir = str(Path(report_dir).parent) logger.debug("Adding OnnxDiscrepancyCheck pass with reference_model_path=%s", reference_model_path) - selected_metrics = set(metrics) if metrics else set(TEST_METRICS) + selected_metrics = set(metrics) if metrics else {"mae"} pass_config = { "type": "OnnxDiscrepancyCheck", "reference_model_path": reference_model_path, From 7490e559c01f4aad3e8511c9e4acb6dfd897d085 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:58:38 +0000 Subject: [PATCH 4/5] Fix test to match new default mae-only behavior --- test/cli/test_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/cli/test_base.py b/test/cli/test_base.py index b09fa96d5..d4ddffa89 100644 --- a/test/cli/test_base.py +++ b/test/cli/test_base.py @@ -347,7 +347,7 @@ def _discrepancy_run_config(): } -def test_add_discrepancy_check_pass_default_enables_all_metrics(): +def test_add_discrepancy_check_pass_default_enables_mae_only(): from olive.cli.base import add_discrepancy_check_pass run_config = add_discrepancy_check_pass(_discrepancy_run_config()) @@ -355,9 +355,9 @@ def test_add_discrepancy_check_pass_default_enables_all_metrics(): pass_config = run_config["passes"]["discrepancy_check"] assert pass_config["type"] == "OnnxDiscrepancyCheck" assert pass_config["reference_model_path"] == "ref_model" - # mae metric -> threshold enforced; speedup metric -> timing not disabled + # default: mae only -> threshold enforced, timing disabled assert pass_config["max_mae"] == 0.1 - assert "timing_iterations" not in pass_config + assert pass_config["timing_iterations"] == 0 def test_add_discrepancy_check_pass_speedup_only_disables_mae(): From bf96e3f5f6b3961e1f2a284a47e1e004a160bc1f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:15:56 +0000 Subject: [PATCH 5/5] Fix test_cli.py expected pass config to include timing_iterations=0 --- test/cli/test_cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/cli/test_cli.py b/test/cli/test_cli.py index 59817d830..5e666b94f 100644 --- a/test/cli/test_cli.py +++ b/test/cli/test_cli.py @@ -189,8 +189,9 @@ def test_workflow_run_command_with_test_override(mock_run, tmp_path): "discrepancy_check": { "type": "OnnxDiscrepancyCheck", "reference_model_path": test_model_path, - "max_mae": 0.1, "report_output_dir": output_dir, + "max_mae": 0.1, + "timing_iterations": 0, } }, },