Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 84 additions & 10 deletions olive/passes/onnx/discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def _longest_common_token_sequence(seq_a: list[int], seq_b: list[int]) -> int:
return length


def _format_seconds(value: Optional[float]) -> str:
"""Format an optional latency value (in seconds) for logging."""
return "n/a" if value is None else f"{value:.4f}s"


class OnnxDiscrepancyCheck(Pass):
"""Validates ONNX model outputs against a reference PyTorch model.

Expand All @@ -122,6 +127,8 @@ class OnnxDiscrepancyCheck(Pass):
- Inference speedup of ONNX over PyTorch on the target device (or CPU fallback)
- Longest common token sequence from the beginning between transformers
generate and ONNX Runtime GenAI generate (when enabled)
- Time-to-first-token and time-to-first-N-tokens latencies for both transformers
and ONNX Runtime GenAI generation (when enabled)

The pass status is marked as failed if any configured threshold is exceeded.
"""
Expand Down Expand Up @@ -205,6 +212,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
default_value=32,
description="Maximum number of new tokens to generate for the token sequence comparison.",
),
"time_to_first_n_tokens": PassConfigParam(
type_=int,
default_value=5,
description=(
"Number of leading generated tokens used for the time-to-first-N-tokens latency "
"measurement reported for both transformers and ONNX Runtime GenAI."
),
),
"min_longest_common_tokens": PassConfigParam(
type_=Optional[int],
default_value=None,
Expand Down Expand Up @@ -406,8 +421,9 @@ def _run_for_config(

# Generation token sequence comparison (transformers vs ONNX Runtime GenAI)
if config.genai_model_path:
longest_common = self.compare_generation(config, ref_model)
results["longest_common_token_sequence"] = longest_common
gen_results = self.compare_generation(config, ref_model)
longest_common = gen_results["longest_common_token_sequence"]
results.update(gen_results)
results["genai_model_path"] = config.genai_model_path
if config.min_longest_common_tokens is not None and longest_common < config.min_longest_common_tokens:
results["status"] = "failed"
Expand Down Expand Up @@ -496,27 +512,63 @@ def _measure_speedup(

return speedup

def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
"""Run generation on both transformers and GenAI, return longest common token sequence length."""
def compare_generation(self, config: type[BasePassConfig], ref_model) -> dict:
"""Run generation on both transformers and GenAI and compare them.

Returns a dict with the longest common token sequence length and the time-to-first-token
and time-to-first-N-tokens latencies (in seconds) for both transformers and ONNX Runtime
GenAI, where N is ``config.time_to_first_n_tokens``.
"""
try:
import onnxruntime_genai as og
except ImportError as exc:
raise ImportError("Please install `onnxruntime-genai` to enable generation comparison.") from exc
from transformers import AutoTokenizer
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList

tokenizer = AutoTokenizer.from_pretrained(config.reference_model_path)

max_new_tokens = config.generate_max_new_tokens
first_n = max(1, min(config.time_to_first_n_tokens, max_new_tokens)) if max_new_tokens > 0 else 0

# Transformers generation
input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(ref_model.device)
import torch

input_ids = input_ids.to(ref_model.device)
use_cuda_sync = ref_model.device.type == "cuda"

prompt_token_count = input_ids.shape[-1]
transformers_latency = {"start": None, "ttft": None, "ttfn": None}

class _TransformersLatencyStopCriteria(StoppingCriteria):
def __call__(self, generated_ids, scores, **kwargs) -> bool:
generated_token_count = generated_ids.shape[-1] - prompt_token_count
if generated_token_count >= 1 and transformers_latency["ttft"] is None:
transformers_latency["ttft"] = time.perf_counter() - transformers_latency["start"]
if generated_token_count >= first_n and transformers_latency["ttfn"] is None:
transformers_latency["ttfn"] = time.perf_counter() - transformers_latency["start"]
return False

with torch.no_grad():
if use_cuda_sync:
torch.cuda.synchronize()
start = time.perf_counter()
transformers_latency["start"] = start
transformers_output = ref_model.generate(
input_ids,
max_new_tokens=config.generate_max_new_tokens,
max_new_tokens=max_new_tokens,
do_sample=False,
stopping_criteria=StoppingCriteriaList([_TransformersLatencyStopCriteria()]),
)
if use_cuda_sync:
torch.cuda.synchronize()
transformers_elapsed = time.perf_counter() - start
if max_new_tokens > 0:
transformers_ttft = transformers_latency["ttft"] or transformers_elapsed
transformers_ttfn = transformers_latency["ttfn"] or transformers_elapsed
else:
transformers_ttft = None
transformers_ttfn = None
transformers_tokens = transformers_output[0].cpu().tolist()

# ONNX Runtime GenAI generation
Expand All @@ -525,26 +577,48 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
genai_input_ids = genai_tokenizer.encode(config.generate_prompt)

params = og.GeneratorParams(genai_model)
params.set_search_options(max_length=len(genai_input_ids) + config.generate_max_new_tokens, do_sample=False)
params.set_search_options(max_length=len(genai_input_ids) + max_new_tokens, do_sample=False)

generator = og.Generator(genai_model, params)
generator.append_tokens([genai_input_ids])
genai_tokens = list(genai_input_ids)
genai_ttft = None
genai_ttfn = None
num_generated = 0
start = time.perf_counter()
while not generator.is_done():
generator.generate_next_token()
genai_tokens.append(generator.get_next_tokens()[0])
num_generated += 1
if num_generated == 1:
genai_ttft = time.perf_counter() - start
if num_generated == first_n:
genai_ttfn = time.perf_counter() - start
Comment thread
xadupre marked this conversation as resolved.
del generator

longest_common = _longest_common_token_sequence(transformers_tokens, genai_tokens)

gen_results = {
"longest_common_token_sequence": longest_common,
"time_to_first_n_tokens": first_n,
"transformers_time_to_first_token_s": transformers_ttft,
"transformers_time_to_first_n_tokens_s": transformers_ttfn,
"genai_time_to_first_token_s": genai_ttft,
"genai_time_to_first_n_tokens_s": genai_ttfn,
}

gen_summary = (
f"OnnxDiscrepancyCheck generation comparison: "
f"transformers_len={len(transformers_tokens)}, genai_len={len(genai_tokens)}, "
f"longest_common_token_sequence={longest_common}"
f"longest_common_token_sequence={longest_common}, "
f"transformers_ttft={_format_seconds(transformers_ttft)}, "
f"transformers_time_to_first_{first_n}_tokens={_format_seconds(transformers_ttfn)}, "
f"genai_ttft={_format_seconds(genai_ttft)}, "
f"genai_time_to_first_{first_n}_tokens={_format_seconds(genai_ttfn)}"
)
logger.info(gen_summary)

return longest_common
return gen_results

def _export_reference_model(self, ref_model, output_model_path: str):
"""Save the reference PyTorch model weights for direct comparison."""
Expand Down
73 changes: 71 additions & 2 deletions test/passes/onnx/test_discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_compare_generation_returns_common_prefix_length(self):
config.genai_model_path = "mock_genai_model"
config.generate_prompt = "Hello world"
config.generate_max_new_tokens = 10
config.time_to_first_n_tokens = 5

# Mock transformers tokenizer and model
mock_tokenizer = MagicMock()
Expand Down Expand Up @@ -107,7 +108,17 @@ def get_next_tokens_side_effect():

mock_generator.append_tokens.assert_called_once_with([[1, 2, 3]])
# Common prefix: [1, 2, 3, 10, 11] = 5 tokens before divergence
assert result == 5
assert result["longest_common_token_sequence"] == 5
# Latency metrics are exposed for both transformers and ONNX Runtime GenAI.
assert result["time_to_first_n_tokens"] == 5
for key in (
"transformers_time_to_first_token_s",
"transformers_time_to_first_n_tokens_s",
):
assert key in result
assert isinstance(result[key], float)
for key in ("genai_time_to_first_token_s", "genai_time_to_first_n_tokens_s"):
assert key in result

def test_compare_generation_fully_matching(self):
"""Test when both outputs are identical."""
Expand All @@ -120,6 +131,7 @@ def test_compare_generation_fully_matching(self):
config.genai_model_path = "mock_genai_model"
config.generate_prompt = "Test"
config.generate_max_new_tokens = 5
config.time_to_first_n_tokens = 5

mock_tokenizer = MagicMock()
mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]]))
Expand Down Expand Up @@ -162,7 +174,64 @@ def get_next_tokens_side_effect():

mock_generator.append_tokens.assert_called_once_with([[10, 20]])
# All 5 tokens match
assert result == 5
assert result["longest_common_token_sequence"] == 5
Comment thread
xadupre marked this conversation as resolved.
assert result["time_to_first_n_tokens"] == 5
for key in (
"transformers_time_to_first_token_s",
"transformers_time_to_first_n_tokens_s",
):
assert key in result
assert isinstance(result[key], float)
assert "genai_time_to_first_token_s" in result
assert isinstance(result["genai_time_to_first_token_s"], float)
assert "genai_time_to_first_n_tokens_s" in result
assert result["genai_time_to_first_n_tokens_s"] is None or isinstance(
result["genai_time_to_first_n_tokens_s"], float
)

def test_compare_generation_with_zero_max_new_tokens(self):
"""Test that latency metrics are skipped when max_new_tokens is zero."""
import torch

from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

config = MagicMock()
config.reference_model_path = "mock_model"
config.genai_model_path = "mock_genai_model"
config.generate_prompt = "Test"
config.generate_max_new_tokens = 0
config.time_to_first_n_tokens = 5

mock_tokenizer = MagicMock()
mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]]))

mock_ref_model = MagicMock()
mock_ref_model.device = torch.device("cpu")
mock_ref_model.generate.return_value = torch.tensor([[10, 20]])

mock_og = MagicMock()
mock_og.Model.return_value = MagicMock()
mock_genai_tokenizer = MagicMock()
mock_og.Tokenizer.return_value = mock_genai_tokenizer
mock_genai_tokenizer.encode.return_value = [10, 20]
mock_og.GeneratorParams.return_value = MagicMock()

mock_generator = MagicMock()
mock_generator.is_done.return_value = True
mock_og.Generator.return_value = mock_generator

with (
patch.dict(sys.modules, {"onnxruntime_genai": mock_og}),
patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer),
):
pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck)
result = pass_instance.compare_generation(config, mock_ref_model)

assert mock_ref_model.generate.call_count == 1
assert mock_ref_model.generate.call_args.kwargs["max_new_tokens"] == 0
assert result["time_to_first_n_tokens"] == 0
assert result["transformers_time_to_first_token_s"] is None
assert result["transformers_time_to_first_n_tokens_s"] is None


class TestWeightDtypeInference:
Expand Down
Loading