Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions olive/passes/onnx/discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _run_for_config(

# Measure inference speedup (ONNX vs PyTorch) on the target device
if config.timing_iterations > 0:
self._measure_speedup(
timing = self._measure_speedup(
ref_model,
session,
dataloader,
Expand All @@ -377,6 +377,11 @@ def _run_for_config(
config.warmup_iterations,
config.timing_iterations,
)
if timing is not None:
pytorch_time, onnx_time, speedup = timing
results["pytorch_latency_s"] = pytorch_time
results["onnx_latency_s"] = onnx_time
results["speedup"] = speedup
Comment thread
xadupre marked this conversation as resolved.
else:
logger.info(
"OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.",
Expand Down Expand Up @@ -431,8 +436,13 @@ def _run_for_config(

def _measure_speedup(
self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations
):
"""Measure inference speedup of ONNX over PyTorch on the target device."""
) -> tuple[float, float, float] | None:
"""Measure inference latencies and speedup of ONNX over PyTorch on the target device.

Returns a tuple ``(pytorch_time, onnx_time, speedup)`` of the average PyTorch and ONNX
per-iteration latencies (in seconds) and the ONNX-over-PyTorch speedup, or ``None`` when
measurement is skipped.
"""
Comment thread
xadupre marked this conversation as resolved.
if timing_iterations <= 0:
logger.info(
"OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.",
Expand Down Expand Up @@ -494,7 +504,7 @@ def _measure_speedup(
torch_device,
)

return speedup
return pytorch_time, onnx_time, speedup

def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
"""Run generation on both transformers and GenAI, return longest common token sequence length."""
Expand Down
29 changes: 29 additions & 0 deletions test/passes/onnx/test_discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,32 @@ def test_measure_speedup_skips_when_timing_iterations_is_zero(self):
assert result is None
ref_model.assert_not_called()
session.run.assert_not_called()

def test_measure_speedup_returns_latencies_and_speedup(self):
import torch

from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck)
ref_model = MagicMock()
session = MagicMock()
input_data = {"input_ids": torch.tensor([[1, 2, 3]], dtype=torch.int64)}
dataloader = [(input_data, None)]

with (
patch("olive.common.utils.format_data", return_value={"input_ids": [1, 2, 3]}),
patch("olive.passes.onnx.discrepancy_check.time.perf_counter", side_effect=[10.0, 14.0, 20.0, 22.0]),
):
result = pass_instance._measure_speedup(
ref_model=ref_model,
session=session,
dataloader=dataloader,
io_config=MagicMock(),
torch_device=torch.device("cpu"),
warmup_iterations=1,
timing_iterations=2,
)

assert result == (2.0, 1.0, 2.0)
Comment thread
xadupre marked this conversation as resolved.
assert ref_model.call_count == 3
assert session.run.call_count == 3
Loading