Skip to content

[cuda] implement max_delta_step output cap on the CUDA tree learner#20

Open
maxwbuckley wants to merge 36 commits into
BelixRogner:masterfrom
maxwbuckley:cuda/max-delta-step-error
Open

[cuda] implement max_delta_step output cap on the CUDA tree learner#20
maxwbuckley wants to merge 36 commits into
BelixRogner:masterfrom
maxwbuckley:cuda/max-delta-step-error

Conversation

@maxwbuckley

@maxwbuckley maxwbuckley commented May 31, 2026

Copy link
Copy Markdown
Collaborator

Problem: the delta

When training with device_type="cuda", max_delta_step was silently ignored. The leaf-output cap

if (max_delta_step > 0 && std::fabs(ret) > max_delta_step) {
  ret = Common::Sign(ret) * max_delta_step;
}

exists only in the CPU FeatureHistogram::CalculateSplittedLeafOutput (the USE_MAX_OUTPUT path). The CUDA leaf-output device functions had no max_delta_step parameter and no cap.

Measured delta before this PR (400×6 synthetic data, learning_rate=1.0, num_leaves=15, 5 rounds, gpu_use_dp=true; tree-0 leaf-value spread should be ≤ 2*max_delta_step):

objective max_delta_step leaf spread CPU leaf spread CUDA cap (2·mds) max pred delta
binary 0.05 0.100 4.150 0.1 17.7
binary 0.1 0.200 4.150 0.2 17.7
binary 0.5 1.000 4.150 1.0 18.9
regression 0.05 0.100 4.017 0.1 2.48
regression 0.1 0.200 4.017 0.2 2.23
regression 0.5 1.000 4.017 1.0 0.66

CPU enforces the cap; CUDA's leaves are unbounded — predictions diverge by up to 17.7 (raw score).

Fix

Implements the cap on the CUDA tree learner, identical to the CPU formula, threaded as a runtime double (0 = inactive) through:

  • CUDALeafSplits::CalculateSplittedLeafOutput / GetLeafGain / GetSplitGains (the cap sits between the L1 step and the smoothing step, exactly like CPU; the closed-form gain shortcut is bypassed when the cap is active, matching CPU's !USE_MAX_OUTPUT && !USE_SMOOTHING condition)
  • the root-leaf init kernels (cuda_leaf_splits.cu)
  • every best-split-finder kernel — numerical, categorical, global-memory, and discretized variants (cuda_best_split_finder.cu)
  • the root SetLeafOutput and the refit kernels (cuda_single_gpu_tree_learner.cpp/.cu)

The first commit's Log::Fatal guard is removed: max_delta_step now works on CUDA instead of being rejected.

Result: the delta after the fix

Same measurement, post-fix:

objective max_delta_step leaf spread CPU leaf spread CUDA cap (2·mds) max pred delta
binary 0.05 0.100 0.100 0.1 0.40 †
binary 0.1 0.200 0.200 0.2 0.40 †
binary 0.5 1.000 1.000 1.0 1.27 †
regression 0.05 0.100 0.100 0.1 0.15 †
regression 0.1 0.200 0.200 0.2 0.18 †
regression 0.5 1.000 1.000 1.0 4.4e-16

And in the non-degenerate regime (regression, learning_rate=0.1, 10 rounds, min_data_in_leaf=5):

max_delta_step max pred delta
1.0 0.0 (bit-identical)
2.0 3.3e-16
5.0 4.4e-16

Why some configs still differ structurally: with learning_rate=1.0 + min_data_in_leaf=1 + a tiny cap, every leaf saturates at ±max_delta_step, so all candidate split gains collapse onto a plateau of equal values. CPU and CUDA then pick different — but equally optimal — splits from the plateau (ULP-level FP tie-breaking, the known lightgbm-org#6055 family; same root cause PR #13 addresses). The cap itself is enforced identically (spread columns match exactly), and training loss agrees within 0.6% in those configs. This is a property of gain plateaus, not of the max_delta_step implementation.

Tests (in test_dual.py, gated on TASK=cuda)

Test Cases What it pins down
test_cuda_max_delta_step_caps_outputs_like_cpu 6 leaf spread ≤ 2·mds on both backends, spreads equal — fails on the old build (CUDA spread 4.15 vs cap 0.1)
test_cuda_max_delta_step_matches_cpu_exactly 3 bit-level prediction parity (atol=1e-10) in non-saturating configs
test_cuda_max_delta_step_loss_matches_cpu_when_saturated 2 cap enforced on every tree + training loss within 2% in the plateau regime

All 11 cases pass with the fix; the cap-enforcement test fails on the unpatched build. Full test_dual.py suite (65 tests) passes.

🤖 Generated with Claude Code

maxwbuckley and others added 30 commits May 10, 2026 01:44
Two related bugs caused CUDA to ignore the `max_depth` parameter:

1. CUDABestSplitFinder::FindBestSplitsForLeaf had no max_depth check.
   CPU's SerialTreeLearner::BeforeFindBestSplit invalidates a leaf's
   gain when its depth has reached config_->max_depth, but the CUDA
   path never did the equivalent.

2. CUDATree::Split / SplitCategorical updated the GPU-side
   cuda_leaf_depth_ via the launch kernel but never updated the
   host-side leaf_depth_ vector, so tree->leaf_depth(idx) always
   returned 0 on CUDA. Without (2), even adding the check at (1)
   would have done nothing.

Symptom (max_depth=2, varying num_leaves):

  num_leaves= 4: cpu depth=2 leaves=4 | cuda depth=2 leaves=4
  num_leaves= 7: cpu depth=2 leaves=4 | cuda depth=3 leaves=7
  num_leaves=15: cpu depth=2 leaves=4 | cuda depth=5 leaves=15
  num_leaves=31: cpu depth=2 leaves=4 | cuda depth=7 leaves=31

After fix, CUDA caps at the requested depth (2) for every num_leaves.

Fix:

* Mirror the host-side leaf_depth_ update in CUDATree::Split and
  CUDATree::SplitCategorical (matching CPU Tree::Split's behavior in
  include/LightGBM/tree.h).
* Plumb a `smaller_leaf_below_max_depth` / `larger_leaf_below_max_depth`
  flag pair into FindBestSplitsForLeaf and AND them into the
  is_*_leaf_valid checks. The caller in
  cuda_single_gpu_tree_learner.cpp computes them as
  `config_->max_depth <= 0 || tree->leaf_depth(idx) < config_->max_depth`.

Verified with the cpu/cuda parity sweep: reg_max_depth case (which used
max_depth=3 with num_leaves=7) now matches CPU at FP epsilon, down from
max|Δ|=0.25 raw_score.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Regression coverage for the prior commit (CUDA tree learner now
enforces max_depth). Two parametrized tests, gated on
LIGHTGBM_TEST_CUDA=1:

- test_cuda_respects_max_depth: across (max_depth, num_leaves)
  combinations from {1,2,3,5} x {2,4,7,31}, asserts CUDA tree depth
  is at most max_depth and matches CPU depth + leaf count exactly.
- test_cuda_max_depth_matches_cpu_predictions: end-to-end check
  that 5 boosting rounds with max_depth=3 produce CPU/CUDA
  predictions matching at FP epsilon. Without the fix, this
  diverged by max|Δ|=0.47.

Verified: with the prior commit reverted, 5 of 9 cases fail
(those where num_leaves > 2^max_depth, i.e. where the bug actually
triggered). With the fix applied, all 9 pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CUDA 13.0 removed offline-compilation support for Maxwell (sm_50/52/53),
Pascal (sm_60/61/62), and Volta (sm_70/72). With nvcc 13.x, the
unconditional inclusion of sm_60/61/62/70 in CUDA_ARCHS causes the
build to fail with:

    nvcc fatal : Unsupported gpu architecture 'compute_60'

Gate those architectures behind a CUDAToolkit_VERSION VERSION_LESS
"13.0" check. With CUDA >= 13.0 the initial list starts at "75"
(Turing); the existing version-conditional appends below add 80, 86,
87, 89, 90, 100, 120 as appropriate.

Verified locally with CUDA 13.2 + RTX 5090 (sm_120): builds and
installs cleanly without any other changes.

Reference for the dropped capabilities:
https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
cpplint's --root=.. workaround derived the expected header-guard prefix
from the parent directory name. After renaming the repo to ExaBoost,
that prefix changed from LIGHTGBM_INCLUDE_*_H_ to EXABOOST_INCLUDE_*_H_
and every header now fails build/header_guard.

We deliberately did not rename the C/C++ symbols (still LightGBM,
LGBM_*, import lightgbm) to keep ExaBoost binary-compatible. Disable
the header-guard check in the cpplint pre-commit hook to match the
existing setup in .ci/lint-cpp.sh.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Builds with -DCMAKE_CUDA_ARCHITECTURES (e.g. "120-real" for a single-GPU
local iteration on RTX 5090) currently get overwritten unconditionally
by the toolkit-version-driven CUDA_ARCHS list, producing a multi-arch
build that takes much longer to compile and isn't what the user asked
for.

Wrap the existing toolkit-version logic in a check that only applies it
when CMAKE_CUDA_ARCHITECTURES is unset or empty. When the user passes
it explicitly, use their value verbatim.

No behavior change for users who don't pass the flag.

Composes with lightgbm-org#5 (the toolkit-version gating for CUDA 13.x dropped
archs) — both branches together give a sane default that adapts to the
toolkit, plus an escape hatch for fast local iteration.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Aligns with the existing convention used by test_engine.py's CUDA-only
tests. Addresses Felix's review note (same change going on lightgbm-org#6/lightgbm-org#8/lightgbm-org#9/lightgbm-org#10).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sklearn>=1.9 dev routes check_classification_targets and LabelEncoder
through narwhals, which raises TypeError on a bare pyarrow Array /
ChunkedArray ("Please set `allow_series=True` or `series_only=True`")
because sklearn does not pass that flag. The Python - latest versions
(manylinux_2_28) CI job has been failing for 18 test variants of
test_classification_and_regression_minimally_work_with_all_accepted_data_types
on every PR for this reason.

We advertise pyarrow Array / ChunkedArray as accepted label types
(_LGBM_LabelType), so the user-facing contract should be preserved.
Convert eagerly to numpy at the top of LGBMClassifier.fit, before
calling into sklearn — _LGBMAssertAllFinite, _LGBMCheckClassificationTargets,
and _LGBMLabelEncoder all see a familiar 1-D array.

No behavior change for non-pyarrow y. Regression tests (LGBMRegressor)
don't hit this path because they don't call check_classification_targets;
they were already passing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Compares matched configurations on CPU and CUDA at tight tolerance
(1e-5 raw_score, exact tree structure). Initial run on 19 tiny configs
finds 6 with real prediction divergence (reg_quantile, reg_categorical,
reg_l1, reg_bagging, reg_max_depth, multi_dense) and 13 where
predictions match at FP epsilon despite tree-dump threshold differences.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The CUDA categorical split-finder kernels accepted min_data_per_group as
a function parameter but never referenced it in the function body, so
the constraint had zero effect on CUDA training. CPU correctly enforces
it via FindBestThresholdCategoricalInner in feature_histogram.cpp.

Add the missing left/right count check to the candidate-acceptance
condition in both the shared-memory and global-memory variants of the
categorical kernel, in both the left-to-right and right-to-left scans.

Verified with scratch/probe_categorical3.py: across min_data_per_group
values from 1 to 1,000,000, CPU and CUDA now produce identical splits
or both correctly decline to split. Also closes the reg_categorical
case in the broader CPU/CUDA parity sweep.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Felix asked for a real CI-runnable regression test that locks in the
categorical-kernel fix. Mirrors the scratch/probe_categorical3.py probe:
on a 200-row, 5-category dataset (~40 rows per group), train one round
on CPU and CUDA at min_data_per_group in {10, 41, 100, 1000} and assert
both produce the same split decision.

Before the fix, CUDA accepted the split at mdpg in {100, 1000, 1_000_000}
while CPU correctly refused; the assertion (None, None) != (0, 44.910)
trips loudly.

Gated on TASK=cuda to match the existing CUDA-only test pattern in
test_engine.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The reg_categorical case now lives in test_engine.py as a real
regression test, so the dev-only parity script no longer needs
to ship in the production tree. Removing it also clears the lint
errors (T201 print, F841 unused var) that were blocking CI.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The CUDA PercentileDevice (used by L1 and quantile leaf-value renewal)
computed the percentile position against `len` instead of `len - 1`,
and indexed it as 0-based instead of CPU's 1-based-with-+1 offset.
For alpha=0.5 (median), this returned the upper-middle element on
even-length arrays and the average of the upper-middle and median on
odd-length arrays - i.e., systematically biased upward in the
descending-sort convention that PercentileDevice uses.

CPU PercentileFun (src/objective/regression_objective.hpp:28-29):

    const double float_pos = static_cast<double>(cnt_data - 1) * (1.0 - alpha);
    const data_size_t pos = static_cast<data_size_t>(float_pos) + 1;
    ...
    const double bias = float_pos - (pos - 1);

This matches the standard Type-7 interpolated quantile (numpy.median,
R's quantile() default).

Verified against numpy:
  reg_l1     leaf-value max delta vs np.median:    0.5 -> 0.0 (after fix)
  reg_quantile leaf-value max delta vs np.quantile: 0.6 -> 0.0 (after fix)

After this fix every leaf in the parity benchmark reproducer matches
its numpy counterpart to FP epsilon. There is a residual structural
divergence on reg_l1 (CPU and CUDA disagree on a few splits) which
will be investigated separately - this PR fixes only the leaf-value
calculation.

The weighted-percentile path uses different conventions on CPU and
CUDA (ascending vs descending sort, alpha vs 1-alpha threshold) and
is left untouched here. None of our parity tests exercise it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Regression coverage for the unweighted PercentileDevice formula fix
(prior commit). Three parametrized tests, all gated on
LIGHTGBM_TEST_CUDA=1 so they only run on a CUDA-enabled build:

- test_cuda_l1_leaf_renewal_matches_numpy_median: across 3 random
  seeds, asserts every leaf value on both CPU and CUDA matches
  numpy.median over the leaf's data points.
- test_cuda_quantile_leaf_renewal_matches_numpy_quantile: same shape
  but parametrized over alpha = 0.1, 0.25, 0.5, 0.7, 0.9 to cover
  every even/odd leaf-size combination of the percentile bias.
- test_cuda_l1_median_handles_small_even_and_odd_leaves: targets the
  exact failure mode of the old formula (even-length leaves returned
  sorted[1] instead of avg(sorted[1], sorted[2])) by sweeping leaves
  of size 2, 3, 4, 5, 8, 9.

Tolerance is 1e-6 - well below the ~0.3 bias the old formula
produced, but loose enough to absorb label_t float32 quantization
inside the renewal kernel.

Verified: with the prior commit reverted, 13 of 14 cases fail with
bias > 1e-6; with the fix applied, all 14 pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Aligns with the existing convention used by test_engine.py's CUDA-only
tests. Addresses Felix's review note (same change going on lightgbm-org#7/lightgbm-org#8/lightgbm-org#9/lightgbm-org#10).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same bug as PR lightgbm-org#6 fixed for the in-block PercentileDevice, but in the
global-memory kernel used for init-score computation. The unweighted
branch of PercentileGlobalKernel computed the percentile position
against `len` instead of `len - 1`, biasing alpha=0.5 toward the
upper-middle element on descending-sort layouts.

Reproducer (with the Python wrapper's optimization that drops uniform
weights, this is the path actually executed by `objective=regression_l1`
or `quantile` when sample weights aren't supplied or are all 1):

  y = [1, 2, 3, 4, 5]
  init_score (numpy median): 3.0
  CPU init_score:            3.0  (correct)
  CUDA init_score (before):  3.5  (biased toward upper)
  CUDA init_score (after):   3.0  (correct)

This fix mirrors PR lightgbm-org#6 in PercentileDevice and uses the same Type-7
interpolated-quantile formula:

  float_pos = (1 - alpha) * (len - 1)
  pos       = floor(float_pos) + 1
  bias      = float_pos - (pos - 1)

Parity-sweep impact:

  reg_l1     max|Δ|: 0.25  -> 0.000e+00
  reg_quantile max|Δ|: 0.54  -> 0.000e+00

The weighted branch of PercentileGlobalKernel uses different
conventions and is not touched by this PR. There appears to be an
unrelated bug in the CPU `WeightedPercentileFun` macro (off-by-one in
which cdf delta is used in the interpolation), but that affects only
non-uniform-weight workloads and is out of scope here - the Python
wrapper drops uniform weights, so this PR's unweighted-formula fix
already covers the common path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Regression coverage for the prior commit. 24 parametrized cases
across (objective, alpha, n) verifying the init score logged by
'Start training from score' matches between CPU and CUDA at FP epsilon.

Without the fix, regression_l1 (alpha=0.5) and quantile failed for
small n where the formula bias landed on a different element.

Gated on LIGHTGBM_TEST_CUDA=1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Aligns with the existing convention used by test_engine.py's CUDA-only
tests. Addresses Felix's review note (same change going on lightgbm-org#6/lightgbm-org#7/lightgbm-org#8/lightgbm-org#10).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…PT006 fix

Squashes four local iterations: drop the prediction-parity test and
num_leaves parity assertion (keep only the two depth assertions), drop
redundant objective=regression (default value), use tuple for
parametrize argnames (ruff PT006), and shrink fixture to n=64 / 4
features / min_data_in_leaf=1 — cuts runtime ~6x.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts:
#	tests/python_package_test/test_dual.py
The function had two related bugs:

1. shared_buffer is declared __shared__ REDUCE_VAL_T shared_buffer[WARPSIZE]
   (32 entries), but the line `const REDUCE_VAL_T thread_base =
   shared_buffer[threadIdx.x]` reads at threadIdx.x in [0, blockDim.x).
   When blockDim.x > WARPSIZE (e.g. 256 for the L1/quantile renewal
   kernels), threadIdx.x in [WARPSIZE, blockDim.x) reads out-of-bounds
   shared memory.

2. The loop body `out_values[index] = thread_base + in_values[...]`
   does not cumulate within the per-thread chunk. It is correct only
   when num_data_per_thread == 1.

Together these manifest as an "illegal memory access" crash on weighted
L1 / weighted quantile training with n >= ~100 samples. Symptom:

    [LightGBM] [Fatal] [CUDA] an illegal memory access was encountered
    .../cuda_regression_objective.cu 225 (SynchronizeCUDADevice after
    RenewTreeOutputCUDAKernel_RegressionL1<USE_WEIGHT=true>)

Fix: use the per-thread exclusive prefix sum already returned by
ShufflePrefixSumExclusive (matching the existing correct usage in
GlobalMemoryPrefixSum at line 183), and cumulate inclusively across
the chunk.

Verified: weighted L1 and weighted quantile now train successfully on
n in {100, 200, 500, 1000} on RTX 5090 / CUDA 13.2. Predictions match
CPU within the typical L1/quantile FP-precision range.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Verifies CUDA weighted L1 / weighted quantile training does not raise
"illegal memory access" for n in {100, 200, 500, 1000}. Without the
prior fix, these all crashed in ShuffleSortedPrefixSumDevice.

Gated on LIGHTGBM_TEST_CUDA=1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Aligns with the existing convention used by test_engine.py's CUDA-only
tests (getenv("TASK", "") != "cuda"). Addresses Felix's review note on
PR lightgbm-org#8 (and the matching note on lightgbm-org#6, lightgbm-org#7, lightgbm-org#9, lightgbm-org#10).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…orcement

[cuda] enforce max_depth on CUDA tree learner
@maxwbuckley maxwbuckley force-pushed the cuda/max-delta-step-error branch from a118816 to 977f51b Compare May 31, 2026 23:16
The max_delta_step leaf-output clamp is implemented only in the CPU
FeatureHistogram (CalculateSplittedLeafOutput with USE_MAX_OUTPUT). The CUDA
leaf-output kernels have no equivalent term, so on CUDA the parameter was
silently ignored and leaf values / split gains diverged from CPU whenever the
clamp would bind. Fail fast in Config::CheckParamConflict until on-GPU
enforcement is implemented, matching the house convention of Log::Fatal for
unsupported CUDA features. max_delta_step defaults to 0 (inactive), so only
users who explicitly set it are affected. CPU and GPU(OpenCL) behavior is
unchanged.

Adds a parametrized regression test in test_dual.py asserting CPU trains while
CUDA raises.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@maxwbuckley maxwbuckley force-pushed the cuda/max-delta-step-error branch from 977f51b to 0d440a1 Compare May 31, 2026 23:17
Implements the max_delta_step leaf-output cap in the CUDA tree learner,
matching the CPU FeatureHistogram USE_MAX_OUTPUT formula:

    if (max_delta_step > 0 && fabs(ret) > max_delta_step)
      ret = sign(ret) * max_delta_step;

applied between the L1-threshold step and the path-smoothing step, in:
- CUDALeafSplits::CalculateSplittedLeafOutput / GetLeafGain / GetSplitGains
  (cuda_leaf_splits.hpp), with the closed-form gain shortcut bypassed
  whenever the cap is active (matching CPU's !USE_MAX_OUTPUT && !USE_SMOOTHING
  condition)
- the root-leaf init kernels (cuda_leaf_splits.cu)
- every best-split-finder kernel: numerical, categorical, global-memory and
  discretized variants (cuda_best_split_finder.cu), threaded as a runtime
  double through the kernel signatures and launch macros
- the root SetLeafOutput and the refit kernels
  (cuda_single_gpu_tree_learner.cpp/.cu)

This replaces the previous Log::Fatal guard: max_delta_step now works on
CUDA instead of being rejected.

Before this change CUDA silently ignored max_delta_step: with
max_delta_step=0.05, CPU capped the tree-0 leaf spread at 0.10 while CUDA
produced 4.15 (unbounded), with prediction divergence up to 17.7 (binary,
raw score). After this change the cap is enforced identically on both
backends and predictions match at FP epsilon (<= 4.4e-16) in non-degenerate
configurations.

Known limitation: when a very small max_delta_step saturates every leaf,
all candidate split gains collapse to ~0 and CPU/CUDA may pick different
but equally-optimal splits from the gain plateau (FP tie-breaking, the
pre-existing lightgbm-org#6055 family). The cap itself is still enforced and training
loss matches within 0.6%; tree structure can differ. This is covered by a
dedicated loss-equivalence test rather than an exact-match test.

Adds three regression tests to test_dual.py:
- test_cuda_max_delta_step_caps_outputs_like_cpu (cap enforcement, 6 cases)
- test_cuda_max_delta_step_matches_cpu_exactly (bit parity, 3 cases)
- test_cuda_max_delta_step_loss_matches_cpu_when_saturated (plateau regime, 2 cases)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@maxwbuckley maxwbuckley changed the title [cuda] reject max_delta_step instead of silently ignoring it [cuda] implement max_delta_step output cap on the CUDA tree learner Jun 1, 2026
@maxwbuckley maxwbuckley marked this pull request as ready for review June 1, 2026 20:36
@BelixRogner

Copy link
Copy Markdown
Owner

Thank you, Max — and thank you, Claude Code (independently 🙂). Verdict: SOLID, no bug found.

Reviewed against the CPU USE_MAX_OUTPUT cap in CalculateSplittedLeafOutput:

  • Cap formula is equivalent, and the Common::Sign(0)=0 vs CUDA ret>=0?+md:-md difference is provably unreachable — the branch is gated on fabs(ret) > max_delta_step > 0, so ret==0 never enters it. (Cosmetically a real Sign() would read more faithfully, but behavior is identical.)
  • Cap sits in the correct position: L1 step → cap → smoothing, matching CPU.
  • The closed-form gain shortcut is correctly bypassed when the cap is active — the CPU compile-time !USE_MAX_OUTPUT && !USE_SMOOTHING guard is faithfully turned into a runtime !USE_SMOOTHING && max_delta_step <= 0.0 guard.
  • Every site is threaded: leaf-output/gain functions, root-init kernels, all four best-split-finder families + their global-memory variants, root SetLeafOutput, and the refit kernels. I grepped every CalculateSplittedLeafOutput/GetLeafGain/GetSplitGains call — no site takes the param without applying it, none missed. min_gain_shift propagates cap-aware, consistent with CPU.
  • Default 0 is a true no-op (every application guarded by max_delta_step > 0).

Your gain-plateau scope note is honest: the structural divergence at learning_rate=1.0 is the expected ULP-tie signature, not a cap defect — the spread-exact and bit-parity assertions would break if it were a cap bug.

Non-blocking: no path_smooth>0 + cap, categorical, or lambda_l1>0 + cap test; core cap behavior is well covered.

Blocker is only the required ruff format + git merge master (these five overlap). macOS red = dask socket flake.

P.S. — you threaded max_delta_step through ~9 files and a dozen kernels without missing one, then capped your output at "ran the linter: no". 😄 ruff format has a step cap of exactly one command.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants