diff --git a/olive/data/component/pre_process_data.py b/olive/data/component/pre_process_data.py index 40ebff582..273ef0e18 100644 --- a/olive/data/component/pre_process_data.py +++ b/olive/data/component/pre_process_data.py @@ -298,6 +298,7 @@ def speech_transcription_pre_process( dataset, audio_col: str = "audio", text_col: str = "text", + id_col: str = "", sample_rate: int = 16000, max_samples: Optional[int] = None, limit: Optional[float] = None, @@ -307,12 +308,17 @@ def speech_transcription_pre_process( """Pre-process data for speech transcription (ASR) evaluation. Loads audio arrays and reference transcription text from a HuggingFace dataset. - Returns a dataset of (audio_array, reference_text) pairs suitable for WER evaluation. + Returns a dataset of ({"audio": audio_array, "file_name": name}, reference_text) pairs + suitable for WER evaluation. Args: dataset: HuggingFace dataset with audio and text columns. audio_col: Name of the audio column. Defaults to "audio". text_col: Name of the reference text column. Defaults to "text". + id_col: Name of a column to use as a per-sample identifier (e.g. an audio file name or + sample id). When set and present, its value is surfaced as ``file_name`` so it can be + included in the evaluation sample log. Falls back to the HuggingFace Audio feature's + ``path`` and then the dataset row index. Defaults to "". sample_rate: Target sample rate for audio. Defaults to 16000. max_samples: Maximum number of samples (deprecated, use limit). Defaults to None. limit: Sampling limit following Olive convention: @@ -343,26 +349,39 @@ def speech_transcription_pre_process( dataset = dataset.select(range(n)) class SpeechTranscriptionDataset: - """Dataset that returns (audio_array, reference_text) pairs. + """Dataset that returns ({"audio": audio_array, "file_name": name}, reference_text) pairs. Note: Use batch_size=1 in dataloader config as audio samples have variable lengths. """ - def __init__(self, hf_dataset, audio_column, text_column): + def __init__(self, hf_dataset, audio_column, text_column, id_column=""): self.dataset = hf_dataset self.audio_column = audio_column self.text_column = text_column + self.id_column = id_column def __len__(self): return len(self.dataset) def __getitem__(self, idx): - item = self.dataset[idx] + import os + import numpy as np - audio_array = np.array(item[self.audio_column]["array"], dtype=np.float32) + item = self.dataset[idx] + audio_item = item[self.audio_column] + audio_array = np.array(audio_item["array"], dtype=np.float32) reference_text = item[self.text_column] - return audio_array, reference_text + + path = audio_item.get("path") if isinstance(audio_item, dict) else None + if self.id_column and self.id_column in item and item[self.id_column] is not None: + file_name = str(item[self.id_column]) + elif path: + file_name = os.path.basename(str(path)) + else: + file_name = str(idx) + + return {"audio": audio_array, "file_name": file_name}, reference_text @staticmethod def collate_fn(batch): @@ -371,14 +390,15 @@ def collate_fn(batch): # batch_size=1 is expected for speech evaluation (variable-length audio) if len(batch) == 1: - audio, text = batch[0] - return (np.expand_dims(audio, 0), [text]) + input_dict, text = batch[0] + batched = {**input_dict, "audio": np.expand_dims(input_dict["audio"], 0)} + return (batched, [text]) # For batch_size > 1, return as lists (no padding) - audios = [item[0] for item in batch] + inputs = [item[0] for item in batch] texts = [item[1] for item in batch] - return (audios, texts) + return (inputs, texts) - return SpeechTranscriptionDataset(dataset, audio_col, text_col) + return SpeechTranscriptionDataset(dataset, audio_col, text_col, id_col) @Registry.register_pre_process() @@ -389,6 +409,7 @@ def vision_vqa_pre_process( answer_col: str = "answer", options_col: str = "", system_prompt: str = "", + id_col: str = "", max_length: int = 4096, max_samples: Optional[int] = None, limit: Optional[float] = None, @@ -415,6 +436,9 @@ def vision_vqa_pre_process( options are formatted as numbered choices and appended to the question. Defaults to "". system_prompt: System prompt to guide model responses (e.g., "Reply with only the option number"). Passed through to the evaluator. Defaults to "". + id_col: Name of a column to use as a per-sample identifier (e.g. an image file name or + sample id). When set and present, its value is surfaced as ``file_name`` so it can be + included in the evaluation sample log. Falls back to the dataset row index. Defaults to "". max_length: Maximum generation length (input + output tokens) for the VLM. Vision prompts with large images can exceed 3000 tokens due to vision patches. Defaults to 4096. max_samples: Maximum number of samples (deprecated, use limit). Defaults to None. @@ -455,6 +479,7 @@ def __init__( answer_column, options_column="", sys_prompt="", + id_column="", max_length=4096, ): self.dataset = hf_dataset @@ -463,6 +488,7 @@ def __init__( self.answer_column = answer_column self.options_column = options_column self.system_prompt = sys_prompt + self.id_column = id_column self.max_length = max_length def __len__(self): @@ -493,8 +519,8 @@ def __getitem__(self, idx): # Convert 0-based answer index to 1-based to match the option numbering if num_choices > 0: try: - idx = int(answer) - answer = str(idx + 1) + answer_idx = int(answer) + answer = str(answer_idx + 1) except (ValueError, TypeError): pass # answer is already a non-numeric string (e.g., text label) @@ -504,6 +530,11 @@ def __getitem__(self, idx): "system_prompt": self.system_prompt, "num_choices": num_choices, "max_length": self.max_length, + "file_name": ( + str(item[self.id_column]) + if self.id_column and self.id_column in item and item[self.id_column] is not None + else str(idx) + ), } return input_dict, str(answer) @@ -521,4 +552,6 @@ def collate_fn(batch): answers = [item[1] for item in batch] return (inputs, answers) - return VisionVQADataset(dataset, image_col, question_col, answer_col, options_col, system_prompt, max_length) + return VisionVQADataset( + dataset, image_col, question_col, answer_col, options_col, system_prompt, id_col, max_length + ) diff --git a/olive/evaluator/metric.py b/olive/evaluator/metric.py index 5d19b132e..b1c77bc8e 100644 --- a/olive/evaluator/metric.py +++ b/olive/evaluator/metric.py @@ -124,6 +124,17 @@ class Metric(NestedConfig): sub_types: list[SubMetric] = Field(default=[], validate_default=True) user_config: Optional[ConfigBase] = Field(default=None, validate_default=True) data_config: Optional[DataConfig] = Field(default=None, validate_default=True) + sample_log_num: int = Field( + default=0, + description=( + "Number of sample predictions to log alongside ground truth. " + "When > 0, saves a JSONL file with the first N sample results for debugging." + ), + ) + sample_log_dir: Optional[str] = Field( + default=None, + description="Directory to save the sample log file. Defaults to the current working directory.", + ) def get_inference_settings(self, framework): if self.user_config is None: diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index f6b7fb841..9b25e5c26 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import collections +import json import logging import time from abc import ABC, abstractmethod @@ -57,6 +58,7 @@ class OliveModelOutput(NamedTuple): preds: Any logits: Any + extras: Any = None # Text-based accuracy sub-types that work with string predictions/targets @@ -257,6 +259,64 @@ def compute_accuracy(metric: Metric, model_outputs: Union[tuple, NamedTuple], ta evaluate_backend_cls = MetricBackend.registry[metric.backend] return evaluate_backend_cls().measure(model_outputs, targets, metric) + @staticmethod + def save_sample_log(metric: Metric, inference_output: "OliveModelOutput", targets: Any, num_samples: int) -> None: + """Save top N sample predictions and ground truth to a JSONL file. + + Each line in the output file is a JSON object with 'index', 'prediction', and 'target' + fields. When the inference output carries per-sample ``extras`` (e.g. the input prompt and + the vision/audio file name), those key/value pairs are merged into each record as well. + For tensor data, values are converted to Python scalars or lists. + This is best-effort: filesystem or serialization errors are logged as warnings. + """ + if num_samples <= 0: + return + + try: + preds = inference_output.preds + extras = getattr(inference_output, "extras", None) + output_dir = Path(metric.sample_log_dir) if metric.sample_log_dir else Path.cwd() + # Sanitize metric name to prevent path traversal + safe_name = Path(metric.name).name.replace("/", "_").replace("\\", "_") or "metric" + output_dir.mkdir(parents=True, exist_ok=True) + log_path = output_dir / f"{safe_name}_samples.jsonl" + + def _to_serializable(val): + """Convert tensor/ndarray/numpy scalar values to JSON-serializable Python objects.""" + if isinstance(val, (torch.Tensor, np.ndarray)): + return val.tolist() + if isinstance(val, np.generic): + return val.item() + return val + + total_samples = len(preds) if hasattr(preds, "__len__") else num_samples + n = min(num_samples, total_samples) + if num_samples > total_samples: + logger.warning( + "sample_log_num (%d) exceeds available samples (%d), capping to %d.", + num_samples, + total_samples, + n, + ) + with log_path.open("w", encoding="utf-8") as f: + for i in range(n): + pred_val = preds[i] if hasattr(preds, "__getitem__") else preds + target_val = targets[i] if hasattr(targets, "__getitem__") else targets + record = {"index": i} + # Merge per-sample metadata (e.g. prompt, image/audio file name) when available. + if extras is not None and hasattr(extras, "__getitem__") and i < len(extras): + extra = extras[i] + if isinstance(extra, dict): + for key, value in extra.items(): + record[key] = _to_serializable(value) + record["prediction"] = _to_serializable(pred_val) + record["target"] = _to_serializable(target_val) + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + logger.info("Saved %d sample predictions to %s", n, log_path) + except Exception as e: + logger.warning("Failed to save sample log for metric '%s': %s", metric.name, e) + @staticmethod def latency_helper(latencies) -> dict: return { @@ -511,6 +571,61 @@ def _get_genai_model_dir(model: ONNXModelHandler) -> str: return str(Path(model.model_path).parent) +def _normalize_audio_batch(input_data) -> tuple[list, list]: + """Return (audio_arrays, file_names) from the various speech input shapes. + + Supports the ``{"audio": array, "file_name": name}`` dict produced by + ``speech_transcription_pre_process`` (single or batched), as well as legacy raw + arrays / lists of arrays (in which case file names are ``None``). + """ + arrays: list = [] + names: list = [] + + def _add_array(arr, name): + arr = np.array(arr) if not isinstance(arr, np.ndarray) else arr + if arr.ndim <= 1: + arrays.append(arr) + names.append(name) + else: + for i in range(arr.shape[0]): + arrays.append(arr[i]) + names.append(name) + + dict_items = None + if isinstance(input_data, dict): + dict_items = [input_data] + elif isinstance(input_data, list) and input_data and all(isinstance(d, dict) for d in input_data): + dict_items = input_data + + if dict_items is not None: + for item in dict_items: + _add_array(item.get("audio"), item.get("file_name")) + return arrays, names + + # Legacy shapes: raw array/tensor or list of arrays. + if isinstance(input_data, (np.ndarray, torch.Tensor)): + arr = np.array(input_data) if isinstance(input_data, torch.Tensor) else input_data + _add_array(arr, None) + elif isinstance(input_data, list): + for a in input_data: + _add_array(a, None) + return arrays, names + + +def _unwrap_audio_input(input_data): + """Strip the speech metadata dict down to the raw audio array(s). + + Keeps the non-genai inference/latency paths (which feed ``format_data``) working with the + ``{"audio": array, "file_name": name}`` shape produced by ``speech_transcription_pre_process``. + Other input shapes are returned unchanged. + """ + if isinstance(input_data, dict) and "audio" in input_data: + return input_data["audio"] + if isinstance(input_data, list) and input_data and all(isinstance(d, dict) and "audio" in d for d in input_data): + return [d["audio"] for d in input_data] + return input_data + + @Registry.register(str(Framework.ONNX)) @Registry.register("OnnxEvaluator") class OnnxEvaluator(_OliveEvaluator, OnnxEvaluatorMixin): @@ -622,7 +737,6 @@ def _load_genai_config(model: ONNXModelHandler) -> Optional[dict]: genai_config_path = _find_genai_config(model) if genai_config_path is None: return None - import json try: with genai_config_path.open(encoding="utf-8") as f: @@ -679,6 +793,7 @@ def _evaluate_onnx_accuracy( inference_output, targets = self._inference( model, metric, dataloader, post_func, device, execution_providers ) + OliveEvaluator.save_sample_log(metric, inference_output, targets, metric.sample_log_num) return OliveEvaluator.compute_accuracy(metric, inference_output, targets) def _inference_text( @@ -716,6 +831,8 @@ def _inference_text( for batch in dataloader: input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + # Drop any speech metadata (e.g. file_name) so format_data receives raw audio arrays. + input_data = _unwrap_audio_input(input_data) # Track audio duration from input data if isinstance(input_data, (np.ndarray, torch.Tensor)): audio_samples = input_data.shape[-1] if len(input_data.shape) > 1 else input_data.shape[0] @@ -870,7 +987,6 @@ def _inference_vision_genai( "Install it with: pip install onnxruntime-genai" ) from e - import json import re import tempfile @@ -898,6 +1014,7 @@ def _inference_vision_genai( all_preds = [] all_targets = [] + all_extras = [] # Use a temporary directory for image files to avoid per-file create/delete overhead with tempfile.TemporaryDirectory() as tmp_dir: @@ -917,10 +1034,12 @@ def _inference_vision_genai( sys_prompt = item.get("system_prompt", "") num_choices = item.get("num_choices", 0) max_length = item.get("max_length", default_max_length) + file_name = item.get("file_name", str(sample_idx)) if pil_image is None: # Append empty pred to maintain alignment with targets all_preds.append("") + all_extras.append({"prompt": question, "image": file_name}) sample_idx += 1 continue @@ -992,6 +1111,7 @@ def _inference_vision_genai( pred = ch break all_preds.append(pred) + all_extras.append({"prompt": question, "image": file_name}) # Collect reference texts (aligned with preds including empty ones for None images) if isinstance(labels, (list, tuple)): @@ -1001,7 +1121,7 @@ def _inference_vision_genai( del og_model - return OliveModelOutput(preds=all_preds, logits=None), all_targets + return OliveModelOutput(preds=all_preds, logits=None, extras=all_extras), all_targets def _inference_text_genai( self, @@ -1026,7 +1146,6 @@ def _inference_text_genai( ) from None import io - import json import soundfile as sf @@ -1099,31 +1218,25 @@ def _transcribe_chunks(audio_arr: np.ndarray, genai_model) -> str: all_preds = [] all_targets = [] + all_extras = [] total_audio_duration = 0.0 total_inference_time = 0.0 for batch in dataloader: input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) - # Convert input to list of audio arrays - audio_arrays = [] - if isinstance(input_data, (np.ndarray, torch.Tensor)): - arr = np.array(input_data) if isinstance(input_data, torch.Tensor) else input_data - if arr.ndim == 1: - audio_arrays = [arr] - else: - audio_arrays = [arr[i] for i in range(arr.shape[0])] - elif isinstance(input_data, list): - audio_arrays = [np.array(a) if not isinstance(a, np.ndarray) else a for a in input_data] + # Convert input to list of audio arrays (with optional file names) + audio_arrays, audio_names = _normalize_audio_batch(input_data) if not audio_arrays: continue start_time = time.perf_counter() - for arr in audio_arrays: + for arr, name in zip(audio_arrays, audio_names): total_audio_duration += len(arr) / sample_rate transcription = _transcribe_chunks(arr, og_model) all_preds.append(transcription) + all_extras.append({"audio": name if name is not None else str(len(all_extras))}) total_inference_time += time.perf_counter() - start_time # Collect reference texts @@ -1138,7 +1251,7 @@ def _transcribe_chunks(audio_arr: np.ndarray, genai_model) -> str: "total_audio_duration": total_audio_duration, "total_inference_time": total_inference_time, } - return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + return OliveModelOutput(preds=all_preds, logits=timing_metadata, extras=all_extras), all_targets def _inference_text_genai_streaming( self, @@ -1162,8 +1275,6 @@ def _inference_text_genai_streaming( "Install it with: pip install onnxruntime-genai" ) from None - import json - model_dir = _get_genai_model_dir(model) with (Path(model_dir) / "genai_config.json").open() as f: @@ -1229,31 +1340,25 @@ def decode_tokens(): all_preds = [] all_targets = [] + all_extras = [] total_audio_duration = 0.0 total_inference_time = 0.0 for batch in dataloader: input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) - # Convert input to list of audio arrays - audio_arrays = [] - if isinstance(input_data, (np.ndarray, torch.Tensor)): - arr = np.array(input_data) if isinstance(input_data, torch.Tensor) else input_data - if arr.ndim == 1: - audio_arrays = [arr] - else: - audio_arrays = [arr[i] for i in range(arr.shape[0])] - elif isinstance(input_data, list): - audio_arrays = [np.array(a) if not isinstance(a, np.ndarray) else a for a in input_data] + # Convert input to list of audio arrays (with optional file names) + audio_arrays, audio_names = _normalize_audio_batch(input_data) if not audio_arrays: continue start_time = time.perf_counter() - for arr in audio_arrays: + for arr, name in zip(audio_arrays, audio_names): total_audio_duration += len(arr) / sample_rate transcription = _transcribe_streaming(arr, og_model) all_preds.append(transcription) + all_extras.append({"audio": name if name is not None else str(len(all_extras))}) total_inference_time += time.perf_counter() - start_time # Collect reference texts @@ -1268,7 +1373,7 @@ def decode_tokens(): "total_audio_duration": total_audio_duration, "total_inference_time": total_inference_time, } - return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + return OliveModelOutput(preds=all_preds, logits=timing_metadata, extras=all_extras), all_targets def _evaluate_onnx_latency( self, @@ -1287,6 +1392,8 @@ def _evaluate_onnx_latency( batch = next(iter(dataloader)) input_data = OliveEvaluator.extract_input_data(batch) + # Strip speech metadata (e.g. file_name) so format_data receives raw audio arrays. + input_data = _unwrap_audio_input(input_data) input_feed = format_data(input_data, io_config) latencies = session.time_run( @@ -1383,6 +1490,7 @@ def _evaluate_distributed_accuracy( targets = [x for _, t, _ in results for x in t] logits = [x for _, _, logit in results for x in logit] model_output = OliveModelOutput(preds, logits) + OliveEvaluator.save_sample_log(metric, model_output, targets, metric.sample_log_num) return OliveEvaluator.compute_accuracy(metric, model_output, targets) @staticmethod @@ -1586,6 +1694,7 @@ def _evaluate_accuracy( inference_output, targets = self._inference( model, metric, dataloader, post_func, device, execution_providers ) + OliveEvaluator.save_sample_log(metric, inference_output, targets, metric.sample_log_num) return OliveEvaluator.compute_accuracy(metric, inference_output, targets) @torch.no_grad() @@ -1820,6 +1929,7 @@ def _evaluate_accuracy( execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: inference_output, targets = self._inference(model, metric, dataloader, post_func, device, execution_providers) + OliveEvaluator.save_sample_log(metric, inference_output, targets, metric.sample_log_num) return OliveEvaluator.compute_accuracy(metric, inference_output, targets) def _evaluate_raw_latency( @@ -1894,6 +2004,7 @@ def _evaluate_accuracy( execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: inference_output, targets = self._inference(model, metric, dataloader, post_func, device, execution_providers) + OliveEvaluator.save_sample_log(metric, inference_output, targets, metric.sample_log_num) return OliveEvaluator.compute_accuracy(metric, inference_output, targets) def _evaluate_raw_latency( diff --git a/test/data_container/test_pre_process_data.py b/test/data_container/test_pre_process_data.py index 8fcefdc4f..55647df79 100644 --- a/test/data_container/test_pre_process_data.py +++ b/test/data_container/test_pre_process_data.py @@ -8,7 +8,12 @@ import pytest from datasets import Dataset -from olive.data.component.pre_process_data import huggingface_pre_process, tokenizer_pre_process +from olive.data.component.pre_process_data import ( + huggingface_pre_process, + speech_transcription_pre_process, + tokenizer_pre_process, + vision_vqa_pre_process, +) class TestPreProcessData: @@ -71,3 +76,59 @@ def test_huggingface_pre_process_with_label(self, mock_get_tokenizer, mock_datas mock_get_tokenizer.assert_called_with("bert-base-uncased", trust_remote_code=None) assert result.label_col == "label" assert result.effective_len == 2 + + +class TestVisionAndAudioFileName: + """Tests for the file_name surfaced by vision/audio preprocessors for sample logging.""" + + def test_vision_vqa_pre_process_uses_id_col_as_file_name(self): + data = { + "image": ["img_a", "img_b"], + "question": ["Q1", "Q2"], + "answer": ["1", "2"], + "image_id": ["a.png", "b.png"], + } + dataset = Dataset.from_dict(data) + vqa = vision_vqa_pre_process( + dataset, image_col="image", question_col="question", answer_col="answer", id_col="image_id" + ) + input_dict, answer = vqa[0] + assert input_dict["file_name"] == "a.png" + assert input_dict["question"] == "Q1" + assert answer == "1" + + def test_vision_vqa_pre_process_falls_back_to_index(self): + # Includes options so the num_choices answer-conversion path runs; file_name must still + # be the dataset row index (regression: the conversion previously shadowed the index var). + data = { + "image": ["img_a", "img_b"], + "question": ["Q1", "Q2"], + "answer": ["0", "1"], + "options": [["a", "b"], ["c", "d"]], + } + dataset = Dataset.from_dict(data) + vqa = vision_vqa_pre_process( + dataset, image_col="image", question_col="question", answer_col="answer", options_col="options" + ) + input_dict0, answer0 = vqa[0] + assert input_dict0["file_name"] == "0" + assert answer0 == "1" # 0-based answer converted to 1-based + input_dict1, _ = vqa[1] + assert input_dict1["file_name"] == "1" + + @patch("datasets.Dataset.cast_column", autospec=True) + def test_speech_transcription_pre_process_returns_dict_with_file_name(self, mock_cast): + import numpy as np + + # cast_column is patched to a no-op so we can supply raw audio dicts directly. + mock_cast.side_effect = lambda self, *args, **kwargs: self + data = { + "audio": [{"array": np.zeros(16000, dtype=np.float32), "sampling_rate": 16000, "path": "/data/clip_0.wav"}], + "text": ["hello"], + } + dataset = Dataset.from_dict(data) + speech = speech_transcription_pre_process(dataset, audio_col="audio", text_col="text") + input_dict, text = speech[0] + assert input_dict["file_name"] == "clip_0.wav" + assert input_dict["audio"].shape == (16000,) + assert text == "hello" diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 1e1977cf6..1b5528e79 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -664,6 +664,8 @@ def _make_vision_accuracy_metric(self): metric.user_config.input_names = None metric.user_config.input_shapes = None metric.backend = "huggingface_metrics" + metric.sample_log_num = 0 + metric.sample_log_dir = None return metric def test_genai_vision_detected_when_vision_field_present(self, tmp_path): @@ -869,3 +871,187 @@ def test_find_genai_config_ignores_directory(self, tmp_path): # Should not find the directory, should return None result = _find_genai_config(model) assert result is None + + +class TestSaveSampleLog: + """Tests for OliveEvaluator.save_sample_log.""" + + @staticmethod + def _make_metric(sample_log_num=0, sample_log_dir=None, name="test_metric"): + metric = MagicMock() + metric.name = name + metric.sample_log_num = sample_log_num + metric.sample_log_dir = sample_log_dir + return metric + + def test_save_sample_log_disabled_when_zero(self, tmp_path): + """No file should be created when sample_log_num=0.""" + import torch + + from olive.evaluator.olive_evaluator import OliveModelOutput + + metric = self._make_metric(sample_log_num=0, sample_log_dir=str(tmp_path), name="m") + output = OliveModelOutput(preds=torch.tensor([1, 2, 3]), logits=None) + targets = torch.tensor([1, 2, 3]) + + OliveEvaluator.save_sample_log(metric, output, targets, 0) + assert not list(tmp_path.iterdir()) + + def test_save_sample_log_with_tensor_data(self, tmp_path): + """Should write a JSONL file with tensor preds/targets converted to Python values.""" + import json + + import torch + + from olive.evaluator.olive_evaluator import OliveModelOutput + + metric = self._make_metric(sample_log_num=3, sample_log_dir=str(tmp_path), name="accuracy") + preds = torch.tensor([0, 1, 1, 0, 1]) + targets = torch.tensor([0, 1, 0, 0, 1]) + output = OliveModelOutput(preds=preds, logits=None) + + OliveEvaluator.save_sample_log(metric, output, targets, 3) + + log_path = tmp_path / "accuracy_samples.jsonl" + assert log_path.exists() + + lines = log_path.read_text().strip().split("\n") + assert len(lines) == 3 + + for i, line in enumerate(lines): + record = json.loads(line) + assert record["index"] == i + assert record["prediction"] == preds[i].item() + assert record["target"] == targets[i].item() + + def test_save_sample_log_with_string_data(self, tmp_path): + """Should handle string predictions and targets (text-based metrics).""" + import json + + from olive.evaluator.olive_evaluator import OliveModelOutput + + metric = self._make_metric(sample_log_num=2, sample_log_dir=str(tmp_path), name="wer") + preds = ["hello world", "foo bar"] + targets = ["hello world", "foo baz"] + output = OliveModelOutput(preds=preds, logits=None) + + OliveEvaluator.save_sample_log(metric, output, targets, 2) + + log_path = tmp_path / "wer_samples.jsonl" + assert log_path.exists() + + lines = log_path.read_text().strip().split("\n") + assert len(lines) == 2 + + record0 = json.loads(lines[0]) + assert record0["prediction"] == "hello world" + assert record0["target"] == "hello world" + + record1 = json.loads(lines[1]) + assert record1["prediction"] == "foo bar" + assert record1["target"] == "foo baz" + + def test_save_sample_log_caps_at_available_samples(self, tmp_path): + """When sample_log_num > len(preds), should write only available samples.""" + import torch + + from olive.evaluator.olive_evaluator import OliveModelOutput + + metric = self._make_metric(sample_log_num=100, sample_log_dir=str(tmp_path), name="acc") + preds = torch.tensor([1, 2]) + targets = torch.tensor([1, 0]) + output = OliveModelOutput(preds=preds, logits=None) + + OliveEvaluator.save_sample_log(metric, output, targets, 100) + + log_path = tmp_path / "acc_samples.jsonl" + lines = log_path.read_text().strip().split("\n") + assert len(lines) == 2 + + def test_save_sample_log_merges_extras(self, tmp_path): + """Per-sample extras (e.g. prompt and image/audio file name) should be merged into records.""" + import json + + from olive.evaluator.olive_evaluator import OliveModelOutput + + metric = self._make_metric(sample_log_num=2, sample_log_dir=str(tmp_path), name="vision_accuracy") + preds = ["1", "3"] + targets = ["3", "3"] + extras = [ + {"prompt": "What is shown?\n1. cat\n2. dog", "image": "img_0.png"}, + {"prompt": "Which arrow?\n1. up\n2. down", "image": "img_1.png"}, + ] + output = OliveModelOutput(preds=preds, logits=None, extras=extras) + + OliveEvaluator.save_sample_log(metric, output, targets, 2) + + log_path = tmp_path / "vision_accuracy_samples.jsonl" + lines = log_path.read_text().strip().split("\n") + assert len(lines) == 2 + + record0 = json.loads(lines[0]) + # index first, then merged extras, then prediction/target + assert list(record0.keys()) == ["index", "prompt", "image", "prediction", "target"] + assert record0["prompt"] == "What is shown?\n1. cat\n2. dog" + assert record0["image"] == "img_0.png" + assert record0["prediction"] == "1" + assert record0["target"] == "3" + + record1 = json.loads(lines[1]) + assert record1["image"] == "img_1.png" + + def test_save_sample_log_without_extras_is_unchanged(self, tmp_path): + """When extras is None, records should only contain index/prediction/target.""" + import json + + from olive.evaluator.olive_evaluator import OliveModelOutput + + metric = self._make_metric(sample_log_num=1, sample_log_dir=str(tmp_path), name="acc") + output = OliveModelOutput(preds=["a"], logits=None) + + OliveEvaluator.save_sample_log(metric, output, ["a"], 1) + + record = json.loads((tmp_path / "acc_samples.jsonl").read_text().strip()) + assert list(record.keys()) == ["index", "prediction", "target"] + + +class TestAudioInputHelpers: + """Tests for the speech input normalization/unwrap helpers.""" + + def test_normalize_audio_batch_dict_with_file_name(self): + import numpy as np + + from olive.evaluator.olive_evaluator import _normalize_audio_batch + + arr = np.zeros(16000, dtype=np.float32) + arrays, names = _normalize_audio_batch({"audio": np.expand_dims(arr, 0), "file_name": "a.wav"}) + assert len(arrays) == 1 + assert arrays[0].shape == (16000,) + assert names == ["a.wav"] + + def test_normalize_audio_batch_legacy_array(self): + import numpy as np + + from olive.evaluator.olive_evaluator import _normalize_audio_batch + + arr = np.zeros((1, 16000), dtype=np.float32) + arrays, names = _normalize_audio_batch(arr) + assert len(arrays) == 1 + assert names == [None] + + def test_unwrap_audio_input_dict(self): + import numpy as np + + from olive.evaluator.olive_evaluator import _unwrap_audio_input + + arr = np.zeros((1, 8), dtype=np.float32) + unwrapped = _unwrap_audio_input({"audio": arr, "file_name": "a.wav"}) + assert unwrapped is arr + + def test_unwrap_audio_input_passthrough(self): + import numpy as np + + from olive.evaluator.olive_evaluator import _unwrap_audio_input + + arr = np.zeros((1, 8), dtype=np.float32) + assert _unwrap_audio_input(arr) is arr