From b21b4f68279308eb781aa78b6136cdc193a3a272 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 22 Jun 2026 13:45:34 +0200 Subject: [PATCH 1/3] expose latencies with the speedup in OnnxDiscrepancyCheck --- olive/passes/onnx/discrepancy_check.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index 661100d1c..f4ce31210 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -293,7 +293,7 @@ def _run_for_config( # Measure inference speedup (ONNX vs PyTorch) on the target device if config.timing_iterations > 0: - self._measure_speedup( + timing = self._measure_speedup( ref_model, session, dataloader, @@ -302,6 +302,11 @@ def _run_for_config( config.warmup_iterations, config.timing_iterations, ) + if timing is not None: + pytorch_time, onnx_time, speedup = timing + results["pytorch_latency_s"] = pytorch_time + results["onnx_latency_s"] = onnx_time + results["speedup"] = speedup else: logger.info( "OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.", @@ -357,7 +362,12 @@ def _run_for_config( def _measure_speedup( self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations ): - """Measure inference speedup of ONNX over PyTorch on the target device.""" + """Measure inference latencies and speedup of ONNX over PyTorch on the target device. + + Returns a tuple ``(pytorch_time, onnx_time, speedup)`` of the average PyTorch and ONNX + per-iteration latencies (in seconds) and the ONNX-over-PyTorch speedup, or ``None`` when + measurement is skipped. + """ if timing_iterations <= 0: logger.info( "OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.", @@ -419,7 +429,7 @@ def _measure_speedup( torch_device, ) - return speedup + return pytorch_time, onnx_time, speedup def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: """Run generation on both transformers and GenAI, return longest common token sequence length.""" From 72c5d4e40584ef808c977694dfd5da542cae8ef0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 11:58:34 +0000 Subject: [PATCH 2/3] Extend discrepancy check unit test for latency tuple --- test/passes/onnx/test_discrepancy_check.py | 29 ++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py index c6a2f83eb..31e738f6f 100644 --- a/test/passes/onnx/test_discrepancy_check.py +++ b/test/passes/onnx/test_discrepancy_check.py @@ -192,3 +192,32 @@ def test_measure_speedup_skips_when_timing_iterations_is_zero(self): assert result is None ref_model.assert_not_called() session.run.assert_not_called() + + def test_measure_speedup_returns_latencies_and_speedup(self): + import torch + + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) + ref_model = MagicMock() + session = MagicMock() + input_data = {"input_ids": torch.tensor([[1, 2, 3]], dtype=torch.int64)} + dataloader = [(input_data, None)] + + with ( + patch("olive.common.utils.format_data", return_value={"input_ids": [1, 2, 3]}), + patch("olive.passes.onnx.discrepancy_check.time.perf_counter", side_effect=[10.0, 14.0, 20.0, 22.0]), + ): + result = pass_instance._measure_speedup( + ref_model=ref_model, + session=session, + dataloader=dataloader, + io_config=MagicMock(), + torch_device=torch.device("cpu"), + warmup_iterations=1, + timing_iterations=2, + ) + + assert result == (2.0, 1.0, 2.0) + assert ref_model.call_count == 3 + assert session.run.call_count == 3 From 804bb92762c4d8de35064e084cf16c9af09e272a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:43:04 +0000 Subject: [PATCH 3/3] Add return type annotation to _measure_speedup --- olive/passes/onnx/discrepancy_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index f4ce31210..4e0a603f4 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -361,7 +361,7 @@ def _run_for_config( def _measure_speedup( self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations - ): + ) -> tuple[float, float, float] | None: """Measure inference latencies and speedup of ONNX over PyTorch on the target device. Returns a tuple ``(pytorch_time, onnx_time, speedup)`` of the average PyTorch and ONNX