diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index a9bb7e1ac..321758f03 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -368,7 +368,7 @@ def _run_for_config( # Measure inference speedup (ONNX vs PyTorch) on the target device if config.timing_iterations > 0: - self._measure_speedup( + timing = self._measure_speedup( ref_model, session, dataloader, @@ -377,6 +377,11 @@ def _run_for_config( config.warmup_iterations, config.timing_iterations, ) + if timing is not None: + pytorch_time, onnx_time, speedup = timing + results["pytorch_latency_s"] = pytorch_time + results["onnx_latency_s"] = onnx_time + results["speedup"] = speedup else: logger.info( "OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.", @@ -431,8 +436,13 @@ def _run_for_config( def _measure_speedup( self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations - ): - """Measure inference speedup of ONNX over PyTorch on the target device.""" + ) -> tuple[float, float, float] | None: + """Measure inference latencies and speedup of ONNX over PyTorch on the target device. + + Returns a tuple ``(pytorch_time, onnx_time, speedup)`` of the average PyTorch and ONNX + per-iteration latencies (in seconds) and the ONNX-over-PyTorch speedup, or ``None`` when + measurement is skipped. + """ if timing_iterations <= 0: logger.info( "OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.", @@ -494,7 +504,7 @@ def _measure_speedup( torch_device, ) - return speedup + return pytorch_time, onnx_time, speedup def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: """Run generation on both transformers and GenAI, return longest common token sequence length.""" diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py index 8bb53a966..79dde628d 100644 --- a/test/passes/onnx/test_discrepancy_check.py +++ b/test/passes/onnx/test_discrepancy_check.py @@ -313,3 +313,32 @@ def test_measure_speedup_skips_when_timing_iterations_is_zero(self): assert result is None ref_model.assert_not_called() session.run.assert_not_called() + + def test_measure_speedup_returns_latencies_and_speedup(self): + import torch + + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) + ref_model = MagicMock() + session = MagicMock() + input_data = {"input_ids": torch.tensor([[1, 2, 3]], dtype=torch.int64)} + dataloader = [(input_data, None)] + + with ( + patch("olive.common.utils.format_data", return_value={"input_ids": [1, 2, 3]}), + patch("olive.passes.onnx.discrepancy_check.time.perf_counter", side_effect=[10.0, 14.0, 20.0, 22.0]), + ): + result = pass_instance._measure_speedup( + ref_model=ref_model, + session=session, + dataloader=dataloader, + io_config=MagicMock(), + torch_device=torch.device("cpu"), + warmup_iterations=1, + timing_iterations=2, + ) + + assert result == (2.0, 1.0, 2.0) + assert ref_model.call_count == 3 + assert session.run.call_count == 3