Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions olive/data/component/pre_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def speech_transcription_pre_process(
dataset,
audio_col: str = "audio",
text_col: str = "text",
id_col: str = "",
sample_rate: int = 16000,
max_samples: Optional[int] = None,
limit: Optional[float] = None,
Expand All @@ -307,12 +308,17 @@ def speech_transcription_pre_process(
"""Pre-process data for speech transcription (ASR) evaluation.

Loads audio arrays and reference transcription text from a HuggingFace dataset.
Returns a dataset of (audio_array, reference_text) pairs suitable for WER evaluation.
Returns a dataset of ({"audio": audio_array, "file_name": name}, reference_text) pairs
suitable for WER evaluation.

Args:
dataset: HuggingFace dataset with audio and text columns.
audio_col: Name of the audio column. Defaults to "audio".
text_col: Name of the reference text column. Defaults to "text".
id_col: Name of a column to use as a per-sample identifier (e.g. an audio file name or
sample id). When set and present, its value is surfaced as ``file_name`` so it can be
included in the evaluation sample log. Falls back to the HuggingFace Audio feature's
``path`` and then the dataset row index. Defaults to "".
sample_rate: Target sample rate for audio. Defaults to 16000.
max_samples: Maximum number of samples (deprecated, use limit). Defaults to None.
limit: Sampling limit following Olive convention:
Expand Down Expand Up @@ -343,26 +349,39 @@ def speech_transcription_pre_process(
dataset = dataset.select(range(n))

class SpeechTranscriptionDataset:
"""Dataset that returns (audio_array, reference_text) pairs.
"""Dataset that returns ({"audio": audio_array, "file_name": name}, reference_text) pairs.

Note: Use batch_size=1 in dataloader config as audio samples have variable lengths.
"""

def __init__(self, hf_dataset, audio_column, text_column):
def __init__(self, hf_dataset, audio_column, text_column, id_column=""):
self.dataset = hf_dataset
self.audio_column = audio_column
self.text_column = text_column
self.id_column = id_column

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
item = self.dataset[idx]
import os

import numpy as np

audio_array = np.array(item[self.audio_column]["array"], dtype=np.float32)
item = self.dataset[idx]
audio_item = item[self.audio_column]
audio_array = np.array(audio_item["array"], dtype=np.float32)
reference_text = item[self.text_column]
return audio_array, reference_text

path = audio_item.get("path") if isinstance(audio_item, dict) else None
if self.id_column and self.id_column in item and item[self.id_column] is not None:
file_name = str(item[self.id_column])
elif path:
file_name = os.path.basename(str(path))
else:
file_name = str(idx)

return {"audio": audio_array, "file_name": file_name}, reference_text

@staticmethod
def collate_fn(batch):
Expand All @@ -371,14 +390,15 @@ def collate_fn(batch):

# batch_size=1 is expected for speech evaluation (variable-length audio)
if len(batch) == 1:
audio, text = batch[0]
return (np.expand_dims(audio, 0), [text])
input_dict, text = batch[0]
batched = {**input_dict, "audio": np.expand_dims(input_dict["audio"], 0)}
return (batched, [text])
# For batch_size > 1, return as lists (no padding)
audios = [item[0] for item in batch]
inputs = [item[0] for item in batch]
texts = [item[1] for item in batch]
return (audios, texts)
return (inputs, texts)

return SpeechTranscriptionDataset(dataset, audio_col, text_col)
return SpeechTranscriptionDataset(dataset, audio_col, text_col, id_col)


@Registry.register_pre_process()
Expand All @@ -389,6 +409,7 @@ def vision_vqa_pre_process(
answer_col: str = "answer",
options_col: str = "",
system_prompt: str = "",
id_col: str = "",
max_length: int = 4096,
max_samples: Optional[int] = None,
limit: Optional[float] = None,
Expand All @@ -415,6 +436,9 @@ def vision_vqa_pre_process(
options are formatted as numbered choices and appended to the question. Defaults to "".
system_prompt: System prompt to guide model responses (e.g., "Reply with only the
option number"). Passed through to the evaluator. Defaults to "".
id_col: Name of a column to use as a per-sample identifier (e.g. an image file name or
sample id). When set and present, its value is surfaced as ``file_name`` so it can be
included in the evaluation sample log. Falls back to the dataset row index. Defaults to "".
max_length: Maximum generation length (input + output tokens) for the VLM. Vision prompts
with large images can exceed 3000 tokens due to vision patches. Defaults to 4096.
max_samples: Maximum number of samples (deprecated, use limit). Defaults to None.
Expand Down Expand Up @@ -455,6 +479,7 @@ def __init__(
answer_column,
options_column="",
sys_prompt="",
id_column="",
max_length=4096,
):
self.dataset = hf_dataset
Expand All @@ -463,6 +488,7 @@ def __init__(
self.answer_column = answer_column
self.options_column = options_column
self.system_prompt = sys_prompt
self.id_column = id_column
self.max_length = max_length

def __len__(self):
Expand Down Expand Up @@ -493,8 +519,8 @@ def __getitem__(self, idx):
# Convert 0-based answer index to 1-based to match the option numbering
if num_choices > 0:
try:
idx = int(answer)
answer = str(idx + 1)
answer_idx = int(answer)
answer = str(answer_idx + 1)
except (ValueError, TypeError):
pass # answer is already a non-numeric string (e.g., text label)

Expand All @@ -504,6 +530,11 @@ def __getitem__(self, idx):
"system_prompt": self.system_prompt,
"num_choices": num_choices,
"max_length": self.max_length,
"file_name": (
str(item[self.id_column])
if self.id_column and self.id_column in item and item[self.id_column] is not None
else str(idx)
),
}
return input_dict, str(answer)

Expand All @@ -521,4 +552,6 @@ def collate_fn(batch):
answers = [item[1] for item in batch]
return (inputs, answers)

return VisionVQADataset(dataset, image_col, question_col, answer_col, options_col, system_prompt, max_length)
return VisionVQADataset(
dataset, image_col, question_col, answer_col, options_col, system_prompt, id_col, max_length
)
11 changes: 11 additions & 0 deletions olive/evaluator/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ class Metric(NestedConfig):
sub_types: list[SubMetric] = Field(default=[], validate_default=True)
user_config: Optional[ConfigBase] = Field(default=None, validate_default=True)
data_config: Optional[DataConfig] = Field(default=None, validate_default=True)
sample_log_num: int = Field(
default=0,
description=(
"Number of sample predictions to log alongside ground truth. "
"When > 0, saves a JSONL file with the first N sample results for debugging."
),
)
sample_log_dir: Optional[str] = Field(
default=None,
description="Directory to save the sample log file. Defaults to the current working directory.",
)

def get_inference_settings(self, framework):
if self.user_config is None:
Expand Down
Loading
Loading