diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index a9bb7e1ac..bfeec6cbc 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -110,6 +110,11 @@ def _longest_common_token_sequence(seq_a: list[int], seq_b: list[int]) -> int: return length +def _format_seconds(value: Optional[float]) -> str: + """Format an optional latency value (in seconds) for logging.""" + return "n/a" if value is None else f"{value:.4f}s" + + class OnnxDiscrepancyCheck(Pass): """Validates ONNX model outputs against a reference PyTorch model. @@ -122,6 +127,8 @@ class OnnxDiscrepancyCheck(Pass): - Inference speedup of ONNX over PyTorch on the target device (or CPU fallback) - Longest common token sequence from the beginning between transformers generate and ONNX Runtime GenAI generate (when enabled) + - Time-to-first-token and time-to-first-N-tokens latencies for both transformers + and ONNX Runtime GenAI generation (when enabled) The pass status is marked as failed if any configured threshold is exceeded. """ @@ -205,6 +212,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon default_value=32, description="Maximum number of new tokens to generate for the token sequence comparison.", ), + "time_to_first_n_tokens": PassConfigParam( + type_=int, + default_value=5, + description=( + "Number of leading generated tokens used for the time-to-first-N-tokens latency " + "measurement reported for both transformers and ONNX Runtime GenAI." + ), + ), "min_longest_common_tokens": PassConfigParam( type_=Optional[int], default_value=None, @@ -406,8 +421,9 @@ def _run_for_config( # Generation token sequence comparison (transformers vs ONNX Runtime GenAI) if config.genai_model_path: - longest_common = self.compare_generation(config, ref_model) - results["longest_common_token_sequence"] = longest_common + gen_results = self.compare_generation(config, ref_model) + longest_common = gen_results["longest_common_token_sequence"] + results.update(gen_results) results["genai_model_path"] = config.genai_model_path if config.min_longest_common_tokens is not None and longest_common < config.min_longest_common_tokens: results["status"] = "failed" @@ -496,27 +512,63 @@ def _measure_speedup( return speedup - def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: - """Run generation on both transformers and GenAI, return longest common token sequence length.""" + def compare_generation(self, config: type[BasePassConfig], ref_model) -> dict: + """Run generation on both transformers and GenAI and compare them. + + Returns a dict with the longest common token sequence length and the time-to-first-token + and time-to-first-N-tokens latencies (in seconds) for both transformers and ONNX Runtime + GenAI, where N is ``config.time_to_first_n_tokens``. + """ try: import onnxruntime_genai as og except ImportError as exc: raise ImportError("Please install `onnxruntime-genai` to enable generation comparison.") from exc - from transformers import AutoTokenizer + from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList tokenizer = AutoTokenizer.from_pretrained(config.reference_model_path) + max_new_tokens = config.generate_max_new_tokens + first_n = max(1, min(config.time_to_first_n_tokens, max_new_tokens)) if max_new_tokens > 0 else 0 + # Transformers generation input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids - input_ids = input_ids.to(ref_model.device) import torch + input_ids = input_ids.to(ref_model.device) + use_cuda_sync = ref_model.device.type == "cuda" + + prompt_token_count = input_ids.shape[-1] + transformers_latency = {"start": None, "ttft": None, "ttfn": None} + + class _TransformersLatencyStopCriteria(StoppingCriteria): + def __call__(self, generated_ids, scores, **kwargs) -> bool: + generated_token_count = generated_ids.shape[-1] - prompt_token_count + if generated_token_count >= 1 and transformers_latency["ttft"] is None: + transformers_latency["ttft"] = time.perf_counter() - transformers_latency["start"] + if generated_token_count >= first_n and transformers_latency["ttfn"] is None: + transformers_latency["ttfn"] = time.perf_counter() - transformers_latency["start"] + return False + with torch.no_grad(): + if use_cuda_sync: + torch.cuda.synchronize() + start = time.perf_counter() + transformers_latency["start"] = start transformers_output = ref_model.generate( input_ids, - max_new_tokens=config.generate_max_new_tokens, + max_new_tokens=max_new_tokens, do_sample=False, + stopping_criteria=StoppingCriteriaList([_TransformersLatencyStopCriteria()]), ) + if use_cuda_sync: + torch.cuda.synchronize() + transformers_elapsed = time.perf_counter() - start + if max_new_tokens > 0: + transformers_ttft = transformers_latency["ttft"] or transformers_elapsed + transformers_ttfn = transformers_latency["ttfn"] or transformers_elapsed + else: + transformers_ttft = None + transformers_ttfn = None transformers_tokens = transformers_output[0].cpu().tolist() # ONNX Runtime GenAI generation @@ -525,26 +577,48 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: genai_input_ids = genai_tokenizer.encode(config.generate_prompt) params = og.GeneratorParams(genai_model) - params.set_search_options(max_length=len(genai_input_ids) + config.generate_max_new_tokens, do_sample=False) + params.set_search_options(max_length=len(genai_input_ids) + max_new_tokens, do_sample=False) generator = og.Generator(genai_model, params) generator.append_tokens([genai_input_ids]) genai_tokens = list(genai_input_ids) + genai_ttft = None + genai_ttfn = None + num_generated = 0 + start = time.perf_counter() while not generator.is_done(): generator.generate_next_token() genai_tokens.append(generator.get_next_tokens()[0]) + num_generated += 1 + if num_generated == 1: + genai_ttft = time.perf_counter() - start + if num_generated == first_n: + genai_ttfn = time.perf_counter() - start del generator longest_common = _longest_common_token_sequence(transformers_tokens, genai_tokens) + gen_results = { + "longest_common_token_sequence": longest_common, + "time_to_first_n_tokens": first_n, + "transformers_time_to_first_token_s": transformers_ttft, + "transformers_time_to_first_n_tokens_s": transformers_ttfn, + "genai_time_to_first_token_s": genai_ttft, + "genai_time_to_first_n_tokens_s": genai_ttfn, + } + gen_summary = ( f"OnnxDiscrepancyCheck generation comparison: " f"transformers_len={len(transformers_tokens)}, genai_len={len(genai_tokens)}, " - f"longest_common_token_sequence={longest_common}" + f"longest_common_token_sequence={longest_common}, " + f"transformers_ttft={_format_seconds(transformers_ttft)}, " + f"transformers_time_to_first_{first_n}_tokens={_format_seconds(transformers_ttfn)}, " + f"genai_ttft={_format_seconds(genai_ttft)}, " + f"genai_time_to_first_{first_n}_tokens={_format_seconds(genai_ttfn)}" ) logger.info(gen_summary) - return longest_common + return gen_results def _export_reference_model(self, ref_model, output_model_path: str): """Save the reference PyTorch model weights for direct comparison.""" diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py index 8bb53a966..36ad470f0 100644 --- a/test/passes/onnx/test_discrepancy_check.py +++ b/test/passes/onnx/test_discrepancy_check.py @@ -60,6 +60,7 @@ def test_compare_generation_returns_common_prefix_length(self): config.genai_model_path = "mock_genai_model" config.generate_prompt = "Hello world" config.generate_max_new_tokens = 10 + config.time_to_first_n_tokens = 5 # Mock transformers tokenizer and model mock_tokenizer = MagicMock() @@ -107,7 +108,17 @@ def get_next_tokens_side_effect(): mock_generator.append_tokens.assert_called_once_with([[1, 2, 3]]) # Common prefix: [1, 2, 3, 10, 11] = 5 tokens before divergence - assert result == 5 + assert result["longest_common_token_sequence"] == 5 + # Latency metrics are exposed for both transformers and ONNX Runtime GenAI. + assert result["time_to_first_n_tokens"] == 5 + for key in ( + "transformers_time_to_first_token_s", + "transformers_time_to_first_n_tokens_s", + ): + assert key in result + assert isinstance(result[key], float) + for key in ("genai_time_to_first_token_s", "genai_time_to_first_n_tokens_s"): + assert key in result def test_compare_generation_fully_matching(self): """Test when both outputs are identical.""" @@ -120,6 +131,7 @@ def test_compare_generation_fully_matching(self): config.genai_model_path = "mock_genai_model" config.generate_prompt = "Test" config.generate_max_new_tokens = 5 + config.time_to_first_n_tokens = 5 mock_tokenizer = MagicMock() mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]])) @@ -162,7 +174,64 @@ def get_next_tokens_side_effect(): mock_generator.append_tokens.assert_called_once_with([[10, 20]]) # All 5 tokens match - assert result == 5 + assert result["longest_common_token_sequence"] == 5 + assert result["time_to_first_n_tokens"] == 5 + for key in ( + "transformers_time_to_first_token_s", + "transformers_time_to_first_n_tokens_s", + ): + assert key in result + assert isinstance(result[key], float) + assert "genai_time_to_first_token_s" in result + assert isinstance(result["genai_time_to_first_token_s"], float) + assert "genai_time_to_first_n_tokens_s" in result + assert result["genai_time_to_first_n_tokens_s"] is None or isinstance( + result["genai_time_to_first_n_tokens_s"], float + ) + + def test_compare_generation_with_zero_max_new_tokens(self): + """Test that latency metrics are skipped when max_new_tokens is zero.""" + import torch + + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + config = MagicMock() + config.reference_model_path = "mock_model" + config.genai_model_path = "mock_genai_model" + config.generate_prompt = "Test" + config.generate_max_new_tokens = 0 + config.time_to_first_n_tokens = 5 + + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]])) + + mock_ref_model = MagicMock() + mock_ref_model.device = torch.device("cpu") + mock_ref_model.generate.return_value = torch.tensor([[10, 20]]) + + mock_og = MagicMock() + mock_og.Model.return_value = MagicMock() + mock_genai_tokenizer = MagicMock() + mock_og.Tokenizer.return_value = mock_genai_tokenizer + mock_genai_tokenizer.encode.return_value = [10, 20] + mock_og.GeneratorParams.return_value = MagicMock() + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_og.Generator.return_value = mock_generator + + with ( + patch.dict(sys.modules, {"onnxruntime_genai": mock_og}), + patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), + ): + pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) + result = pass_instance.compare_generation(config, mock_ref_model) + + assert mock_ref_model.generate.call_count == 1 + assert mock_ref_model.generate.call_args.kwargs["max_new_tokens"] == 0 + assert result["time_to_first_n_tokens"] == 0 + assert result["transformers_time_to_first_token_s"] is None + assert result["transformers_time_to_first_n_tokens_s"] is None class TestWeightDtypeInference: