Skip to content

arguments: add --enable-r3-correctness-check CLI flag#25

Open
DavidBellamy wants to merge 5 commits into
prodfrom
agentic-rl/r3-correctness-check-flag
Open

arguments: add --enable-r3-correctness-check CLI flag#25
DavidBellamy wants to merge 5 commits into
prodfrom
agentic-rl/r3-correctness-check-flag

Conversation

@DavidBellamy
Copy link
Copy Markdown
Collaborator

@DavidBellamy DavidBellamy commented May 18, 2026

What this does

Three small, targeted miles changes to enable direct end-to-end validation of R3 (Rollout Router Replay) without dragging in unrelated --ci-test invariants.

A. New CLI flag --enable-r3-correctness-check (miles/utils/arguments.py) that flips RoutingReplayManager.enable_check_replay_result = True. The flag also writes the value at module level in train_async.py after parse_args() (line 78-81).

B. Make the flag actually take effect (miles/backends/megatron_utils/actor.py). Previously actor.py unconditionally overwrote the value to self.args.ci_test for every replay manager (single-line assignment at L112), so the new flag was a no-op. Extends that condition to also honor --enable-r3-correctness-check:

m.enable_check_replay_result = m.enabled and (
    self.args.ci_test or getattr(self.args, "enable_r3_correctness_check", False)
)

This lets callers turn on just the R3 overlap check without enabling --ci-test's other strict-equality invariants — in particular, the log_probs == ref_log_probs check that trips on routine floating-point precision differences (~1e-3 gap) and was breaking the E2E run before this fix.

C. Add direct-evidence logs in RoutingReplayManager (miles/utils/replay_base.py + check). Two unconditional logger.info calls, gated on enable_check_replay_result so production training stays quiet:

  • One in new_topk_fn's replay_forward and replay_backward branches that logs R3 wrapper: replay_{forward,backward} branch taken (rank ..., n_tokens=..., topk=..., replay_idx_sum=...). Without this, there is no log evidence that megatron actually called through the wrapper (vs falling through to old_topk_fn).
  • One in check_replay_result that always logs R3 check (rank ..., stage ...): n_tokens=... mismatch=... (...%). Previously the check returned silently when mismatch_count == 0, making "check passed" indistinguishable from "check never ran."

These three logs together let the LLM360/RL360 E2E daemon prove that megatron's MoE forward and backward both used the rollout indices, not just that no assertion fired. See LLM360/RL360#317 validation evidence section.

Why this matters

Without this PR, the R3 regression E2E on M2 has no way to assert direct correctness of R3. The previous replay_base.py:178-219 check existed but was only callable via --ci-test, which trips a separate log_probs == ref_log_probs strict-equality check on the same run and crashes before backward fires.

Commits in this PR (substantive vs formatting)

Sorted by what kind of change they make:

Substantive (please review)

Commit Files Lines
0431dbf5arguments: add --enable-r3-correctness-check CLI flag miles/utils/arguments.py, train_async.py +16
0854adccreplay_base: direct-evidence logs for R3 wrapper + overlap check miles/utils/replay_base.py +33, -2
f019a625actor: make --enable-r3-correctness-check independent of --ci-test miles/backends/megatron_utils/actor.py +8, -1

Formatting only (mechanical, no behavior change)

Commit What
db437d22prod: apply black drift cleanup black==24.3.0 (the version pinned in .pre-commit-config.yaml) wanted to reformat 7 files when CI ran pre-commit run --all-files. Six were pre-existing drift on the prod base; the seventh is one blank line in train_async.py from this PR. 64 lines touched (+30/-34) across log_utils.py, loss.py, rollout.py, openai_endpoint_utils.py, linear_trajectory.py, replay_base.py, train_async.py. No symbol added or removed.
e06a6b3factor: apply black to the new condition (CI fix, no logic change) Single line: self.args.ci_test\n or getattr(...) collapsed onto one line per black's preference.

Validation

LLM360/RL360#317 section "Validation evidence" walks through the full forward + backward proof using miles at this PR's head (f019a625 snapshot, behaviorally identical to current head e06a6b3f). SLURM job 1654622 (manual) and 1655337 (daemon-driven, posted by llm360-deploy-bot on RL360 radixark#319) both COMPLETED with zero non-zero mismatches across thousands of per-rank-per-layer R3 wrapper calls in both forward and backward.

When set, flips RoutingReplayManager.enable_check_replay_result = True
so the per-step overlap check (replay_base.py:178-219) fires for every
training step. Off by default because the check roughly doubles the
cost of routing.

Intended for the R3 regression E2E on LLM360/RL360, which runs a small
GPU sbatch on M2 every time a submodule-pin bump PR opens. With this
flag, miles will raise AssertionError("R3 mismatch tokens ...") if the
overlap drops below MILES_TEST_R3_THRESHOLD (default 1e-2), giving the
E2E a hard pass/fail signal.

The R3 master switch (--use-rollout-routing-replay) is still required;
this flag has no effect without it.
@DavidBellamy DavidBellamy requested a review from a team as a code owner May 18, 2026 22:36
Six files on the prod base had black-non-compliant formatting that
pre-commit on PR #25 flagged as failures. Applying `black==24.3.0`
(matches .pre-commit-config.yaml) brings them in line so CI passes.

Also fixes the single line in train_async.py from this PR that black
wants (blank line after the import).

No behavioral changes; pure whitespace + line breaks.
The previous --enable-r3-correctness-check flag turned on the overlap
check but produced no log output unless an actual mismatch happened,
making it impossible to distinguish "check passed" from "check never
ran." Add two unconditional logs gated on enable_check_replay_result:

1. get_topk_fn / new_topk_fn replay_forward + replay_backward branches:
   log when the wrapper actually returns replay indices rather than
   falling through to old_topk_fn. Direct evidence megatron's MoE
   forward used the rollout indices (vs recomputing them).

2. check_replay_result: log n_tokens and mismatch_count on every call,
   including the mismatch_count==0 case (which previously returned
   silently). Direct evidence the check ran, plus the actual overlap
   number for cross-step / cross-rank comparison.

Both logs gated on enable_check_replay_result so production training
runs (which leave it False) stay quiet. Adds no overhead when off.

Intended to make the LLM360/RL360 R3 regression E2E able to assert
directly that R3 worked end-to-end, rather than inferring from absence
of failure messages.
actor.py:111-112 unconditionally set
  m.enable_check_replay_result = m.enabled and self.args.ci_test
which overrode the value we set in train_async.py from
--enable-r3-correctness-check. The flag was effectively a no-op.

This change keeps backward-compat for --ci-test and ALSO honors
--enable-r3-correctness-check on its own, so callers can enable the R3
overlap check without enabling the rest of --ci-test's invariants. In
particular --ci-test also enables a strict log_probs == ref_log_probs
equality check that trips on routine floating-point precision
differences (~1e-3 gap), so R3 callers need a way to opt into ONLY the
replay check.

Found during R3 E2E pre-merge validation: with --ci-test on, the R3
overlap check fired cleanly (1976+ checks, all mismatch=0%) but the
job then failed at the unrelated log_probs assertion before the
backward pass. With --enable-r3-correctness-check now wired through
properly, the same run reaches backward and can show
replay_backward branch evidence too.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant