Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/how-to/cli/cli-fast-test.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ This is a quick way to confirm that:

If you omit the folder and just pass `--test`, `olive run` will save the reduced model under `<output_path>/test_model`.

### Optional: choose which `--test` metrics to run

By default, `--test` evaluates both:

- `mae`: maximum absolute error between the ONNX and reference model outputs
- `speedup`: ONNX-vs-PyTorch latency measurement

You can select a subset with `--test_metrics`. For example, to run only speedup checks:

```bash
olive run \
--config out/qwen/config.json \
--test out/qwen-test-model \
--test_metrics speedup \
--output_path out/qwen-test-run
```

## Step 3: run the full conversion

Once the smoke test succeeds, rerun the conversion on the full Qwen checkpoint by removing `--test`.
Expand Down
46 changes: 41 additions & 5 deletions olive/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

TEST_OUTPUT_MARKER_FILE = "olive_test_output.json"

# Metrics that --test can evaluate via the injected OnnxDiscrepancyCheck pass.
TEST_METRICS = ("mae", "speedup")

Comment thread
xadupre marked this conversation as resolved.

def _get_test_output_marker_path(output_path: str) -> Path:
return Path(output_path) / TEST_OUTPUT_MARKER_FILE
Expand Down Expand Up @@ -67,8 +70,20 @@ def mark_test_output_path(output_path: Optional[str]) -> None:
_get_test_output_marker_path(output_path).write_text(json.dumps({"type": "olive_hf_test_output"}, indent=2))


def add_discrepancy_check_pass(run_config: dict) -> dict:
"""Inject OnnxDiscrepancyCheck pass when --test is active and not already configured."""
def warn_unused_test_metrics(test, metrics: Optional[list]) -> None:
"""Warn when --test_metrics is provided without --test, since it has no effect."""
if metrics and test in (None, False):
logger.warning("--test_metrics is ignored because --test is not enabled.")


def add_discrepancy_check_pass(run_config: dict, metrics: Optional[list] = None) -> dict:
"""Inject OnnxDiscrepancyCheck pass when --test is active and not already configured.

``metrics`` selects which test metrics to evaluate. Supported values are defined in
``TEST_METRICS`` (``"mae"`` for the max-absolute-error accuracy check and ``"speedup"`` for the
ONNX-vs-PyTorch latency measurement). When ``None``, only ``"mae"`` is evaluated; pass
``["speedup"]`` or ``["mae", "speedup"]`` explicitly to enable timing.
"""
passes = run_config.get("passes", {})
# Skip if already configured
for pass_config in passes.values():
Expand All @@ -86,12 +101,21 @@ def add_discrepancy_check_pass(run_config: dict) -> dict:
if report_dir and Path(report_dir).suffix and not Path(report_dir).is_dir():
report_dir = str(Path(report_dir).parent)
logger.debug("Adding OnnxDiscrepancyCheck pass with reference_model_path=%s", reference_model_path)
passes["discrepancy_check"] = {

selected_metrics = set(metrics) if metrics else {"mae"}
pass_config = {
"type": "OnnxDiscrepancyCheck",
"reference_model_path": reference_model_path,
"max_mae": 0.1,
"report_output_dir": report_dir,
}
# Enforce the max-absolute-error threshold only when the accuracy metric is requested.
if "mae" in selected_metrics:
pass_config["max_mae"] = 0.1
# Disable the latency/speedup measurement when the speedup metric is not requested.
if "speedup" not in selected_metrics:
pass_config["timing_iterations"] = 0

passes["discrepancy_check"] = pass_config
run_config["passes"] = passes
return run_config

Expand Down Expand Up @@ -135,12 +159,13 @@ def _run_workflow(self):
from olive.workflows import run as olive_run

validate_test_output_path(self.args.output_path, getattr(self.args, "test", None))
warn_unused_test_metrics(getattr(self.args, "test", None), getattr(self.args, "test_metrics", None))
Path(self.args.output_path).mkdir(parents=True, exist_ok=True)

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self._get_run_config(tempdir)
if getattr(self.args, "test", None) not in (None, False):
run_config = add_discrepancy_check_pass(run_config)
run_config = add_discrepancy_check_pass(run_config, getattr(self.args, "test_metrics", None))
if self.args.save_config_file or self.args.dry_run:
self._save_config_file(run_config)
if self.args.dry_run:
Expand Down Expand Up @@ -505,6 +530,17 @@ def add_input_model_options(
"Optionally provide a folder where the generated test model should be saved and reused."
),
)
model_group.add_argument(
"--test_metrics",
type=str,
nargs="+",
choices=list(TEST_METRICS),
help=(
"Metrics to evaluate during a --test run: 'mae' enforces the max absolute error between the "
"ONNX and reference model outputs, and 'speedup' measures ONNX-vs-PyTorch inference latency. "
"Defaults to all metrics. Only used together with --test."
),
)

if enable_hf_adapter:
assert enable_hf, "enable_hf must be True when enable_hf_adapter is True."
Expand Down
4 changes: 3 additions & 1 deletion olive/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
mark_test_output_path,
save_discrepancy_check_results,
validate_test_output_path,
warn_unused_test_metrics,
)
from olive.telemetry import action

Expand Down Expand Up @@ -83,8 +84,9 @@ def run(self):

output_path = run_config.get("output_dir") or run_config.get("engine", {}).get("output_dir")
validate_test_output_path(output_path, self.args.test)
warn_unused_test_metrics(self.args.test, getattr(self.args, "test_metrics", None))
if self.args.test not in (None, False):
run_config = add_discrepancy_check_pass(run_config)
run_config = add_discrepancy_check_pass(run_config, getattr(self.args, "test_metrics", None))
workflow_output = olive_run(
run_config,
list_required_packages=self.args.list_required_packages,
Expand Down
59 changes: 59 additions & 0 deletions test/cli/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,62 @@ def test_get_input_model_config_no_crash_without_onnx_file_name(tmp_path):

# model_path should remain unchanged since no onnx_file_name to guide rewriting
assert config["config"]["model_path"] == stale_model_path


def _discrepancy_run_config():
return {
"input_model": {"type": "HfModel", "test_model_path": "ref_model"},
"output_dir": "out_dir",
}


def test_add_discrepancy_check_pass_default_enables_mae_only():
from olive.cli.base import add_discrepancy_check_pass

run_config = add_discrepancy_check_pass(_discrepancy_run_config())

pass_config = run_config["passes"]["discrepancy_check"]
assert pass_config["type"] == "OnnxDiscrepancyCheck"
assert pass_config["reference_model_path"] == "ref_model"
# default: mae only -> threshold enforced, timing disabled
assert pass_config["max_mae"] == 0.1
assert pass_config["timing_iterations"] == 0


def test_add_discrepancy_check_pass_speedup_only_disables_mae():
from olive.cli.base import add_discrepancy_check_pass

run_config = add_discrepancy_check_pass(_discrepancy_run_config(), metrics=["speedup"])

pass_config = run_config["passes"]["discrepancy_check"]
assert "max_mae" not in pass_config
assert "timing_iterations" not in pass_config


def test_add_discrepancy_check_pass_mae_only_disables_speedup():
from olive.cli.base import add_discrepancy_check_pass

run_config = add_discrepancy_check_pass(_discrepancy_run_config(), metrics=["mae"])

pass_config = run_config["passes"]["discrepancy_check"]
assert pass_config["max_mae"] == 0.1
assert pass_config["timing_iterations"] == 0


def test_warn_unused_test_metrics_logs_when_test_disabled():
from olive.cli.base import warn_unused_test_metrics

with patch("olive.cli.base.logger") as mock_logger:
warn_unused_test_metrics(test=None, metrics=["speedup"])

mock_logger.warning.assert_called_once()
assert "--test_metrics is ignored" in mock_logger.warning.call_args[0][0]


def test_warn_unused_test_metrics_silent_when_test_enabled():
from olive.cli.base import warn_unused_test_metrics

with patch("olive.cli.base.logger") as mock_logger:
warn_unused_test_metrics(test=True, metrics=["speedup"])

mock_logger.warning.assert_not_called()
3 changes: 2 additions & 1 deletion test/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def test_workflow_run_command_with_test_override(mock_run, tmp_path):
"discrepancy_check": {
"type": "OnnxDiscrepancyCheck",
"reference_model_path": test_model_path,
"max_mae": 0.1,
"report_output_dir": output_dir,
"max_mae": 0.1,
"timing_iterations": 0,
}
},
},
Expand Down
Loading