Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss#8863
Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss#8863AlexanderSanin wants to merge 2 commits into
Conversation
) Catching BaseException inadvertently suppresses KeyboardInterrupt, SystemExit, and GeneratorExit, which should nearly always propagate. All 17 occurrences across monai/ and tests/ are replaced with Exception, which is the appropriate base class for catchable errors. Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…-MONAI#8603) Add use_softmax and use_sigmoid parameters so users can pass raw logits directly. When both are False (default), the input is assumed to already be probabilities, preserving backward compatibility. Also removes the stale TODO comment about multi-class support and adds proper docstrings for the reduction parameter. Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
📝 WalkthroughWalkthroughThis PR makes two changes: (1) systematically narrowing exception handlers from Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsGit: Failed to clone repository. Please run the Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
tests/losses/test_unified_focal_loss.py (1)
71-76:⚠️ Potential issue | 🟠 Major | ⚡ Quick winSame issues as
test_use_sigmoid.Single-channel logits will fail one-hot conversion, and the assertion is too weak to verify softmax application.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/losses/test_unified_focal_loss.py` around lines 71 - 76, test_use_softmax fails because single-channel logits break one-hot conversion and the current assertion is too weak; update the test to feed multi-channel logits (C=2) and matching one-hot y_true so AsymmetricUnifiedFocalLoss(use_softmax=True) can perform softmax, then strengthen the assertion by comparing the softmax-enabled loss to a baseline (e.g., compute loss_softmax = loss(y_pred, y_true) with AsymmetricUnifiedFocalLoss(use_softmax=True) and loss_sigmoid = AsymmetricUnifiedFocalLoss(use_softmax=False)(same y_pred, y_true)) and assert both are finite and that loss_softmax differs from loss_sigmoid (or is less than, depending on expected behavior) to verify softmax was applied; reference test_use_softmax, AsymmetricUnifiedFocalLoss, use_softmax, y_pred and y_true when making the changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/losses/unified_focal_loss.py`:
- Around line 241-244: The activation is applied after converting single-channel
y_pred to one-hot causing integer-index errors for single-channel logits; update
the UnifiedFocalLoss forward (or the method handling y_pred) to either apply
softmax/sigmoid before the one_hot conversion or add an explicit input
validation that raises a clear error when y_pred has a single channel while
use_softmax or use_sigmoid is True (referencing variables y_pred, use_softmax,
use_sigmoid and the one_hot conversion block) so users get a helpful message
rather than a silent failure.
- Around line 241-244: The code enables softmax via use_softmax but the
component losses AsymmetricFocalLoss and AsymmetricFocalTverskyLoss index
hardcoded class channels ([:, 0] and [:, 1]) so softmax with >2 channels will be
broken; update the functions to validate and reject multi-class inputs (e.g.,
check y_pred.shape[1] or num_classes and raise a clear ValueError if != 2) or
explicitly document in the use_softmax docstring that only binary (2-class)
predictions are supported, and ensure the error message references use_softmax,
AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss so callers know why their
multi-channel input is unsupported.
In `@tests/losses/test_unified_focal_loss.py`:
- Around line 64-69: The test crashes because AsymmetricUnifiedFocalLoss's
internal one_hot conversion is called when y_pred is single-channel float logits
([2,1,2,2]) and those logits are being cast to long, producing out-of-range
indices; change the test to provide multi-channel logits (e.g. shape [2,2,2,2])
when use_sigmoid=True or else supply integer class indices for y_true so one_hot
isn't fed raw logits; update the test_use_sigmoid to create y_pred with two
channels and matching y_true (or use class indices) so one_hot/scatter_ receives
valid class indices.
---
Duplicate comments:
In `@tests/losses/test_unified_focal_loss.py`:
- Around line 71-76: test_use_softmax fails because single-channel logits break
one-hot conversion and the current assertion is too weak; update the test to
feed multi-channel logits (C=2) and matching one-hot y_true so
AsymmetricUnifiedFocalLoss(use_softmax=True) can perform softmax, then
strengthen the assertion by comparing the softmax-enabled loss to a baseline
(e.g., compute loss_softmax = loss(y_pred, y_true) with
AsymmetricUnifiedFocalLoss(use_softmax=True) and loss_sigmoid =
AsymmetricUnifiedFocalLoss(use_softmax=False)(same y_pred, y_true)) and assert
both are finite and that loss_softmax differs from loss_sigmoid (or is less
than, depending on expected behavior) to verify softmax was applied; reference
test_use_softmax, AsymmetricUnifiedFocalLoss, use_softmax, y_pred and y_true
when making the changes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0e334f54-a695-4d0b-bc39-462797493bc0
📒 Files selected for processing (15)
monai/__init__.pymonai/apps/auto3dseg/data_analyzer.pymonai/apps/auto3dseg/ensemble_builder.pymonai/apps/auto3dseg/utils.pymonai/apps/detection/metrics/coco.pymonai/apps/nnunet/nnunetv2_runner.pymonai/config/deviceconfig.pymonai/data/__init__.pymonai/inferers/inferer.pymonai/losses/unified_focal_loss.pymonai/utils/tf32.pytests/apps/detection/networks/test_retinanet.pytests/losses/test_unified_focal_loss.pytests/networks/nets/test_resnet.pytests/test_utils.py
| if self.use_softmax: | ||
| y_pred = torch.softmax(y_pred, dim=1) | ||
| elif self.use_sigmoid: | ||
| y_pred = torch.sigmoid(y_pred) |
There was a problem hiding this comment.
Activation applied after one-hot conversion breaks single-channel logit inputs.
Lines 227-229 convert single-channel y_pred to one-hot before activation is applied here. This means:
- Single-channel inputs are treated as class indices (discrete values), not logits
- Passing single-channel logits with
use_sigmoid=Truewill fail becauseone_hot()expects integers - Users must pass 2-channel inputs to use the activation flags
Document this requirement or add validation to raise a clear error for single-channel inputs when activation flags are set.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The activation is
applied after converting single-channel y_pred to one-hot causing integer-index
errors for single-channel logits; update the UnifiedFocalLoss forward (or the
method handling y_pred) to either apply softmax/sigmoid before the one_hot
conversion or add an explicit input validation that raises a clear error when
y_pred has a single channel while use_softmax or use_sigmoid is True
(referencing variables y_pred, use_softmax, use_sigmoid and the one_hot
conversion block) so users get a helpful message rather than a silent failure.
Multi-class limitation: softmax won't work beyond binary despite being added.
The component losses (AsymmetricFocalLoss lines 135-138 and AsymmetricFocalTverskyLoss lines 79-80) hardcode indices [:, 0] and [:, 1], limiting support to exactly 2 classes. Even with use_softmax=True, inputs with >2 channels will fail or produce incorrect results.
Consider documenting this limitation in the use_softmax docstring or adding validation to reject num_classes != 2.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The code enables
softmax via use_softmax but the component losses AsymmetricFocalLoss and
AsymmetricFocalTverskyLoss index hardcoded class channels ([:, 0] and [:, 1]) so
softmax with >2 channels will be broken; update the functions to validate and
reject multi-class inputs (e.g., check y_pred.shape[1] or num_classes and raise
a clear ValueError if != 2) or explicitly document in the use_softmax docstring
that only binary (2-class) predictions are supported, and ensure the error
message references use_softmax, AsymmetricFocalLoss, and
AsymmetricFocalTverskyLoss so callers know why their multi-channel input is
unsupported.
| def test_use_sigmoid(self): | ||
| loss = AsymmetricUnifiedFocalLoss(use_sigmoid=True) | ||
| y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]]) | ||
| y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) | ||
| result = loss(y_pred, y_true) | ||
| self.assertTrue(result.item() >= 0) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Run the specific test to check if single-channel logits work
python -m pytest tests/losses/test_unified_focal_loss.py::TestAsymmetricUnifiedFocalLoss::test_use_sigmoid -vRepository: Project-MONAI/MONAI
Length of output: 103
🏁 Script executed:
# Find the AsymmetricUnifiedFocalLoss implementation
find . -type f -name "*.py" | xargs grep -l "class AsymmetricUnifiedFocalLoss"Repository: Project-MONAI/MONAI
Length of output: 100
🏁 Script executed:
# Get the test file to see all three tests mentioned
cat -n tests/losses/test_unified_focal_loss.pyRepository: Project-MONAI/MONAI
Length of output: 3818
🏁 Script executed:
# Search for one_hot usage in the loss implementation
rg "one_hot" --type pyRepository: Project-MONAI/MONAI
Length of output: 10644
🏁 Script executed:
cat -n monai/losses/unified_focal_loss.py | head -300Repository: Project-MONAI/MONAI
Length of output: 12927
🏁 Script executed:
# Find and read the one_hot function definition
rg -A 20 "^def one_hot" monai/networks/utils.pyRepository: Project-MONAI/MONAI
Length of output: 1250
🏁 Script executed:
# Double check by examining what happens with out-of-bounds indices in one_hot
python3 << 'EOF'
import torch
import torch.nn.functional as F
# Simulate what monai.networks.one_hot does
labels = torch.tensor([10, -10]) # Out of bounds for 2 classes
try:
# This is essentially what one_hot does internally
result = F.one_hot(labels.long(), num_classes=2)
print("Result:", result)
except Exception as e:
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
EOFRepository: Project-MONAI/MONAI
Length of output: 181
🏁 Script executed:
# Get the full one_hot implementation
rg -A 30 "^def one_hot" monai/networks/utils.pyRepository: Project-MONAI/MONAI
Length of output: 1575
🏁 Script executed:
# Get the complete one_hot function implementation
rg -A 50 "^def one_hot" monai/networks/utils.py | head -60Repository: Project-MONAI/MONAI
Length of output: 2157
Test crashes with out-of-bounds indices in one-hot conversion.
Single-channel y_pred (shape [2, 1, 2, 2]) triggers one_hot() conversion at line 228, which internally calls scatter_() with indices from .long(). The logit values 10.0 and -10 become indices 10 and -10, both out of bounds for num_classes=2. This causes an IndexError before use_sigmoid is even applied.
Recommendation: Use multi-channel inputs ([B, 2, H, W]) to bypass one-hot conversion, or use integer class indices if single-channel is intended.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/losses/test_unified_focal_loss.py` around lines 64 - 69, The test
crashes because AsymmetricUnifiedFocalLoss's internal one_hot conversion is
called when y_pred is single-channel float logits ([2,1,2,2]) and those logits
are being cast to long, producing out-of-range indices; change the test to
provide multi-channel logits (e.g. shape [2,2,2,2]) when use_sigmoid=True or
else supply integer class indices for y_true so one_hot isn't fed raw logits;
update the test_use_sigmoid to create y_pred with two channels and matching
y_true (or use class indices) so one_hot/scatter_ receives valid class indices.
Summary
Fixes #8603
Adds `use_softmax` and `use_sigmoid` parameters to `AsymmetricUnifiedFocalLoss`, following the same pattern as `FocalLoss`. This allows users to pass raw logits directly without manually applying activations beforehand.
Test plan
Signed-off-by: Oleksandr Sanin alexaaander.sanin@gmail.com