Skip to content

Adding training functionalities to Toolkit#108

Merged
laserkelvin merged 395 commits into
NVIDIA:mainfrom
laserkelvin:training-epic
Jun 26, 2026
Merged

Adding training functionalities to Toolkit#108
laserkelvin merged 395 commits into
NVIDIA:mainfrom
laserkelvin:training-epic

Conversation

@laserkelvin

@laserkelvin laserkelvin commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

ALCHEMI Toolkit Pull Request

Description

This PR introduces the core functionalities required to support training and fine-tuning of models in nvalchemi-toolkit.

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Performance improvement
  • Documentation update
  • Refactoring (no functional changes)
  • CI/CD or infrastructure change

Related Issues

Changes Made

  • create_model_spec methods and dynamic pydantic model creation for pickle-less serialization of configuration
  • Adds a few base loss functions, the general loss abstraction including individual losses and a composed loss function. The latter can be adjusted with weight scheduling, allowing the relative weighting of different losses to be adjusted over the course of training
  • Adds a TrainingStrategy pydantic model as a recipe validation and loop executor. The execution is highly modular and extendible, allowing for (hopefully) arbitrarily complex training workflows to be built, and not limited to MLIPs
  • Adds a FineTuningStrategy that specializes TrainingStrategy for...fine-tuning workflows by making pre-existing checkpoints and layer addition/modification integral to the workflow
  • Adds data loading optimizations; the main changes is addition of "batched" pre-fetching, which amortizes I/O for non-contiguous data samples. This is crucial for Zarr performance when shuffling data
  • Adds multidataset support, with a "meta" sampler that allows users to implement different cross-dataset sampling strategies (e.g. to account for dataset size imbalances)
  • Adds several training-related hooks, such as model averaging, mixed precision, checkpointing
  • Adds a CLI for training and fine-tuning: the intended use of this CLI is to provide a relatively straightforward on-ramp for users looking to get fine-tune (or train a model from scratch) quickly without needing to know the full training API

Testing

  • Unit tests pass locally (make pytest)
  • Linting passes (make lint)
  • New tests added for new functionality meets coverage expectations?

Checklist

  • I have read and understand the Contributing Guidelines
  • I have updated the CHANGELOG.md
  • I have performed a self-review of my code
  • I have added docstrings to new functions/classes
  • I have updated the documentation (if applicable)

Additional Notes

Tip

This repository uses Greptile, an AI code review service, to help conduct
pull request reviews. We encourage contributors to read and consider suggestions
made by Greptile, but note that human maintainers will provide the necessary
reviews for merging: Greptile's comments are not a qualitative judgement
of your code, nor is it an indication that the PR will be accepted/rejected.
We encourage the use of emoji reactions to Greptile comments, depending on
their usefulness and accuracy.

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…t-loading

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Rename EnergyLoss -> EnergyMSELoss, ForceLoss -> ForceMSELoss,
StressLoss -> StressMSELoss for naming consistency with EnergyMAELoss
and ForceL2NormLoss.

Replace ignore_nan with ignore_nonfinite in all three MSE losses,
switching masking from isnan() to torch.isfinite() to also exclude
inf targets, matching the convention in the MAE/L2 terms.

Add missing EnergyMAELoss and ForceL2NormLoss to API docs.
… weighting

EnergyMAELoss(per_atom=True) now uses atom-count-weighted reduction,
matching EnergyMSELoss semantics: larger graphs contribute in proportion
to their atom count. Previously it used a simple mean over graphs.

Also fixes SyntaxWarning from unescaped LaTeX in the class docstring.
…tion

BaseLossFunction.forward() is now a concrete template orchestrating five
overridable hooks: validate, normalize, mask, compute_residual, reduce.
Subclasses override only what they need — at minimum compute_residual().

Add ReductionContext (dict subclass) for passing reduction metadata such
as atom-count weights between hooks. Dynamo-safe (no TypedDict).

All five leaf losses (EnergyMSELoss, EnergyMAELoss, ForceMSELoss,
ForceL2NormLoss, StressMSELoss) refactored to use the template hooks.
No behavioral changes — all 186 existing tests pass unchanged.
Add three new sections to the losses user guide:
- Example 3: custom mask override (isfinite, padded layouts)
- Example 4: custom reduce override (graph-balanced reduction)
- Layout dispatch with plum (reference to ForceMSELoss/ForceL2NormLoss)
Covers built-in loss terms, the BaseLossFunction template-method pattern,
and how to implement custom losses with normalize, mask, reduce overrides
and plum dispatch for multi-layout forces.
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Add training strategy checkpoint restarts
New test files from training-epic (conftest, test_strategy, test_checkpoint,
test_mixed_precision, test_training_update_orchestrator, test_losses spec
tests) referenced old names EnergyLoss/ForceLoss/ignore_nan. Updated to
EnergyMSELoss/ForceMSELoss/ignore_nonfinite.
…support

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>

# Conflicts:
#	README.md
#	docs/userguide/about/install.md
#	pyproject.toml
#	uv.lock
Add MAE energy and L2-norm force loss terms
Restructure the CLI from a single Click command to a Click group with
two subcommands:

- roundtrip: existing generate+write+read benchmark (no behavior change)
- read: benchmark read performance against a pre-existing Zarr store

The read subcommand accepts a store path and the same read-tuning
options (--read-mode, --read-order, --read-batch-size, etc.) and
reports samples/s throughput via a Rich table.

Add _run_read_benchmark and _print_read_results helpers, plus three
tests for the new functionality.
Comment thread docs/modules/dynamics/hooks.rst Outdated
Comment on lines +260 to +269
Use :class:`~nvalchemi.dynamics.hooks.StageTimingHook` for lightweight stage
timing and optional NVTX ranges.

.. code-block:: python

from nvalchemi.dynamics.hooks import ProfilerHook
from nvalchemi.dynamics.hooks import StageTimingHook

hook = ProfilerHook(enable_nvtx=True, enable_timer=True, frequency=10)
hook = StageTimingHook("step", frequency=10, log_path="stage_timing.csv")
dynamics = DemoDynamics(model=model, n_steps=1_000, dt=0.5, hooks=[hook])
dynamics.run(batch)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example doesn't really tell me what the heck this hook is doing, what "step" refers to, what "frequency" means. We don't need full api doc here but I would expect just another sentence with sufficient exposition explaining what is going on here.

laserkelvin and others added 10 commits June 24, 2026 16:13
…eloper focus

Restructure the fine-tuning guide to follow a progressive disclosure model:
- Rewrite the intro to frame the CLI vs. API split clearly for both audiences
- Add a Fine-tuning API parent section with a simple-first ordering:
  simple full-model → modifications overview → inspect names → freeze →
  freeze mode → module patches → multi-model → checkpoints → hooks
- Demote all API subsections to ### so the heading hierarchy reflects depth
- Rewrite each section opening to flow naturally from the previous one,
  replacing terse reference-style prose with motivated narrative
- Add developer extension points throughout: programmatic pattern generation,
  progressive unfreezing via from_pretrained_checkpoint, freeze_mode as a
  gradient-hook seam, create_model_spec with custom nn.Module subclasses,
  per-model optimizer_configs for differential learning rates
- Add a multi-model fine-tuning section covering the dict-key naming
  requirement and the partial-coverage validation gap
- Collapse operational notes into targeted inline callouts; defer hook
  mechanics entirely to the hooks guide

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…er focus

Restructure the loss computation guide to match the style of the
fine-tuning documentation: motivating context before API, progressive
disclosure from simple to complex, and explicit developer extension
points.

Key changes:
- Rewrite intro to motivate the two-layer design (leaf + composition)
  as the natural answer to multi-task MLIP training objectives
- Add how-to-choose framing before the built-in losses table (MSE vs
  Huber vs MAE/L2-norm trade-offs)
- Reframe LossWeightSchedule as a named extension seam with a minimum
  viable implementation (per_epoch + __call__), then add to_spec() as
  the serialization requirement
- Add bridge sentence from schedule section to writing-your-own-loss
- Restructure "Writing your own loss" section: lead with a decision
  tree (compute_residual → normalize → mask → reduce), then show the
  minimum viable override before adding each additional hook
- Flag plum dispatch section as advanced, explicitly skippable

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…sion-seam focus

Restructure the training guide to match the style of the fine-tuning
documentation: lead with conclusions, explicit extension seams, and
section bridges that guide the reader through the lifecycle.

Key changes:
- Rewrite intro to be concrete: TrainingStrategy as a workflow engine
  with named lifecycle stages, not an abstract "flexibility" statement
- Add explicit framing of training_fn and loss_target_assembler as the
  two primary forward-pass extension seams before their detailed
  explanation
- Name TrainingUpdateHook as the named extension seam for gradient and
  optimizer customization at the start of the Optimizer Orchestration
  section, removing the redundant second introduction
- Add section bridges: setup → counters, optimizer orchestration →
  validation
- Add tip block for ValidationConfig per-batch callback as the
  developer extension point for custom evaluation metrics

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Further restructure the losses guide to consistently address developers
and ML engineers building on top of the API:

- Restructure "Writing your own loss" from numbered examples into one
  subsection per hook (compute_residual, normalize, mask, reduce,
  validate), each opening with what the base provides and when you'd
  override it — woven as prose rather than mechanical bold labels
- Consolidate "Composition weights and schedules" as nested subsections
  (### Weights, ### Weight schedules) under ## Composition, eliminating
  the duplicate operator-sugar intro and merging weight normalization
  and operator constraints into one coherent Weights subsection
- Motivate "The call signature" and "The return type" with why the
  design choices matter (keyed-mapping routing, per-component fields
  for debugging schedule behavior) before presenting the API
- Motivate "Per-sample loss diagnostics" with the use case (hard-sample
  identification, curriculum strategies) before the table
- Add an opening to "Routing errors" explaining why eager validation
  matters for debugging training_fn / loss mismatches
- Add "Bring your own schedule" as flowing prose (not bold-label format)
  with the minimum protocol implementation shown inline
- Add extension-pointer closing sentences to "Ignoring missing labels"
  and "MAE and force-L2 reductions" pointing to the relevant hook
  override sections

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>

# Conflicts:
#	docs/userguide/models.md
#	uv.lock
@laserkelvin

Copy link
Copy Markdown
Collaborator Author

/ok to test e3a04cb

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
@laserkelvin

Copy link
Copy Markdown
Collaborator Author

/ok to test e3a04cb

@copy-pr-bot

copy-pr-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

/ok to test e3a04cb

@laserkelvin, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@laserkelvin

Copy link
Copy Markdown
Collaborator Author

/ok to test 4ee2131

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
@laserkelvin

Copy link
Copy Markdown
Collaborator Author

/ok to test 18399f3

@ys-teh ys-teh mentioned this pull request Jun 25, 2026
15 tasks
physicsnemo 2.1.0 conflicts with fairchem

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
@laserkelvin

Copy link
Copy Markdown
Collaborator Author

/ok to test 85a93f5

Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
@laserkelvin

Copy link
Copy Markdown
Collaborator Author

/ok to test 2b25deb

@laserkelvin laserkelvin enabled auto-merge June 26, 2026 00:37
@laserkelvin laserkelvin added this pull request to the merge queue Jun 26, 2026
Merged via the queue into NVIDIA:main with commit 1d1c2d3 Jun 26, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants