diff --git a/docs/source/how-to/cli/cli-fast-test.md b/docs/source/how-to/cli/cli-fast-test.md index 49fa13055..9557e5253 100644 --- a/docs/source/how-to/cli/cli-fast-test.md +++ b/docs/source/how-to/cli/cli-fast-test.md @@ -50,6 +50,23 @@ This is a quick way to confirm that: If you omit the folder and just pass `--test`, `olive run` will save the reduced model under `/test_model`. +### Optional: choose which `--test` metrics to run + +By default, `--test` evaluates both: + +- `mae`: maximum absolute error between the ONNX and reference model outputs +- `speedup`: ONNX-vs-PyTorch latency measurement + +You can select a subset with `--test_metrics`. For example, to run only speedup checks: + +```bash +olive run \ + --config out/qwen/config.json \ + --test out/qwen-test-model \ + --test_metrics speedup \ + --output_path out/qwen-test-run +``` + ## Step 3: run the full conversion Once the smoke test succeeds, rerun the conversion on the full Qwen checkpoint by removing `--test`. diff --git a/olive/cli/base.py b/olive/cli/base.py index 50f1e55bf..82632ca1d 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -22,6 +22,9 @@ TEST_OUTPUT_MARKER_FILE = "olive_test_output.json" +# Metrics that --test can evaluate via the injected OnnxDiscrepancyCheck pass. +TEST_METRICS = ("mae", "speedup") + def _get_test_output_marker_path(output_path: str) -> Path: return Path(output_path) / TEST_OUTPUT_MARKER_FILE @@ -67,8 +70,20 @@ def mark_test_output_path(output_path: Optional[str]) -> None: _get_test_output_marker_path(output_path).write_text(json.dumps({"type": "olive_hf_test_output"}, indent=2)) -def add_discrepancy_check_pass(run_config: dict) -> dict: - """Inject OnnxDiscrepancyCheck pass when --test is active and not already configured.""" +def warn_unused_test_metrics(test, metrics: Optional[list]) -> None: + """Warn when --test_metrics is provided without --test, since it has no effect.""" + if metrics and test in (None, False): + logger.warning("--test_metrics is ignored because --test is not enabled.") + + +def add_discrepancy_check_pass(run_config: dict, metrics: Optional[list] = None) -> dict: + """Inject OnnxDiscrepancyCheck pass when --test is active and not already configured. + + ``metrics`` selects which test metrics to evaluate. Supported values are defined in + ``TEST_METRICS`` (``"mae"`` for the max-absolute-error accuracy check and ``"speedup"`` for the + ONNX-vs-PyTorch latency measurement). When ``None``, only ``"mae"`` is evaluated; pass + ``["speedup"]`` or ``["mae", "speedup"]`` explicitly to enable timing. + """ passes = run_config.get("passes", {}) # Skip if already configured for pass_config in passes.values(): @@ -86,12 +101,21 @@ def add_discrepancy_check_pass(run_config: dict) -> dict: if report_dir and Path(report_dir).suffix and not Path(report_dir).is_dir(): report_dir = str(Path(report_dir).parent) logger.debug("Adding OnnxDiscrepancyCheck pass with reference_model_path=%s", reference_model_path) - passes["discrepancy_check"] = { + + selected_metrics = set(metrics) if metrics else {"mae"} + pass_config = { "type": "OnnxDiscrepancyCheck", "reference_model_path": reference_model_path, - "max_mae": 0.1, "report_output_dir": report_dir, } + # Enforce the max-absolute-error threshold only when the accuracy metric is requested. + if "mae" in selected_metrics: + pass_config["max_mae"] = 0.1 + # Disable the latency/speedup measurement when the speedup metric is not requested. + if "speedup" not in selected_metrics: + pass_config["timing_iterations"] = 0 + + passes["discrepancy_check"] = pass_config run_config["passes"] = passes return run_config @@ -135,12 +159,13 @@ def _run_workflow(self): from olive.workflows import run as olive_run validate_test_output_path(self.args.output_path, getattr(self.args, "test", None)) + warn_unused_test_metrics(getattr(self.args, "test", None), getattr(self.args, "test_metrics", None)) Path(self.args.output_path).mkdir(parents=True, exist_ok=True) with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self._get_run_config(tempdir) if getattr(self.args, "test", None) not in (None, False): - run_config = add_discrepancy_check_pass(run_config) + run_config = add_discrepancy_check_pass(run_config, getattr(self.args, "test_metrics", None)) if self.args.save_config_file or self.args.dry_run: self._save_config_file(run_config) if self.args.dry_run: @@ -505,6 +530,17 @@ def add_input_model_options( "Optionally provide a folder where the generated test model should be saved and reused." ), ) + model_group.add_argument( + "--test_metrics", + type=str, + nargs="+", + choices=list(TEST_METRICS), + help=( + "Metrics to evaluate during a --test run: 'mae' enforces the max absolute error between the " + "ONNX and reference model outputs, and 'speedup' measures ONNX-vs-PyTorch inference latency. " + "Defaults to all metrics. Only used together with --test." + ), + ) if enable_hf_adapter: assert enable_hf, "enable_hf must be True when enable_hf_adapter is True." diff --git a/olive/cli/run.py b/olive/cli/run.py index 8554ddbe4..7d173f3c0 100644 --- a/olive/cli/run.py +++ b/olive/cli/run.py @@ -15,6 +15,7 @@ mark_test_output_path, save_discrepancy_check_results, validate_test_output_path, + warn_unused_test_metrics, ) from olive.telemetry import action @@ -83,8 +84,9 @@ def run(self): output_path = run_config.get("output_dir") or run_config.get("engine", {}).get("output_dir") validate_test_output_path(output_path, self.args.test) + warn_unused_test_metrics(self.args.test, getattr(self.args, "test_metrics", None)) if self.args.test not in (None, False): - run_config = add_discrepancy_check_pass(run_config) + run_config = add_discrepancy_check_pass(run_config, getattr(self.args, "test_metrics", None)) workflow_output = olive_run( run_config, list_required_packages=self.args.list_required_packages, diff --git a/test/cli/test_base.py b/test/cli/test_base.py index bb34cef3f..d4ddffa89 100644 --- a/test/cli/test_base.py +++ b/test/cli/test_base.py @@ -338,3 +338,62 @@ def test_get_input_model_config_no_crash_without_onnx_file_name(tmp_path): # model_path should remain unchanged since no onnx_file_name to guide rewriting assert config["config"]["model_path"] == stale_model_path + + +def _discrepancy_run_config(): + return { + "input_model": {"type": "HfModel", "test_model_path": "ref_model"}, + "output_dir": "out_dir", + } + + +def test_add_discrepancy_check_pass_default_enables_mae_only(): + from olive.cli.base import add_discrepancy_check_pass + + run_config = add_discrepancy_check_pass(_discrepancy_run_config()) + + pass_config = run_config["passes"]["discrepancy_check"] + assert pass_config["type"] == "OnnxDiscrepancyCheck" + assert pass_config["reference_model_path"] == "ref_model" + # default: mae only -> threshold enforced, timing disabled + assert pass_config["max_mae"] == 0.1 + assert pass_config["timing_iterations"] == 0 + + +def test_add_discrepancy_check_pass_speedup_only_disables_mae(): + from olive.cli.base import add_discrepancy_check_pass + + run_config = add_discrepancy_check_pass(_discrepancy_run_config(), metrics=["speedup"]) + + pass_config = run_config["passes"]["discrepancy_check"] + assert "max_mae" not in pass_config + assert "timing_iterations" not in pass_config + + +def test_add_discrepancy_check_pass_mae_only_disables_speedup(): + from olive.cli.base import add_discrepancy_check_pass + + run_config = add_discrepancy_check_pass(_discrepancy_run_config(), metrics=["mae"]) + + pass_config = run_config["passes"]["discrepancy_check"] + assert pass_config["max_mae"] == 0.1 + assert pass_config["timing_iterations"] == 0 + + +def test_warn_unused_test_metrics_logs_when_test_disabled(): + from olive.cli.base import warn_unused_test_metrics + + with patch("olive.cli.base.logger") as mock_logger: + warn_unused_test_metrics(test=None, metrics=["speedup"]) + + mock_logger.warning.assert_called_once() + assert "--test_metrics is ignored" in mock_logger.warning.call_args[0][0] + + +def test_warn_unused_test_metrics_silent_when_test_enabled(): + from olive.cli.base import warn_unused_test_metrics + + with patch("olive.cli.base.logger") as mock_logger: + warn_unused_test_metrics(test=True, metrics=["speedup"]) + + mock_logger.warning.assert_not_called() diff --git a/test/cli/test_cli.py b/test/cli/test_cli.py index 59817d830..5e666b94f 100644 --- a/test/cli/test_cli.py +++ b/test/cli/test_cli.py @@ -189,8 +189,9 @@ def test_workflow_run_command_with_test_override(mock_run, tmp_path): "discrepancy_check": { "type": "OnnxDiscrepancyCheck", "reference_model_path": test_model_path, - "max_mae": 0.1, "report_output_dir": output_dir, + "max_mae": 0.1, + "timing_iterations": 0, } }, },