Skip to content

Add fit_predict / fit_predict_proba to TabularCloudPredictor#253

Draft
shchur wants to merge 1 commit into
autogluon:masterfrom
shchur:tabular-fit-predict
Draft

Add fit_predict / fit_predict_proba to TabularCloudPredictor#253
shchur wants to merge 1 commit into
autogluon:masterfrom
shchur:tabular-fit-predict

Conversation

@shchur

@shchur shchur commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Fuse fit + batch predict into a single SageMaker training job: fit on train_data and predict on a separate test_data inside the training container (predict_after_fit=True), mirroring the existing TimeSeries fit_predict path. Avoids a second cold start, data upload, and the predictor-tarball round-trip of a separate batch-transform job.

  • train.py: branch the predict_after_fit block on predictor_type; tabular reads a new test_data channel and writes the full [pred, proba...] frame (regression: just [pred]). Raises NotImplementedError on image columns.
  • TabularSagemakerBackend: fit override that validates test_data covers the train feature columns client-side before launch, then injects it as a data channel.
  • TabularCloudPredictor: fit_predict -> pd.Series and fit_predict_proba(include_predict) -> (pred, proba) | proba, matching the predict / predict_proba split. wait=False returns None; fetch later via get_fit_predict_results / get_fit_predict_proba_results.
  • Tests: pure-unit wiring/validation tests plus a parametrized classification/regression integration test.

Issue #, if available:

Description of changes:

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Fuse fit + batch predict into a single SageMaker training job: fit on
train_data and predict on a separate test_data inside the training
container (predict_after_fit=True), mirroring the existing TimeSeries
fit_predict path. Avoids a second cold start, data upload, and the
predictor-tarball round-trip of a separate batch-transform job.

- train.py: branch the predict_after_fit block on predictor_type; tabular
  reads a new test_data channel and writes the full [pred, proba...] frame
  (regression: just [pred]). Raises NotImplementedError on image columns.
- TabularSagemakerBackend: fit override that validates test_data covers the
  train feature columns client-side before launch, then injects it as a
  data channel.
- TabularCloudPredictor: fit_predict -> pd.Series and
  fit_predict_proba(include_predict) -> (pred, proba) | proba, matching the
  predict / predict_proba split. wait=False returns None; fetch later via
  get_fit_predict_results / get_fit_predict_proba_results.
- Tests: pure-unit wiring/validation tests plus a parametrized
  classification/regression integration test.
@shchur shchur marked this pull request as draft June 26, 2026 19:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant