From bf0a9781d113cdc2b088f305f123ab3e8f4f3a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 22 Jun 2026 14:02:13 +0200 Subject: [PATCH 1/4] add time to first token in OnnxDiscrepancyCheck --- olive/passes/onnx/discrepancy_check.py | 89 ++++++++++++++++++---- test/passes/onnx/test_discrepancy_check.py | 16 +++- 2 files changed, 89 insertions(+), 16 deletions(-) diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index 661100d1c..d4fd2e194 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -56,6 +56,11 @@ def _longest_common_token_sequence(seq_a: list[int], seq_b: list[int]) -> int: return length +def _format_seconds(value: Optional[float]) -> str: + """Format an optional latency value (in seconds) for logging.""" + return "n/a" if value is None else f"{value:.4f}s" + + class OnnxDiscrepancyCheck(Pass): """Validates ONNX model outputs against a reference PyTorch model. @@ -68,6 +73,8 @@ class OnnxDiscrepancyCheck(Pass): - Inference speedup of ONNX over PyTorch on the target device (or CPU fallback) - Longest common token sequence from the beginning between transformers generate and ONNX Runtime GenAI generate (when enabled) + - Time-to-first-token and time-to-first-N-tokens latencies for both transformers + and ONNX Runtime GenAI generation (when enabled) The pass status is marked as failed if any configured threshold is exceeded. """ @@ -151,6 +158,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon default_value=32, description="Maximum number of new tokens to generate for the token sequence comparison.", ), + "time_to_first_n_tokens": PassConfigParam( + type_=int, + default_value=5, + description=( + "Number of leading generated tokens used for the time-to-first-N-tokens latency " + "measurement reported for both transformers and ONNX Runtime GenAI." + ), + ), "min_longest_common_tokens": PassConfigParam( type_=Optional[int], default_value=None, @@ -331,8 +346,9 @@ def _run_for_config( # Generation token sequence comparison (transformers vs ONNX Runtime GenAI) if config.genai_model_path: - longest_common = self.compare_generation(config, ref_model) - results["longest_common_token_sequence"] = longest_common + gen_results = self.compare_generation(config, ref_model) + longest_common = gen_results["longest_common_token_sequence"] + results.update(gen_results) results["genai_model_path"] = config.genai_model_path if config.min_longest_common_tokens is not None and longest_common < config.min_longest_common_tokens: results["status"] = "failed" @@ -421,8 +437,13 @@ def _measure_speedup( return speedup - def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: - """Run generation on both transformers and GenAI, return longest common token sequence length.""" + def compare_generation(self, config: type[BasePassConfig], ref_model) -> dict: + """Run generation on both transformers and GenAI and compare them. + + Returns a dict with the longest common token sequence length and the time-to-first-token + and time-to-first-N-tokens latencies (in seconds) for both transformers and ONNX Runtime + GenAI, where N is ``config.time_to_first_n_tokens``. + """ try: import onnxruntime_genai as og except ImportError as exc: @@ -431,17 +452,35 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: tokenizer = AutoTokenizer.from_pretrained(config.reference_model_path) + max_new_tokens = config.generate_max_new_tokens + first_n = max(1, min(config.time_to_first_n_tokens, max_new_tokens)) + # Transformers generation input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids - input_ids = input_ids.to(ref_model.device) import torch - with torch.no_grad(): - transformers_output = ref_model.generate( - input_ids, - max_new_tokens=config.generate_max_new_tokens, - do_sample=False, - ) + input_ids = input_ids.to(ref_model.device) + use_cuda_sync = ref_model.device.type == "cuda" + + def _time_transformers_generate(num_new_tokens): + with torch.no_grad(): + if use_cuda_sync: + torch.cuda.synchronize() + start = time.perf_counter() + output = ref_model.generate( + input_ids, + max_new_tokens=num_new_tokens, + do_sample=False, + ) + if use_cuda_sync: + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + return output, elapsed + + # Time to first token and time to first N tokens (separate timed runs). + _, transformers_ttft = _time_transformers_generate(1) + _, transformers_ttfn = _time_transformers_generate(first_n) + transformers_output, _ = _time_transformers_generate(max_new_tokens) transformers_tokens = transformers_output[0].cpu().tolist() # ONNX Runtime GenAI generation @@ -450,26 +489,48 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: genai_input_ids = genai_tokenizer.encode(config.generate_prompt) params = og.GeneratorParams(genai_model) - params.set_search_options(max_length=len(genai_input_ids) + config.generate_max_new_tokens, do_sample=False) + params.set_search_options(max_length=len(genai_input_ids) + max_new_tokens, do_sample=False) generator = og.Generator(genai_model, params) generator.append_tokens([genai_input_ids]) genai_tokens = list(genai_input_ids) + genai_ttft = None + genai_ttfn = None + num_generated = 0 + start = time.perf_counter() while not generator.is_done(): generator.generate_next_token() genai_tokens.append(generator.get_next_tokens()[0]) + num_generated += 1 + if num_generated == 1: + genai_ttft = time.perf_counter() - start + if num_generated == first_n: + genai_ttfn = time.perf_counter() - start del generator longest_common = _longest_common_token_sequence(transformers_tokens, genai_tokens) + gen_results = { + "longest_common_token_sequence": longest_common, + "time_to_first_n_tokens": first_n, + "transformers_time_to_first_token_s": transformers_ttft, + "transformers_time_to_first_n_tokens_s": transformers_ttfn, + "genai_time_to_first_token_s": genai_ttft, + "genai_time_to_first_n_tokens_s": genai_ttfn, + } + gen_summary = ( f"OnnxDiscrepancyCheck generation comparison: " f"transformers_len={len(transformers_tokens)}, genai_len={len(genai_tokens)}, " - f"longest_common_token_sequence={longest_common}" + f"longest_common_token_sequence={longest_common}, " + f"transformers_ttft={transformers_ttft:.4f}s, " + f"transformers_time_to_first_{first_n}_tokens={transformers_ttfn:.4f}s, " + f"genai_ttft={_format_seconds(genai_ttft)}, " + f"genai_time_to_first_{first_n}_tokens={_format_seconds(genai_ttfn)}" ) logger.info(gen_summary) - return longest_common + return gen_results def _export_reference_model(self, ref_model, output_model_path: str): """Save the reference PyTorch model weights for direct comparison.""" diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py index c6a2f83eb..59a1867fd 100644 --- a/test/passes/onnx/test_discrepancy_check.py +++ b/test/passes/onnx/test_discrepancy_check.py @@ -60,6 +60,7 @@ def test_compare_generation_returns_common_prefix_length(self): config.genai_model_path = "mock_genai_model" config.generate_prompt = "Hello world" config.generate_max_new_tokens = 10 + config.time_to_first_n_tokens = 5 # Mock transformers tokenizer and model mock_tokenizer = MagicMock() @@ -107,7 +108,17 @@ def get_next_tokens_side_effect(): mock_generator.append_tokens.assert_called_once_with([[1, 2, 3]]) # Common prefix: [1, 2, 3, 10, 11] = 5 tokens before divergence - assert result == 5 + assert result["longest_common_token_sequence"] == 5 + # Latency metrics are exposed for both transformers and ONNX Runtime GenAI. + assert result["time_to_first_n_tokens"] == 5 + for key in ( + "transformers_time_to_first_token_s", + "transformers_time_to_first_n_tokens_s", + ): + assert key in result + assert isinstance(result[key], float) + for key in ("genai_time_to_first_token_s", "genai_time_to_first_n_tokens_s"): + assert key in result def test_compare_generation_fully_matching(self): """Test when both outputs are identical.""" @@ -120,6 +131,7 @@ def test_compare_generation_fully_matching(self): config.genai_model_path = "mock_genai_model" config.generate_prompt = "Test" config.generate_max_new_tokens = 5 + config.time_to_first_n_tokens = 5 mock_tokenizer = MagicMock() mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]])) @@ -162,7 +174,7 @@ def get_next_tokens_side_effect(): mock_generator.append_tokens.assert_called_once_with([[10, 20]]) # All 5 tokens match - assert result == 5 + assert result["longest_common_token_sequence"] == 5 class TestSpeedupSettings: From 1bdee25ff539c74d8af4298a2c217f1f0476aeb6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:02:53 +0000 Subject: [PATCH 2/4] Add latency key assertions to fully matching discrepancy test --- test/passes/onnx/test_discrepancy_check.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py index 59a1867fd..cad91f434 100644 --- a/test/passes/onnx/test_discrepancy_check.py +++ b/test/passes/onnx/test_discrepancy_check.py @@ -175,6 +175,19 @@ def get_next_tokens_side_effect(): mock_generator.append_tokens.assert_called_once_with([[10, 20]]) # All 5 tokens match assert result["longest_common_token_sequence"] == 5 + assert result["time_to_first_n_tokens"] == 5 + for key in ( + "transformers_time_to_first_token_s", + "transformers_time_to_first_n_tokens_s", + ): + assert key in result + assert isinstance(result[key], float) + assert "genai_time_to_first_token_s" in result + assert isinstance(result["genai_time_to_first_token_s"], float) + assert "genai_time_to_first_n_tokens_s" in result + assert result["genai_time_to_first_n_tokens_s"] is None or isinstance( + result["genai_time_to_first_n_tokens_s"], float + ) class TestSpeedupSettings: From 142ddea487a74d8cfb7c87f504767433af806e5c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:43:51 +0000 Subject: [PATCH 3/4] Handle zero max_new_tokens in generation metrics --- olive/passes/onnx/discrepancy_check.py | 14 ++++--- test/passes/onnx/test_discrepancy_check.py | 44 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index d4fd2e194..01c501539 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -453,7 +453,7 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> dict: tokenizer = AutoTokenizer.from_pretrained(config.reference_model_path) max_new_tokens = config.generate_max_new_tokens - first_n = max(1, min(config.time_to_first_n_tokens, max_new_tokens)) + first_n = max(1, min(config.time_to_first_n_tokens, max_new_tokens)) if max_new_tokens > 0 else 0 # Transformers generation input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids @@ -478,8 +478,12 @@ def _time_transformers_generate(num_new_tokens): return output, elapsed # Time to first token and time to first N tokens (separate timed runs). - _, transformers_ttft = _time_transformers_generate(1) - _, transformers_ttfn = _time_transformers_generate(first_n) + if max_new_tokens > 0: + _, transformers_ttft = _time_transformers_generate(1) + _, transformers_ttfn = _time_transformers_generate(first_n) + else: + transformers_ttft = None + transformers_ttfn = None transformers_output, _ = _time_transformers_generate(max_new_tokens) transformers_tokens = transformers_output[0].cpu().tolist() @@ -523,8 +527,8 @@ def _time_transformers_generate(num_new_tokens): f"OnnxDiscrepancyCheck generation comparison: " f"transformers_len={len(transformers_tokens)}, genai_len={len(genai_tokens)}, " f"longest_common_token_sequence={longest_common}, " - f"transformers_ttft={transformers_ttft:.4f}s, " - f"transformers_time_to_first_{first_n}_tokens={transformers_ttfn:.4f}s, " + f"transformers_ttft={_format_seconds(transformers_ttft)}, " + f"transformers_time_to_first_{first_n}_tokens={_format_seconds(transformers_ttfn)}, " f"genai_ttft={_format_seconds(genai_ttft)}, " f"genai_time_to_first_{first_n}_tokens={_format_seconds(genai_ttfn)}" ) diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py index cad91f434..9608182ff 100644 --- a/test/passes/onnx/test_discrepancy_check.py +++ b/test/passes/onnx/test_discrepancy_check.py @@ -189,6 +189,50 @@ def get_next_tokens_side_effect(): result["genai_time_to_first_n_tokens_s"], float ) + def test_compare_generation_with_zero_max_new_tokens(self): + """Test that latency metrics are skipped when max_new_tokens is zero.""" + import torch + + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + config = MagicMock() + config.reference_model_path = "mock_model" + config.genai_model_path = "mock_genai_model" + config.generate_prompt = "Test" + config.generate_max_new_tokens = 0 + config.time_to_first_n_tokens = 5 + + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]])) + + mock_ref_model = MagicMock() + mock_ref_model.device = torch.device("cpu") + mock_ref_model.generate.return_value = torch.tensor([[10, 20]]) + + mock_og = MagicMock() + mock_og.Model.return_value = MagicMock() + mock_genai_tokenizer = MagicMock() + mock_og.Tokenizer.return_value = mock_genai_tokenizer + mock_genai_tokenizer.encode.return_value = [10, 20] + mock_og.GeneratorParams.return_value = MagicMock() + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_og.Generator.return_value = mock_generator + + with ( + patch.dict(sys.modules, {"onnxruntime_genai": mock_og}), + patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), + ): + pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) + result = pass_instance.compare_generation(config, mock_ref_model) + + assert mock_ref_model.generate.call_count == 1 + assert mock_ref_model.generate.call_args.kwargs["max_new_tokens"] == 0 + assert result["time_to_first_n_tokens"] == 0 + assert result["transformers_time_to_first_token_s"] is None + assert result["transformers_time_to_first_n_tokens_s"] is None + class TestSpeedupSettings: def test_timing_iterations_default_is_5(self): From 39cac1cebb4910a9c586cc6b5f62292c28499998 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:45:07 +0000 Subject: [PATCH 4/4] Use single measured transformers generation for latency metrics --- olive/passes/onnx/discrepancy_check.py | 47 +++++++++++++++----------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index 01c501539..ea9a9bffe 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -448,7 +448,7 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> dict: import onnxruntime_genai as og except ImportError as exc: raise ImportError("Please install `onnxruntime-genai` to enable generation comparison.") from exc - from transformers import AutoTokenizer + from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList tokenizer = AutoTokenizer.from_pretrained(config.reference_model_path) @@ -462,29 +462,38 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> dict: input_ids = input_ids.to(ref_model.device) use_cuda_sync = ref_model.device.type == "cuda" - def _time_transformers_generate(num_new_tokens): - with torch.no_grad(): - if use_cuda_sync: - torch.cuda.synchronize() - start = time.perf_counter() - output = ref_model.generate( - input_ids, - max_new_tokens=num_new_tokens, - do_sample=False, - ) - if use_cuda_sync: - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - return output, elapsed + prompt_token_count = input_ids.shape[-1] + transformers_latency = {"start": None, "ttft": None, "ttfn": None} + + class _TransformersLatencyStopCriteria(StoppingCriteria): + def __call__(self, generated_ids, scores, **kwargs) -> bool: + generated_token_count = generated_ids.shape[-1] - prompt_token_count + if generated_token_count >= 1 and transformers_latency["ttft"] is None: + transformers_latency["ttft"] = time.perf_counter() - transformers_latency["start"] + if generated_token_count >= first_n and transformers_latency["ttfn"] is None: + transformers_latency["ttfn"] = time.perf_counter() - transformers_latency["start"] + return False - # Time to first token and time to first N tokens (separate timed runs). + with torch.no_grad(): + if use_cuda_sync: + torch.cuda.synchronize() + start = time.perf_counter() + transformers_latency["start"] = start + transformers_output = ref_model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + stopping_criteria=StoppingCriteriaList([_TransformersLatencyStopCriteria()]), + ) + if use_cuda_sync: + torch.cuda.synchronize() + transformers_elapsed = time.perf_counter() - start if max_new_tokens > 0: - _, transformers_ttft = _time_transformers_generate(1) - _, transformers_ttfn = _time_transformers_generate(first_n) + transformers_ttft = transformers_latency["ttft"] or transformers_elapsed + transformers_ttfn = transformers_latency["ttfn"] or transformers_elapsed else: transformers_ttft = None transformers_ttfn = None - transformers_output, _ = _time_transformers_generate(max_new_tokens) transformers_tokens = transformers_output[0].cpu().tolist() # ONNX Runtime GenAI generation