Skip to content

[cuda] tolerance-based gain tie-break in best-split reductions#13

Open
maxwbuckley wants to merge 24 commits into
BelixRogner:masterfrom
maxwbuckley:cuda/gain-tie-break
Open

[cuda] tolerance-based gain tie-break in best-split reductions#13
maxwbuckley wants to merge 24 commits into
BelixRogner:masterfrom
maxwbuckley:cuda/gain-tie-break

Conversation

@maxwbuckley

@maxwbuckley maxwbuckley commented May 10, 2026

Copy link
Copy Markdown
Collaborator

Summary

CPU's best-split finder scans bins ascending and uses strict >, so on a true gain plateau (multiple bins with mathematically equal split gain) the lowest-index bin wins. CUDA's parallel reduction has the same intent, but its histogram aggregation introduces ~1e-15 relative FP noise from the order of atomic adds. That noise flips which bin from the plateau has the slightly-higher numerical gain on CUDA, so the parallel argmax adopts a different bin than CPU's exact computation.

The threshold-encoding mismatch is cosmetic at round 1 — training data routes identically — but compounds through score updates and surfaces as structural tree divergence around round 3 in cases like reg_bagging and multi_dense from the parity sweep. It also affects TreeSHAP and out-of-sample prediction, which depend on the encoded threshold value rather than the routing decision.

The fix

ReduceBestGainWarp and ReduceBestGainBlock in cuda_best_split_finder.cu now use a small relative tolerance (1e-12, ~5000× fp64 epsilon) when comparing gains. Two gains within the tolerance band are treated as tied, and the reduction prefers the lower thread index — which corresponds to the lower bin index, matching CPU's "first bin in scan order wins" behaviour.

The tolerance is sized to absorb the observed ~1e-15 reduction-order noise without crossing any genuine gain difference seen in practice. The condition for adopting the other thread's value becomes:

const double tol = fmax(fabs(gain), fabs(other_gain)) * 1e-12;
const bool other_better = (other_gain > gain + tol)
                       || (fabs(other_gain - gain) <= tol && other_thread_index < thread_index);

Impact (CPU/CUDA parity sweep, 53 cases)

The sweep applies tolerances RAW_TOL = 1e-5 for predictions, THRESH_TOL = 1e-6 for tree thresholds, LEAF_TOL = 1e-5 for leaf values.

metric before after
Predictions exactly bit-identical (max|Δ| == 0) 8/53 8/53
Predictions within 1e-5 37/53 40/53
Tree structure within 1e-6/1e-5 tolerance 5/53 38/53
Both 5/53 38/53

The dominant effect is on tree structure: 33 more cases now have CPU and CUDA picking the same bin from the gain plateau, so the encoded threshold values agree (within THRESH_TOL). The fix does not change the bit-identical-predictions count — remaining FP-epsilon-magnitude prediction differences come from parallel-reduction noise in leaf-level grad/hess sums, which is a separate phenomenon untouched by this PR.

Three cases that previously diverged structurally now match within FP epsilon (rather than 1e-1 magnitude):

case before max|Δ| after max|Δ|
reg_bagging 3.87e-1 4.4e-16
reg_pred_contrib 4.83e-2 1.1e-16
reg_oos 4.49e-1 1.4e-17

The 13 "cosmetic threshold-encoding" cases identified in the prior parity investigation (e.g. reg_dense, reg_huber, bin_dense, multi_softmax_lr, multi_w, reg_dart, reg_goss, bin_goss, rank_xendcg and others) all move from "predictions match at FP epsilon, trees encode different thresholds" to "predictions still match at FP epsilon, trees match within THRESH_TOL."

The remaining 13 prediction-divergent cases have unrelated root causes (weighted percentile open from prior PRs, quantized training covered by Felix's PR #1, LambdaRank atomicAdd_block ordering, hybrid CPU objectives like gamma/tweedie/fair).

Test plan

  • Parametrized regression test added in tests/python_package_test/test_dual.py::test_cuda_split_gain_tie_break_matches_cpu covering bagging, plain dense, max_depth, and lambda_l2 configurations. Each was known to produce a different bin from a gain plateau prior to this fix.
  • Test gated on TASK=cuda per the existing convention.
  • All 4 cases pass on RTX 5090 / CUDA 13.2 / sm_120.
  • No regressions on the 5 cases that were already clean.
  • Verified that 13 cosmetic-threshold cases from the broader parity sweep now produce trees that agree within THRESH_TOL.

🤖 Generated with Claude Code

maxwbuckley and others added 21 commits May 10, 2026 11:40
CUDA 13.0 removed offline-compilation support for Maxwell (sm_50/52/53),
Pascal (sm_60/61/62), and Volta (sm_70/72). With nvcc 13.x, the
unconditional inclusion of sm_60/61/62/70 in CUDA_ARCHS causes the
build to fail with:

    nvcc fatal : Unsupported gpu architecture 'compute_60'

Gate those architectures behind a CUDAToolkit_VERSION VERSION_LESS
"13.0" check. With CUDA >= 13.0 the initial list starts at "75"
(Turing); the existing version-conditional appends below add 80, 86,
87, 89, 90, 100, 120 as appropriate.

Verified locally with CUDA 13.2 + RTX 5090 (sm_120): builds and
installs cleanly without any other changes.

Reference for the dropped capabilities:
https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
cpplint's --root=.. workaround derived the expected header-guard prefix
from the parent directory name. After renaming the repo to ExaBoost,
that prefix changed from LIGHTGBM_INCLUDE_*_H_ to EXABOOST_INCLUDE_*_H_
and every header now fails build/header_guard.

We deliberately did not rename the C/C++ symbols (still LightGBM,
LGBM_*, import lightgbm) to keep ExaBoost binary-compatible. Disable
the header-guard check in the cpplint pre-commit hook to match the
existing setup in .ci/lint-cpp.sh.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Builds with -DCMAKE_CUDA_ARCHITECTURES (e.g. "120-real" for a single-GPU
local iteration on RTX 5090) currently get overwritten unconditionally
by the toolkit-version-driven CUDA_ARCHS list, producing a multi-arch
build that takes much longer to compile and isn't what the user asked
for.

Wrap the existing toolkit-version logic in a check that only applies it
when CMAKE_CUDA_ARCHITECTURES is unset or empty. When the user passes
it explicitly, use their value verbatim.

No behavior change for users who don't pass the flag.

Composes with lightgbm-org#5 (the toolkit-version gating for CUDA 13.x dropped
archs) — both branches together give a sane default that adapts to the
toolkit, plus an escape hatch for fast local iteration.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CPU's best-split finder scans bins ascending and uses strict >, so on a
true gain plateau (multiple bins with mathematically equal split gain)
the lowest-index bin wins. CUDA's parallel reduction has the same
intent, but its histogram aggregation introduces ~1e-15 relative FP
noise from the order of atomic adds. That noise flips which bin from
the plateau has the slightly-higher numerical gain on CUDA, so the
parallel argmax adopts a different bin than CPU's exact computation.

The threshold-encoding mismatch is cosmetic at round 1 — training data
routes identically — but compounds through score updates and surfaces
as structural tree divergence around round 3 in cases like reg_bagging
and multi_dense. It also affects TreeSHAP and out-of-sample prediction,
which depend on the encoded threshold value rather than the routing.

Add a small tolerance band (1e-12 relative, ~5000x fp64 epsilon) to
the bin-level reductions ReduceBestGainWarp / ReduceBestGainBlock:
when |g1 - g2| <= tol, treat as tied and prefer the lower thread
index (= lower bin index). Sized to absorb the ~1e-15 reduction noise
without crossing any genuine gain difference seen in practice.

Impact on the broader CPU/CUDA parity sweep (53 cases):
- 5/53 clean -> 38/53 clean
- All 13 cosmetic threshold-encoding cases now produce bit-identical
  trees (previously matched at fp epsilon for predictions but encoded
  different thresholds)
- reg_bagging max|Δ| drops from 0.39 to 4.4e-16 (round-3 structural
  divergence eliminated)
- reg_pred_contrib (TreeSHAP) drops from 0.048 to 1.1e-16
- reg_oos (out-of-sample predict) drops from 0.45 to 1.4e-17

Adds a parametrized regression test in test_dual.py covering bagging,
plain dense, max_depth, and lambda_l2 configurations. All four were
known to produce different bins from a gain plateau prior to this fix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sklearn>=1.9 dev routes check_classification_targets and LabelEncoder
through narwhals, which raises TypeError on a bare pyarrow Array /
ChunkedArray ("Please set `allow_series=True` or `series_only=True`")
because sklearn does not pass that flag. The Python - latest versions
(manylinux_2_28) CI job has been failing for 18 test variants of
test_classification_and_regression_minimally_work_with_all_accepted_data_types
on every PR for this reason.

We advertise pyarrow Array / ChunkedArray as accepted label types
(_LGBM_LabelType), so the user-facing contract should be preserved.
Convert eagerly to numpy at the top of LGBMClassifier.fit, before
calling into sklearn — _LGBMAssertAllFinite, _LGBMCheckClassificationTargets,
and _LGBMLabelEncoder all see a familiar 1-D array.

No behavior change for non-pyarrow y. Regression tests (LGBMRegressor)
don't hit this path because they don't call check_classification_targets;
they were already passing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Compares matched configurations on CPU and CUDA at tight tolerance
(1e-5 raw_score, exact tree structure). Initial run on 19 tiny configs
finds 6 with real prediction divergence (reg_quantile, reg_categorical,
reg_l1, reg_bagging, reg_max_depth, multi_dense) and 13 where
predictions match at FP epsilon despite tree-dump threshold differences.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The CUDA categorical split-finder kernels accepted min_data_per_group as
a function parameter but never referenced it in the function body, so
the constraint had zero effect on CUDA training. CPU correctly enforces
it via FindBestThresholdCategoricalInner in feature_histogram.cpp.

Add the missing left/right count check to the candidate-acceptance
condition in both the shared-memory and global-memory variants of the
categorical kernel, in both the left-to-right and right-to-left scans.

Verified with scratch/probe_categorical3.py: across min_data_per_group
values from 1 to 1,000,000, CPU and CUDA now produce identical splits
or both correctly decline to split. Also closes the reg_categorical
case in the broader CPU/CUDA parity sweep.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Felix asked for a real CI-runnable regression test that locks in the
categorical-kernel fix. Mirrors the scratch/probe_categorical3.py probe:
on a 200-row, 5-category dataset (~40 rows per group), train one round
on CPU and CUDA at min_data_per_group in {10, 41, 100, 1000} and assert
both produce the same split decision.

Before the fix, CUDA accepted the split at mdpg in {100, 1000, 1_000_000}
while CPU correctly refused; the assertion (None, None) != (0, 44.910)
trips loudly.

Gated on TASK=cuda to match the existing CUDA-only test pattern in
test_engine.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The reg_categorical case now lives in test_engine.py as a real
regression test, so the dev-only parity script no longer needs
to ship in the production tree. Removing it also clears the lint
errors (T201 print, F841 unused var) that were blocking CI.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The CUDA PercentileDevice (used by L1 and quantile leaf-value renewal)
computed the percentile position against `len` instead of `len - 1`,
and indexed it as 0-based instead of CPU's 1-based-with-+1 offset.
For alpha=0.5 (median), this returned the upper-middle element on
even-length arrays and the average of the upper-middle and median on
odd-length arrays - i.e., systematically biased upward in the
descending-sort convention that PercentileDevice uses.

CPU PercentileFun (src/objective/regression_objective.hpp:28-29):

    const double float_pos = static_cast<double>(cnt_data - 1) * (1.0 - alpha);
    const data_size_t pos = static_cast<data_size_t>(float_pos) + 1;
    ...
    const double bias = float_pos - (pos - 1);

This matches the standard Type-7 interpolated quantile (numpy.median,
R's quantile() default).

Verified against numpy:
  reg_l1     leaf-value max delta vs np.median:    0.5 -> 0.0 (after fix)
  reg_quantile leaf-value max delta vs np.quantile: 0.6 -> 0.0 (after fix)

After this fix every leaf in the parity benchmark reproducer matches
its numpy counterpart to FP epsilon. There is a residual structural
divergence on reg_l1 (CPU and CUDA disagree on a few splits) which
will be investigated separately - this PR fixes only the leaf-value
calculation.

The weighted-percentile path uses different conventions on CPU and
CUDA (ascending vs descending sort, alpha vs 1-alpha threshold) and
is left untouched here. None of our parity tests exercise it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Regression coverage for the unweighted PercentileDevice formula fix
(prior commit). Three parametrized tests, all gated on
LIGHTGBM_TEST_CUDA=1 so they only run on a CUDA-enabled build:

- test_cuda_l1_leaf_renewal_matches_numpy_median: across 3 random
  seeds, asserts every leaf value on both CPU and CUDA matches
  numpy.median over the leaf's data points.
- test_cuda_quantile_leaf_renewal_matches_numpy_quantile: same shape
  but parametrized over alpha = 0.1, 0.25, 0.5, 0.7, 0.9 to cover
  every even/odd leaf-size combination of the percentile bias.
- test_cuda_l1_median_handles_small_even_and_odd_leaves: targets the
  exact failure mode of the old formula (even-length leaves returned
  sorted[1] instead of avg(sorted[1], sorted[2])) by sweeping leaves
  of size 2, 3, 4, 5, 8, 9.

Tolerance is 1e-6 - well below the ~0.3 bias the old formula
produced, but loose enough to absorb label_t float32 quantization
inside the renewal kernel.

Verified: with the prior commit reverted, 13 of 14 cases fail with
bias > 1e-6; with the fix applied, all 14 pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Aligns with the existing convention used by test_engine.py's CUDA-only
tests. Addresses Felix's review note (same change going on lightgbm-org#7/lightgbm-org#8/lightgbm-org#9/lightgbm-org#10).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same bug as PR lightgbm-org#6 fixed for the in-block PercentileDevice, but in the
global-memory kernel used for init-score computation. The unweighted
branch of PercentileGlobalKernel computed the percentile position
against `len` instead of `len - 1`, biasing alpha=0.5 toward the
upper-middle element on descending-sort layouts.

Reproducer (with the Python wrapper's optimization that drops uniform
weights, this is the path actually executed by `objective=regression_l1`
or `quantile` when sample weights aren't supplied or are all 1):

  y = [1, 2, 3, 4, 5]
  init_score (numpy median): 3.0
  CPU init_score:            3.0  (correct)
  CUDA init_score (before):  3.5  (biased toward upper)
  CUDA init_score (after):   3.0  (correct)

This fix mirrors PR lightgbm-org#6 in PercentileDevice and uses the same Type-7
interpolated-quantile formula:

  float_pos = (1 - alpha) * (len - 1)
  pos       = floor(float_pos) + 1
  bias      = float_pos - (pos - 1)

Parity-sweep impact:

  reg_l1     max|Δ|: 0.25  -> 0.000e+00
  reg_quantile max|Δ|: 0.54  -> 0.000e+00

The weighted branch of PercentileGlobalKernel uses different
conventions and is not touched by this PR. There appears to be an
unrelated bug in the CPU `WeightedPercentileFun` macro (off-by-one in
which cdf delta is used in the interpolation), but that affects only
non-uniform-weight workloads and is out of scope here - the Python
wrapper drops uniform weights, so this PR's unweighted-formula fix
already covers the common path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Regression coverage for the prior commit. 24 parametrized cases
across (objective, alpha, n) verifying the init score logged by
'Start training from score' matches between CPU and CUDA at FP epsilon.

Without the fix, regression_l1 (alpha=0.5) and quantile failed for
small n where the formula bias landed on a different element.

Gated on LIGHTGBM_TEST_CUDA=1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Aligns with the existing convention used by test_engine.py's CUDA-only
tests. Addresses Felix's review note (same change going on lightgbm-org#6/lightgbm-org#7/lightgbm-org#8/lightgbm-org#10).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@BelixRogner

Copy link
Copy Markdown
Owner

Thanks Max — and Claude Code. Sharp diagnosis: CPU's sequential scan with strict > gives a deterministic "first bin in scan order wins" on a true gain plateau, and CUDA's parallel reduction picks up ~1e-15 ULP noise from atomic-add order in histogram aggregation that flips which bin from the plateau has the slightly-higher numerical gain. Cosmetic at round 1, but compounds via score updates into structural tree divergence by round 3 — and affects TreeSHAP / OOS predictions since they encode the threshold itself.

The fix is correct in structure: OtherIsBetterWithTieBreak compares with relative 1e-12 tolerance (~5000× fp64 epsilon, well above the ~1e-15 noise floor, well below any plausible real gain difference) and tie-breaks to the lower thread index. With the monotone thread→bin mapping in this reduction, that recovers CPU's "first bin wins" semantics. Applied symmetrically to ReduceBestGainWarp and ReduceBestGainBlock.

Impact (tree-structure match 5/53 → 38/53, three structurally-divergent cases drop from 0.05–0.4 max|Δ| to fp64-epsilon) is the right kind of result.

Tests cover the failure modes; gated correctly on TASK=cuda. Comment block on the tolerance choice will help future maintainers.

Merging once CI lint completes (still in flight at time of comment).

@BelixRogner BelixRogner marked this pull request as ready for review May 19, 2026 04:42
@BelixRogner

Copy link
Copy Markdown
Owner

Quick rebase nudge — #7 and #8 just landed on master and touched files this PR also modifies (cuda_algorithms.hpp / cuda_best_split_finder.cu / etc.), so this branch now shows a merge conflict on GitHub. One more git merge master && git push round should clear it; CI is otherwise green (the apparent failures yesterday were all environmental — dask socket flakes, a cancelled job rolling up, and a Boost-headers wheel-build issue, none of which were touching the actual PR content).

Ready to merge as soon as the conflict's resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants