Skip to content

Load training dataset tensors with weights_only#14333

Open
CFionaBF wants to merge 2 commits into
Comfy-Org:masterfrom
CFionaBF:harden-dataset-load
Open

Load training dataset tensors with weights_only#14333
CFionaBF wants to merge 2 commits into
Comfy-Org:masterfrom
CFionaBF:harden-dataset-load

Conversation

@CFionaBF

@CFionaBF CFionaBF commented Jun 7, 2026

Copy link
Copy Markdown

Load training dataset shards with weights_only=True, consistent with how the rest of the codebase loads checkpoints (e.g. comfy.utils.load_torch_file).

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1b877e583f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".


with open(shard_path, "rb") as f:
shard_data = torch.load(f)
shard_data = torch.load(f, weights_only=True)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Allow hook-bearing dataset shards to load

This breaks saved datasets whose conditioning includes CLIP hooks: Set CLIP Hooks can set clip.apply_hooks_to_conds when apply_to_conds is enabled (comfy_extras/nodes_hooks.py:249-254), and CLIP.add_hooks_to_dict then stores a comfy.hooks.HookGroup in the conditioning dict (comfy/sd.py:315-318). SaveTrainingDataset serializes that conditioning unchanged, but weights_only=True rejects non-allowlisted custom classes/functions during unpickling, so those shards now fail to load even though they were produced by this node. Either keep unsafe loading for this trusted output-directory dataset format or strip/allowlist the hook objects explicitly.

Useful? React with 👍 / 👎.

@coderabbitai

coderabbitai Bot commented Jun 7, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

The PR updates the shard deserialization in LoadTrainingDataset to enable strict, safe loading of training dataset shards. A torch.serialization.add_safe_globals([...]) registration block was added to register comfy.hooks classes (such as HookGroup, Hook, WeightHook, and related enums) as safe serialization globals. The shard loading call in LoadTrainingDataset was changed from torch.load(f) to torch.load(f, weights_only=True), restricting deserialized content to tensors and primitive Python types and preventing arbitrary object reconstruction from untrusted shard files during dataset import.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: loading training dataset tensors with the weights_only parameter set to True.
Description check ✅ Passed The description is directly related to the changeset, explaining the motivation for using weights_only=True and contextualizing it within the codebase's checkpoint loading patterns.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Datasets built with CLIP hooks store comfy.hooks objects in conditioning;
register them as safe globals so they still load with weights_only=True.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@comfy_extras/nodes_dataset.py`:
- Around line 13-36: The torch.serialization.add_safe_globals() call in the code
requires PyTorch 2.4.0 or later, but requirements.txt currently has no version
constraint on torch. Fix this by either adding a torch>=2.4.0 constraint to
requirements.txt, or by wrapping the add_safe_globals() call with a version
check that gracefully handles older PyTorch versions (following the precedent
established in main.py). Choose one approach: update the dependency
specification or add conditional logic that only registers safe globals when the
API is available.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 29de4071-4fbc-4cbe-8fc1-2a0c9eab271f

📥 Commits

Reviewing files that changed from the base of the PR and between 1b877e5 and 0878ec8.

📒 Files selected for processing (1)
  • comfy_extras/nodes_dataset.py

Comment on lines +13 to +36
import comfy.hooks

# Datasets produced by MakeTrainingDataset can embed comfy.hooks objects in their
# conditioning (clip.encode_from_tokens_scheduled stores a HookGroup under "hooks"
# when the CLIP has hooks applied). LoadTrainingDataset loads shards with
# weights_only=True, so register these known classes as safe globals; otherwise such
# datasets fail to load. Only data-bearing hooks (e.g. WeightHook) round-trip this way.
torch.serialization.add_safe_globals([
comfy.hooks.HookGroup,
comfy.hooks.Hook,
comfy.hooks.WeightHook,
comfy.hooks.ObjectPatchHook,
comfy.hooks.AdditionalModelsHook,
comfy.hooks.TransformerOptionsHook,
comfy.hooks.InjectionsHook,
comfy.hooks._HookRef,
comfy.hooks.HookKeyframe,
comfy.hooks.HookKeyframeGroup,
comfy.hooks.EnumHookMode,
comfy.hooks.EnumHookType,
comfy.hooks.EnumWeightTarget,
comfy.hooks.EnumHookScope,
comfy.hooks.default_should_register,
])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there's a minimum PyTorch version specified in the repository
echo "=== Checking requirements files for torch version constraints ==="
fd -t f -e txt -e toml -e cfg -e in . | xargs grep -l -i "torch" 2>/dev/null | head -5 | xargs cat 2>/dev/null | grep -i "torch"

echo ""
echo "=== Checking pyproject.toml for dependencies ==="
cat pyproject.toml 2>/dev/null | grep -A 20 "dependencies"

echo ""
echo "=== Verifying all registered classes exist in comfy.hooks ==="
rg -n "^class (HookGroup|Hook|WeightHook|ObjectPatchHook|AdditionalModelsHook|TransformerOptionsHook|InjectionsHook|HookKeyframe|HookKeyframeGroup)\b" --type py

echo ""
echo "=== Verifying enums and _HookRef exist ==="
rg -n "^class (EnumHookMode|EnumHookType|EnumWeightTarget|EnumHookScope|_HookRef)\b" --type py

echo ""
echo "=== Verifying default_should_register function exists ==="
rg -n "^def default_should_register" --type py

Repository: Comfy-Org/ComfyUI

Length of output: 1217


🏁 Script executed:

# Check for requirements.txt and setup.py files
fd -t f \( -name "requirements*.txt" -o -name "setup.py" -o -name "setup.cfg" -o -name "environment*.yml" \) | head -10

echo "=== Contents of found requirements files ==="
fd -t f -name "requirements*.txt" | xargs cat 2>/dev/null | grep -i "^torch"

Repository: Comfy-Org/ComfyUI

Length of output: 586


🏁 Script executed:

#!/bin/bash
# Find requirements files more directly
echo "=== Looking for requirements files ==="
find . -maxdepth 2 -name "requirements*.txt" -o -name "setup.py" -o -name "setup.cfg" | head -20

echo ""
echo "=== Checking for any torch version specifications ==="
find . -maxdepth 2 -type f \( -name "*.txt" -o -name "*.py" \) -exec grep -l "torch.*2\.[0-9]" {} \; 2>/dev/null | head -10

Repository: Comfy-Org/ComfyUI

Length of output: 483


🏁 Script executed:

#!/bin/bash
echo "=== Contents of requirements.txt ==="
cat ./requirements.txt | grep -i "torch"

echo ""
echo "=== Full requirements.txt for context ==="
cat ./requirements.txt

Repository: Comfy-Org/ComfyUI

Length of output: 695


🏁 Script executed:

#!/bin/bash
echo "=== Checking if there's any version checking or conditional imports for add_safe_globals ==="
rg -n "add_safe_globals|torch.serialization" comfy_extras/nodes_dataset.py

echo ""
echo "=== Checking PyTorch version handling in the codebase ==="
rg -n "torch\.__version__|torch version" --type py | head -20

Repository: Comfy-Org/ComfyUI

Length of output: 829


Safe globals registration is functionally correct, but add version requirement for PyTorch 2.4.0+.

All 16 registered classes and functions (Hook, WeightHook, HookGroup, enums, etc.) exist in comfy/hooks.py. The registration logic itself is sound and aligns with the security pattern used elsewhere in the codebase.

However, torch.serialization.add_safe_globals() was introduced in PyTorch 2.4.0. This PR introduces a hard dependency on this API, but requirements.txt currently specifies torch without a version constraint. This will fail at import time for any user with PyTorch < 2.4.0.

Either add torch>=2.4.0 to requirements.txt or add a version check in the code (the codebase has precedent for this—see main.py line 220).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy_extras/nodes_dataset.py` around lines 13 - 36, The
torch.serialization.add_safe_globals() call in the code requires PyTorch 2.4.0
or later, but requirements.txt currently has no version constraint on torch. Fix
this by either adding a torch>=2.4.0 constraint to requirements.txt, or by
wrapping the add_safe_globals() call with a version check that gracefully
handles older PyTorch versions (following the precedent established in main.py).
Choose one approach: update the dependency specification or add conditional
logic that only registers safe globals when the API is available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant