Skip to content

Created GRPOTrainerWithEval subclass for different evaluation reward functions#8

Closed
jamesbraza wants to merge 3 commits into
working-grpo-2025-03-08from
grpo-with-eval
Closed

Created GRPOTrainerWithEval subclass for different evaluation reward functions#8
jamesbraza wants to merge 3 commits into
working-grpo-2025-03-08from
grpo-with-eval

Conversation

@jamesbraza

Copy link
Copy Markdown
Member

This PR creates a GRPOTrainer subclass GRPOTrainerWithEval that adds support for optional eval_reward_processing_classes.

It should be backwards compatible with GRPOTrainer.

The only caveat here is I didn't comprehensively think about args.reward_weights.

@jamesbraza jamesbraza added the enhancement New feature or request label Mar 9, 2025
@jamesbraza jamesbraza self-assigned this Mar 9, 2025

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Overview

This PR introduces a subclass GRPOTrainerWithEval that extends GRPOTrainer to support evaluation reward functions and their associated processing classes while maintaining backward compatibility.

  • Refactors model initialization to use an instance attribute (_model_init_kwargs) for improved consistency.
  • Extracts reward processing class creation into a new helper method (_make_reward_processing_classes).
  • Implements GRPOTrainerWithEval for handling separate training and evaluation reward functions and processing.

Reviewed Changes

File Description
trl/trainer/grpo_trainer.py Added logging, refactored model initialization kwargs, created a helper for reward processing classes, and introduced a new subclass for evaluation reward support

Copilot reviewed 1 out of 1 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (1)

trl/trainer/grpo_trainer.py:1257

  • [nitpick] Consider handling cases where reward_func may not have a name attribute (e.g., when using lambdas or partial functions) to avoid potential AttributeErrors. A possible solution is to use getattr(reward_func, 'name', repr(reward_func)).
reward_func_name = reward_func.__name__

@jamesbraza jamesbraza deleted the branch working-grpo-2025-03-08 March 10, 2025 20:18
@jamesbraza jamesbraza closed this Mar 10, 2025
@jamesbraza

Copy link
Copy Markdown
Member Author

Closed in favor of #9 after a rebase onto working-grpo-2025-03-10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants