Skip to content

Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss#8863

Open
AlexanderSanin wants to merge 2 commits into
Project-MONAI:devfrom
AlexanderSanin:feat/unified-focal-loss-activation-8603
Open

Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss#8863
AlexanderSanin wants to merge 2 commits into
Project-MONAI:devfrom
AlexanderSanin:feat/unified-focal-loss-activation-8603

Conversation

@AlexanderSanin
Copy link
Copy Markdown

Summary

Fixes #8603

Adds `use_softmax` and `use_sigmoid` parameters to `AsymmetricUnifiedFocalLoss`, following the same pattern as `FocalLoss`. This allows users to pass raw logits directly without manually applying activations beforehand.

  • `use_softmax=True`: applies softmax along channel dim (for multi-class)
  • `use_sigmoid=True`: applies sigmoid (for binary)
  • Both `False` (default): input assumed to be probabilities — fully backward compatible
  • Mutually exclusive validation with clear error message
  • Removed stale TODO comment and added missing docstrings for `reduction` parameter

Test plan

  • Existing tests pass unchanged (backward compatible defaults)
  • New `test_use_sigmoid`: passes logits with sigmoid activation
  • New `test_use_softmax`: passes logits with softmax activation
  • New `test_mutually_exclusive`: validates that setting both raises `ValueError`

Signed-off-by: Oleksandr Sanin alexaaander.sanin@gmail.com

)

Catching BaseException inadvertently suppresses KeyboardInterrupt,
SystemExit, and GeneratorExit, which should nearly always propagate.
All 17 occurrences across monai/ and tests/ are replaced with
Exception, which is the appropriate base class for catchable errors.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…-MONAI#8603)

Add use_softmax and use_sigmoid parameters so users can pass raw
logits directly. When both are False (default), the input is assumed
to already be probabilities, preserving backward compatibility.

Also removes the stale TODO comment about multi-class support and
adds proper docstrings for the reduction parameter.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 18, 2026

📝 Walkthrough

Walkthrough

This PR makes two changes: (1) systematically narrowing exception handlers from BaseException to Exception across 12 library and test files to prevent catching control-flow events like SystemExit and KeyboardInterrupt, and (2) adding use_sigmoid and use_softmax parameters to AsymmetricUnifiedFocalLoss to support logit inputs with enforced mutual exclusivity and corresponding test coverage.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Out of Scope Changes check ⚠️ Warning Most changes replace BaseException with Exception across unrelated files; only AsymmetricUnifiedFocalLoss changes and tests directly address #8603 scope. Remove BaseException→Exception replacements from files unrelated to #8603 (init.py, data_analyzer.py, ensemble_builder.py, utils.py, coco.py, nnunetv2_runner.py, deviceconfig.py, data/init.py, inferer.py, tf32.py, test_retinanet.py, test_resnet.py, test_utils.py) or establish separate tracking for that refactoring.
Docstring Coverage ⚠️ Warning Docstring coverage is 47.62% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed Title accurately summarizes main change: adding sigmoid/softmax activation to AsymmetricUnifiedFocalLoss.
Description check ✅ Passed Description covers main changes, test plan, and follows template structure, though some checklist items unmarked.
Linked Issues check ✅ Passed Changes fully implement #8603: adds use_softmax/use_sigmoid parameters, validates mutual exclusivity, maintains backward compatibility, removes stale TODO, and adds docstrings.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

Review ran into problems

🔥 Problems

Git: Failed to clone repository. Please run the @coderabbitai full review command to re-trigger a full review. If the issue persists, set path_filters to include or exclude specific files.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tests/losses/test_unified_focal_loss.py (1)

71-76: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Same issues as test_use_sigmoid.

Single-channel logits will fail one-hot conversion, and the assertion is too weak to verify softmax application.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/test_unified_focal_loss.py` around lines 71 - 76,
test_use_softmax fails because single-channel logits break one-hot conversion
and the current assertion is too weak; update the test to feed multi-channel
logits (C=2) and matching one-hot y_true so
AsymmetricUnifiedFocalLoss(use_softmax=True) can perform softmax, then
strengthen the assertion by comparing the softmax-enabled loss to a baseline
(e.g., compute loss_softmax = loss(y_pred, y_true) with
AsymmetricUnifiedFocalLoss(use_softmax=True) and loss_sigmoid =
AsymmetricUnifiedFocalLoss(use_softmax=False)(same y_pred, y_true)) and assert
both are finite and that loss_softmax differs from loss_sigmoid (or is less
than, depending on expected behavior) to verify softmax was applied; reference
test_use_softmax, AsymmetricUnifiedFocalLoss, use_softmax, y_pred and y_true
when making the changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/losses/unified_focal_loss.py`:
- Around line 241-244: The activation is applied after converting single-channel
y_pred to one-hot causing integer-index errors for single-channel logits; update
the UnifiedFocalLoss forward (or the method handling y_pred) to either apply
softmax/sigmoid before the one_hot conversion or add an explicit input
validation that raises a clear error when y_pred has a single channel while
use_softmax or use_sigmoid is True (referencing variables y_pred, use_softmax,
use_sigmoid and the one_hot conversion block) so users get a helpful message
rather than a silent failure.
- Around line 241-244: The code enables softmax via use_softmax but the
component losses AsymmetricFocalLoss and AsymmetricFocalTverskyLoss index
hardcoded class channels ([:, 0] and [:, 1]) so softmax with >2 channels will be
broken; update the functions to validate and reject multi-class inputs (e.g.,
check y_pred.shape[1] or num_classes and raise a clear ValueError if != 2) or
explicitly document in the use_softmax docstring that only binary (2-class)
predictions are supported, and ensure the error message references use_softmax,
AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss so callers know why their
multi-channel input is unsupported.

In `@tests/losses/test_unified_focal_loss.py`:
- Around line 64-69: The test crashes because AsymmetricUnifiedFocalLoss's
internal one_hot conversion is called when y_pred is single-channel float logits
([2,1,2,2]) and those logits are being cast to long, producing out-of-range
indices; change the test to provide multi-channel logits (e.g. shape [2,2,2,2])
when use_sigmoid=True or else supply integer class indices for y_true so one_hot
isn't fed raw logits; update the test_use_sigmoid to create y_pred with two
channels and matching y_true (or use class indices) so one_hot/scatter_ receives
valid class indices.

---

Duplicate comments:
In `@tests/losses/test_unified_focal_loss.py`:
- Around line 71-76: test_use_softmax fails because single-channel logits break
one-hot conversion and the current assertion is too weak; update the test to
feed multi-channel logits (C=2) and matching one-hot y_true so
AsymmetricUnifiedFocalLoss(use_softmax=True) can perform softmax, then
strengthen the assertion by comparing the softmax-enabled loss to a baseline
(e.g., compute loss_softmax = loss(y_pred, y_true) with
AsymmetricUnifiedFocalLoss(use_softmax=True) and loss_sigmoid =
AsymmetricUnifiedFocalLoss(use_softmax=False)(same y_pred, y_true)) and assert
both are finite and that loss_softmax differs from loss_sigmoid (or is less
than, depending on expected behavior) to verify softmax was applied; reference
test_use_softmax, AsymmetricUnifiedFocalLoss, use_softmax, y_pred and y_true
when making the changes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0e334f54-a695-4d0b-bc39-462797493bc0

📥 Commits

Reviewing files that changed from the base of the PR and between ef2acfb and ec11cc3.

📒 Files selected for processing (15)
  • monai/__init__.py
  • monai/apps/auto3dseg/data_analyzer.py
  • monai/apps/auto3dseg/ensemble_builder.py
  • monai/apps/auto3dseg/utils.py
  • monai/apps/detection/metrics/coco.py
  • monai/apps/nnunet/nnunetv2_runner.py
  • monai/config/deviceconfig.py
  • monai/data/__init__.py
  • monai/inferers/inferer.py
  • monai/losses/unified_focal_loss.py
  • monai/utils/tf32.py
  • tests/apps/detection/networks/test_retinanet.py
  • tests/losses/test_unified_focal_loss.py
  • tests/networks/nets/test_resnet.py
  • tests/test_utils.py

Comment on lines +241 to +244
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
elif self.use_sigmoid:
y_pred = torch.sigmoid(y_pred)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Activation applied after one-hot conversion breaks single-channel logit inputs.

Lines 227-229 convert single-channel y_pred to one-hot before activation is applied here. This means:

  • Single-channel inputs are treated as class indices (discrete values), not logits
  • Passing single-channel logits with use_sigmoid=True will fail because one_hot() expects integers
  • Users must pass 2-channel inputs to use the activation flags

Document this requirement or add validation to raise a clear error for single-channel inputs when activation flags are set.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The activation is
applied after converting single-channel y_pred to one-hot causing integer-index
errors for single-channel logits; update the UnifiedFocalLoss forward (or the
method handling y_pred) to either apply softmax/sigmoid before the one_hot
conversion or add an explicit input validation that raises a clear error when
y_pred has a single channel while use_softmax or use_sigmoid is True
(referencing variables y_pred, use_softmax, use_sigmoid and the one_hot
conversion block) so users get a helpful message rather than a silent failure.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Multi-class limitation: softmax won't work beyond binary despite being added.

The component losses (AsymmetricFocalLoss lines 135-138 and AsymmetricFocalTverskyLoss lines 79-80) hardcode indices [:, 0] and [:, 1], limiting support to exactly 2 classes. Even with use_softmax=True, inputs with >2 channels will fail or produce incorrect results.

Consider documenting this limitation in the use_softmax docstring or adding validation to reject num_classes != 2.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The code enables
softmax via use_softmax but the component losses AsymmetricFocalLoss and
AsymmetricFocalTverskyLoss index hardcoded class channels ([:, 0] and [:, 1]) so
softmax with >2 channels will be broken; update the functions to validate and
reject multi-class inputs (e.g., check y_pred.shape[1] or num_classes and raise
a clear ValueError if != 2) or explicitly document in the use_softmax docstring
that only binary (2-class) predictions are supported, and ensure the error
message references use_softmax, AsymmetricFocalLoss, and
AsymmetricFocalTverskyLoss so callers know why their multi-channel input is
unsupported.

Comment on lines +64 to +69
def test_use_sigmoid(self):
loss = AsymmetricUnifiedFocalLoss(use_sigmoid=True)
y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]])
y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
result = loss(y_pred, y_true)
self.assertTrue(result.item() >= 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Run the specific test to check if single-channel logits work
python -m pytest tests/losses/test_unified_focal_loss.py::TestAsymmetricUnifiedFocalLoss::test_use_sigmoid -v

Repository: Project-MONAI/MONAI

Length of output: 103


🏁 Script executed:

# Find the AsymmetricUnifiedFocalLoss implementation
find . -type f -name "*.py" | xargs grep -l "class AsymmetricUnifiedFocalLoss"

Repository: Project-MONAI/MONAI

Length of output: 100


🏁 Script executed:

# Get the test file to see all three tests mentioned
cat -n tests/losses/test_unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 3818


🏁 Script executed:

# Search for one_hot usage in the loss implementation
rg "one_hot" --type py

Repository: Project-MONAI/MONAI

Length of output: 10644


🏁 Script executed:

cat -n monai/losses/unified_focal_loss.py | head -300

Repository: Project-MONAI/MONAI

Length of output: 12927


🏁 Script executed:

# Find and read the one_hot function definition
rg -A 20 "^def one_hot" monai/networks/utils.py

Repository: Project-MONAI/MONAI

Length of output: 1250


🏁 Script executed:

# Double check by examining what happens with out-of-bounds indices in one_hot
python3 << 'EOF'
import torch
import torch.nn.functional as F

# Simulate what monai.networks.one_hot does
labels = torch.tensor([10, -10])  # Out of bounds for 2 classes
try:
    # This is essentially what one_hot does internally
    result = F.one_hot(labels.long(), num_classes=2)
    print("Result:", result)
except Exception as e:
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {e}")
EOF

Repository: Project-MONAI/MONAI

Length of output: 181


🏁 Script executed:

# Get the full one_hot implementation
rg -A 30 "^def one_hot" monai/networks/utils.py

Repository: Project-MONAI/MONAI

Length of output: 1575


🏁 Script executed:

# Get the complete one_hot function implementation
rg -A 50 "^def one_hot" monai/networks/utils.py | head -60

Repository: Project-MONAI/MONAI

Length of output: 2157


Test crashes with out-of-bounds indices in one-hot conversion.

Single-channel y_pred (shape [2, 1, 2, 2]) triggers one_hot() conversion at line 228, which internally calls scatter_() with indices from .long(). The logit values 10.0 and -10 become indices 10 and -10, both out of bounds for num_classes=2. This causes an IndexError before use_sigmoid is even applied.

Recommendation: Use multi-channel inputs ([B, 2, H, W]) to bypass one-hot conversion, or use integer class indices if single-channel is intended.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/test_unified_focal_loss.py` around lines 64 - 69, The test
crashes because AsymmetricUnifiedFocalLoss's internal one_hot conversion is
called when y_pred is single-channel float logits ([2,1,2,2]) and those logits
are being cast to long, producing out-of-range indices; change the test to
provide multi-channel logits (e.g. shape [2,2,2,2]) when use_sigmoid=True or
else supply integer class indices for y_true so one_hot isn't fed raw logits;
update the test_use_sigmoid to create y_pred with two channels and matching
y_true (or use class indices) so one_hot/scatter_ receives valid class indices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss

1 participant