Load training dataset tensors with weights_only#14333
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1b877e583f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| with open(shard_path, "rb") as f: | ||
| shard_data = torch.load(f) | ||
| shard_data = torch.load(f, weights_only=True) |
There was a problem hiding this comment.
Allow hook-bearing dataset shards to load
This breaks saved datasets whose conditioning includes CLIP hooks: Set CLIP Hooks can set clip.apply_hooks_to_conds when apply_to_conds is enabled (comfy_extras/nodes_hooks.py:249-254), and CLIP.add_hooks_to_dict then stores a comfy.hooks.HookGroup in the conditioning dict (comfy/sd.py:315-318). SaveTrainingDataset serializes that conditioning unchanged, but weights_only=True rejects non-allowlisted custom classes/functions during unpickling, so those shards now fail to load even though they were produced by this node. Either keep unsafe loading for this trusted output-directory dataset format or strip/allowlist the hook objects explicitly.
Useful? React with 👍 / 👎.
📝 WalkthroughWalkthroughThe PR updates the shard deserialization in 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Datasets built with CLIP hooks store comfy.hooks objects in conditioning; register them as safe globals so they still load with weights_only=True.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@comfy_extras/nodes_dataset.py`:
- Around line 13-36: The torch.serialization.add_safe_globals() call in the code
requires PyTorch 2.4.0 or later, but requirements.txt currently has no version
constraint on torch. Fix this by either adding a torch>=2.4.0 constraint to
requirements.txt, or by wrapping the add_safe_globals() call with a version
check that gracefully handles older PyTorch versions (following the precedent
established in main.py). Choose one approach: update the dependency
specification or add conditional logic that only registers safe globals when the
API is available.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 29de4071-4fbc-4cbe-8fc1-2a0c9eab271f
📒 Files selected for processing (1)
comfy_extras/nodes_dataset.py
| import comfy.hooks | ||
|
|
||
| # Datasets produced by MakeTrainingDataset can embed comfy.hooks objects in their | ||
| # conditioning (clip.encode_from_tokens_scheduled stores a HookGroup under "hooks" | ||
| # when the CLIP has hooks applied). LoadTrainingDataset loads shards with | ||
| # weights_only=True, so register these known classes as safe globals; otherwise such | ||
| # datasets fail to load. Only data-bearing hooks (e.g. WeightHook) round-trip this way. | ||
| torch.serialization.add_safe_globals([ | ||
| comfy.hooks.HookGroup, | ||
| comfy.hooks.Hook, | ||
| comfy.hooks.WeightHook, | ||
| comfy.hooks.ObjectPatchHook, | ||
| comfy.hooks.AdditionalModelsHook, | ||
| comfy.hooks.TransformerOptionsHook, | ||
| comfy.hooks.InjectionsHook, | ||
| comfy.hooks._HookRef, | ||
| comfy.hooks.HookKeyframe, | ||
| comfy.hooks.HookKeyframeGroup, | ||
| comfy.hooks.EnumHookMode, | ||
| comfy.hooks.EnumHookType, | ||
| comfy.hooks.EnumWeightTarget, | ||
| comfy.hooks.EnumHookScope, | ||
| comfy.hooks.default_should_register, | ||
| ]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if there's a minimum PyTorch version specified in the repository
echo "=== Checking requirements files for torch version constraints ==="
fd -t f -e txt -e toml -e cfg -e in . | xargs grep -l -i "torch" 2>/dev/null | head -5 | xargs cat 2>/dev/null | grep -i "torch"
echo ""
echo "=== Checking pyproject.toml for dependencies ==="
cat pyproject.toml 2>/dev/null | grep -A 20 "dependencies"
echo ""
echo "=== Verifying all registered classes exist in comfy.hooks ==="
rg -n "^class (HookGroup|Hook|WeightHook|ObjectPatchHook|AdditionalModelsHook|TransformerOptionsHook|InjectionsHook|HookKeyframe|HookKeyframeGroup)\b" --type py
echo ""
echo "=== Verifying enums and _HookRef exist ==="
rg -n "^class (EnumHookMode|EnumHookType|EnumWeightTarget|EnumHookScope|_HookRef)\b" --type py
echo ""
echo "=== Verifying default_should_register function exists ==="
rg -n "^def default_should_register" --type pyRepository: Comfy-Org/ComfyUI
Length of output: 1217
🏁 Script executed:
# Check for requirements.txt and setup.py files
fd -t f \( -name "requirements*.txt" -o -name "setup.py" -o -name "setup.cfg" -o -name "environment*.yml" \) | head -10
echo "=== Contents of found requirements files ==="
fd -t f -name "requirements*.txt" | xargs cat 2>/dev/null | grep -i "^torch"Repository: Comfy-Org/ComfyUI
Length of output: 586
🏁 Script executed:
#!/bin/bash
# Find requirements files more directly
echo "=== Looking for requirements files ==="
find . -maxdepth 2 -name "requirements*.txt" -o -name "setup.py" -o -name "setup.cfg" | head -20
echo ""
echo "=== Checking for any torch version specifications ==="
find . -maxdepth 2 -type f \( -name "*.txt" -o -name "*.py" \) -exec grep -l "torch.*2\.[0-9]" {} \; 2>/dev/null | head -10Repository: Comfy-Org/ComfyUI
Length of output: 483
🏁 Script executed:
#!/bin/bash
echo "=== Contents of requirements.txt ==="
cat ./requirements.txt | grep -i "torch"
echo ""
echo "=== Full requirements.txt for context ==="
cat ./requirements.txtRepository: Comfy-Org/ComfyUI
Length of output: 695
🏁 Script executed:
#!/bin/bash
echo "=== Checking if there's any version checking or conditional imports for add_safe_globals ==="
rg -n "add_safe_globals|torch.serialization" comfy_extras/nodes_dataset.py
echo ""
echo "=== Checking PyTorch version handling in the codebase ==="
rg -n "torch\.__version__|torch version" --type py | head -20Repository: Comfy-Org/ComfyUI
Length of output: 829
Safe globals registration is functionally correct, but add version requirement for PyTorch 2.4.0+.
All 16 registered classes and functions (Hook, WeightHook, HookGroup, enums, etc.) exist in comfy/hooks.py. The registration logic itself is sound and aligns with the security pattern used elsewhere in the codebase.
However, torch.serialization.add_safe_globals() was introduced in PyTorch 2.4.0. This PR introduces a hard dependency on this API, but requirements.txt currently specifies torch without a version constraint. This will fail at import time for any user with PyTorch < 2.4.0.
Either add torch>=2.4.0 to requirements.txt or add a version check in the code (the codebase has precedent for this—see main.py line 220).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@comfy_extras/nodes_dataset.py` around lines 13 - 36, The
torch.serialization.add_safe_globals() call in the code requires PyTorch 2.4.0
or later, but requirements.txt currently has no version constraint on torch. Fix
this by either adding a torch>=2.4.0 constraint to requirements.txt, or by
wrapping the add_safe_globals() call with a version check that gracefully
handles older PyTorch versions (following the precedent established in main.py).
Choose one approach: update the dependency specification or add conditional
logic that only registers safe globals when the API is available.
Load training dataset shards with
weights_only=True, consistent with how the rest of the codebase loads checkpoints (e.g.comfy.utils.load_torch_file).