-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss #8863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,24 @@ def test_with_cuda(self): | |
| print(output) | ||
| np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) | ||
|
|
||
| def test_use_sigmoid(self): | ||
| loss = AsymmetricUnifiedFocalLoss(use_sigmoid=True) | ||
| y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]]) | ||
| y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) | ||
| result = loss(y_pred, y_true) | ||
| self.assertTrue(result.item() >= 0) | ||
|
Comment on lines
+64
to
+69
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Run the specific test to check if single-channel logits work
python -m pytest tests/losses/test_unified_focal_loss.py::TestAsymmetricUnifiedFocalLoss::test_use_sigmoid -vRepository: Project-MONAI/MONAI Length of output: 103 🏁 Script executed: # Find the AsymmetricUnifiedFocalLoss implementation
find . -type f -name "*.py" | xargs grep -l "class AsymmetricUnifiedFocalLoss"Repository: Project-MONAI/MONAI Length of output: 100 🏁 Script executed: # Get the test file to see all three tests mentioned
cat -n tests/losses/test_unified_focal_loss.pyRepository: Project-MONAI/MONAI Length of output: 3818 🏁 Script executed: # Search for one_hot usage in the loss implementation
rg "one_hot" --type pyRepository: Project-MONAI/MONAI Length of output: 10644 🏁 Script executed: cat -n monai/losses/unified_focal_loss.py | head -300Repository: Project-MONAI/MONAI Length of output: 12927 🏁 Script executed: # Find and read the one_hot function definition
rg -A 20 "^def one_hot" monai/networks/utils.pyRepository: Project-MONAI/MONAI Length of output: 1250 🏁 Script executed: # Double check by examining what happens with out-of-bounds indices in one_hot
python3 << 'EOF'
import torch
import torch.nn.functional as F
# Simulate what monai.networks.one_hot does
labels = torch.tensor([10, -10]) # Out of bounds for 2 classes
try:
# This is essentially what one_hot does internally
result = F.one_hot(labels.long(), num_classes=2)
print("Result:", result)
except Exception as e:
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
EOFRepository: Project-MONAI/MONAI Length of output: 181 🏁 Script executed: # Get the full one_hot implementation
rg -A 30 "^def one_hot" monai/networks/utils.pyRepository: Project-MONAI/MONAI Length of output: 1575 🏁 Script executed: # Get the complete one_hot function implementation
rg -A 50 "^def one_hot" monai/networks/utils.py | head -60Repository: Project-MONAI/MONAI Length of output: 2157 Test crashes with out-of-bounds indices in one-hot conversion. Single-channel Recommendation: Use multi-channel inputs ( 🤖 Prompt for AI Agents |
||
|
|
||
| def test_use_softmax(self): | ||
| loss = AsymmetricUnifiedFocalLoss(use_softmax=True) | ||
| y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]]) | ||
| y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) | ||
| result = loss(y_pred, y_true) | ||
| self.assertTrue(result.item() >= 0) | ||
|
|
||
| def test_mutually_exclusive(self): | ||
| with self.assertRaises(ValueError): | ||
| AsymmetricUnifiedFocalLoss(use_softmax=True, use_sigmoid=True) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Activation applied after one-hot conversion breaks single-channel logit inputs.
Lines 227-229 convert single-channel
y_predto one-hot before activation is applied here. This means:use_sigmoid=Truewill fail becauseone_hot()expects integersDocument this requirement or add validation to raise a clear error for single-channel inputs when activation flags are set.
🤖 Prompt for AI Agents
Multi-class limitation: softmax won't work beyond binary despite being added.
The component losses (
AsymmetricFocalLosslines 135-138 andAsymmetricFocalTverskyLosslines 79-80) hardcode indices[:, 0]and[:, 1], limiting support to exactly 2 classes. Even withuse_softmax=True, inputs with >2 channels will fail or produce incorrect results.Consider documenting this limitation in the
use_softmaxdocstring or adding validation to rejectnum_classes != 2.🤖 Prompt for AI Agents