Skip to content

Add DPO and ORPO preference data preprocessing pipeline utils#3895

Draft
igorts-git wants to merge 1 commit into
mainfrom
igorts/dpo-input-processing
Draft

Add DPO and ORPO preference data preprocessing pipeline utils#3895
igorts-git wants to merge 1 commit into
mainfrom
igorts/dpo-input-processing

Conversation

@igorts-git
Copy link
Copy Markdown
Collaborator

@igorts-git igorts-git commented May 13, 2026

Description

To simplify code review I am splitting the Tunix-based DPO implementation into smaller PRs.
This one adds the data reading processing required by DPO.

The classic DPO inputs consist of three data columns: ["prompt", "chosen_response", "rejected_response"].
However, some DPO datasets use a two-column format where the prompt is the prefix to the choosen and rejected strings.
When a 2-column dataset is used our implementation extracts the common prefix into the "prompt" field that is then fed into the model separately.
The column names in the dataset can wary, for example ["input", chosen", "rejected"]. Our implementation allows the user to supply the dataset column names via the train_data_columns and eval_data_columns parameters.

Tunix requires left-padded prompt and right-padded responses. Our code implements this padding (and truncation if needed) it also provides Tunix with the corresponding masks.

NOTE: once this PR is merged the legacy DPO will stop working correctly. The follow up PRs will enable Tunix-based DPO.

Tests

Added unit tests. Ran DPO/ORPO and performed logits comparison against the legacy implementation.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 13, 2026

Codecov Report

❌ Patch coverage is 20.00000% with 48 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/input_pipeline/dpo_utils.py 20.75% 42 Missing ⚠️
src/maxtext/input_pipeline/hf_data_processing.py 14.28% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

@igorts-git igorts-git force-pushed the igorts/dpo-input-processing branch 2 times, most recently from 30d3c25 to b8ae239 Compare May 14, 2026 21:56
@github-actions
Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The Pull Request introduces important utilities for DPO and ORPO preference data preprocessing, which is a key component for the upcoming Tunix-based alignment implementation. The core logic for handling 2-column and 3-column datasets is well-structured, but I identified a high-severity bug in the common prefix extraction and some opportunities for more flexible truncation strategies.

🔍 General Feedback

  • Logic Bug: The common prefix extraction logic using enumerate(zip(...)) is flawed for edge cases like identical strings or prefix strings. I have provided a more robust implementation in the inline comments.
  • Truncation Strategy: The current 50/50 split for prompt/response lengths and the prefix-based truncation for prompts might lead to information loss in long-context scenarios.
  • Test Coverage: The new unit tests are quite thorough, but adding the suggested edge cases for prefix extraction would make them even better.

Comment thread src/maxtext/input_pipeline/dpo_utils.py
Comment thread src/maxtext/input_pipeline/dpo_utils.py Outdated
Comment thread src/maxtext/input_pipeline/dpo_utils.py Outdated
Comment thread src/maxtext/input_pipeline/dpo_utils.py Outdated
Comment thread tests/post_training/unit/dpo_data_processing_test.py
…ities

Includes robust common prefix extraction for 2-column datasets, prompt suffix truncation, customizable max_prompt_length with validation against max_target_length, and complete integration unit test coverage.
@igorts-git igorts-git force-pushed the igorts/dpo-input-processing branch from b8ae239 to 2d7b6e0 Compare May 15, 2026 00:02
@github-actions
Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces necessary data preprocessing utilities for DPO and ORPO, including a new Grain transform DPOTunixPrep that handles column remapping, prefix extraction, and DPO-aware padding. The implementation is well-tested and integrated into the existing Hugging Face data pipeline.

🔍 General Feedback

  • Robustness: The prefix extraction logic for 2-column datasets is a great addition for supporting popular preference datasets like Anthropic/hh-rlhf.
  • Breaking Change: As noted in the description, moving DPO parameters into a nested config block is a breaking change for existing DPO configurations.
  • Logic Correction: A fix is suggested for the slicing logic in _pad to correctly handle cases where the requested length is 0.
  • Validation: Added a suggestion for non-negativity validation on max_prompt_length to align with project standards.

pad_amount = max(length - x.shape[0], 0)
if left:
pad_width = ((pad_amount, 0),)
x_trimmed = x[-length:]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 In Python, `x[-0:]` returns the entire array `x` instead of an empty array. If `max_prompt_length` is ever 0 (or set to `max_target_length` such that `max_response_length` is 0), this logic will return the un-truncated input.
Suggested change
x_trimmed = x[-length:]
if left:
pad_width = ((pad_amount, 0),)
x_trimmed = x[-length:] if length > 0 else x[:0]
else:

orpo_lambda: float = Field(0.1, description="Weight for preference loss in ORPO.")
dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.")
max_prompt_length: int | None = Field(
None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 It would be safer to ensure that `max_prompt_length` is non-negative to avoid unexpected behavior in the padding logic.
Suggested change
None,
max_prompt_length: int | None = Field(
None,
ge=0,
description="Maximum length for prompt. If None, defaults to half of max_target_length.",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant