Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def filter(self, record):
# workaround related to https://github.com/Project-MONAI/MONAI/issues/7575
if hasattr(torch.cuda.device_count, "cache_clear"):
torch.cuda.device_count.cache_clear()
except BaseException:
except Exception:
from .utils.misc import MONAIEnvVars

if MONAIEnvVars.debug():
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/auto3dseg/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _get_all_case_stats(
_label_argmax = True # track if label is argmaxed
batch_data[self.label_key] = label.to(device)
d = summarizer(batch_data)
except BaseException as err:
except Exception as err:
if "image_meta_dict" in batch_data.keys():
filename = batch_data["image_meta_dict"][ImageMetaKey.FILENAME_OR_OBJ]
else:
Expand All @@ -357,7 +357,7 @@ def _get_all_case_stats(
label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
batch_data[self.label_key] = label.to("cpu")
d = summarizer(batch_data)
except BaseException as err:
except Exception as err:
logger.info(f"Unable to process data {filename} on {device}. {err}")
continue
else:
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/auto3dseg/ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __call__(self, pred_param: dict | None = None) -> list:
if "image_save_func" in param:
try:
ensemble_preds = self.ensemble_pred(preds, sigmoid=sigmoid)
except BaseException:
except Exception:
ensemble_preds = self.ensemble_pred([_.to("cpu") for _ in preds], sigmoid=sigmoid)
res = img_saver(ensemble_preds)
# res is the path to the saved results
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def import_bundle_algo_history(
if best_metric is None:
try:
best_metric = algo.get_score()
except BaseException:
except Exception:
pass

is_trained = best_metric is not None
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/detection/metrics/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _compute_stats_single_threshold(
for save_idx, array_index in enumerate(inds):
precision[save_idx] = pr[array_index]
th_scores[save_idx] = dt_scores_sorted[array_index]
except BaseException:
except Exception:
pass

return recall, np.array(precision), np.array(th_scores)
4 changes: 2 additions & 2 deletions monai/apps/nnunet/nnunetv2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name

self.dataset_name = maybe_convert_to_dataset_name(int(self.dataset_name_or_id))
except BaseException:
except Exception:
logger.warning(
f"Dataset with name/ID: {self.dataset_name_or_id} cannot be found in the record. "
"Please ignore the message above if you are running the pipeline from a fresh start. "
Expand Down Expand Up @@ -278,7 +278,7 @@ def convert_dataset(self):
num_input_channels=num_input_channels,
output_datafolder=raw_data_foldername,
)
except BaseException as err:
except Exception as err:
logger.warning(f"Input config may be incorrect. Detail info: error/exception message is:\n {err}")
return

Expand Down
2 changes: 1 addition & 1 deletion monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def print_config(file=sys.stdout):
def _dict_append(in_dict, key, fn):
try:
in_dict[key] = fn() if callable(fn) else fn
except BaseException:
except Exception:
in_dict[key] = "UNKNOWN for given OS"


Expand Down
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
from .wsi_datasets import MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset
from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader, WSIReader

with contextlib.suppress(BaseException):
with contextlib.suppress(Exception):
from multiprocessing.reduction import ForkingPickler

def _rebuild_meta(cls, storage, dtype, metadata):
Expand Down
2 changes: 1 addition & 1 deletion monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def __init__(
)
if cache_roi_weight_map and self.roi_weight_map is None:
warnings.warn("cache_roi_weight_map=True, but cache is not created. (dynamic roi_size?)")
except BaseException as e:
except Exception as e:
raise RuntimeError(
f"roi size {self.roi_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n"
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
Expand Down
26 changes: 21 additions & 5 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
use_sigmoid: bool = False,
):
"""
Args:
Expand All @@ -170,8 +172,14 @@ def __init__(
weight : weight for each loss function. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.


reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
use_softmax: if True, use softmax to transform the input logits into probabilities.
Defaults to False. Mutually exclusive with ``use_sigmoid``.
use_sigmoid: if True, use sigmoid to transform the input logits into probabilities.
Defaults to False. Mutually exclusive with ``use_softmax``.
When both ``use_softmax`` and ``use_sigmoid`` are False, the input is assumed
to already be probabilities.

Example:
>>> import torch
Expand All @@ -182,22 +190,25 @@ def __init__(
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
if use_softmax and use_sigmoid:
raise ValueError("use_softmax and use_sigmoid are mutually exclusive.")
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.use_softmax = use_softmax
self.use_sigmoid = use_sigmoid
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
The input can be raw logits or probabilities depending on ``use_softmax``
and ``use_sigmoid`` settings.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.

Expand Down Expand Up @@ -227,6 +238,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
elif self.use_sigmoid:
y_pred = torch.sigmoid(y_pred)
Comment on lines +241 to +244
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Activation applied after one-hot conversion breaks single-channel logit inputs.

Lines 227-229 convert single-channel y_pred to one-hot before activation is applied here. This means:

  • Single-channel inputs are treated as class indices (discrete values), not logits
  • Passing single-channel logits with use_sigmoid=True will fail because one_hot() expects integers
  • Users must pass 2-channel inputs to use the activation flags

Document this requirement or add validation to raise a clear error for single-channel inputs when activation flags are set.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The activation is
applied after converting single-channel y_pred to one-hot causing integer-index
errors for single-channel logits; update the UnifiedFocalLoss forward (or the
method handling y_pred) to either apply softmax/sigmoid before the one_hot
conversion or add an explicit input validation that raises a clear error when
y_pred has a single channel while use_softmax or use_sigmoid is True
(referencing variables y_pred, use_softmax, use_sigmoid and the one_hot
conversion block) so users get a helpful message rather than a silent failure.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Multi-class limitation: softmax won't work beyond binary despite being added.

The component losses (AsymmetricFocalLoss lines 135-138 and AsymmetricFocalTverskyLoss lines 79-80) hardcode indices [:, 0] and [:, 1], limiting support to exactly 2 classes. Even with use_softmax=True, inputs with >2 channels will fail or produce incorrect results.

Consider documenting this limitation in the use_softmax docstring or adding validation to reject num_classes != 2.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The code enables
softmax via use_softmax but the component losses AsymmetricFocalLoss and
AsymmetricFocalTverskyLoss index hardcoded class channels ([:, 0] and [:, 1]) so
softmax with >2 channels will be broken; update the functions to validate and
reject multi-class inputs (e.g., check y_pred.shape[1] or num_classes and raise
a clear ValueError if != 2) or explicitly document in the use_softmax docstring
that only binary (2-class) predictions are supported, and ensure the error
message references use_softmax, AsymmetricFocalLoss, and
AsymmetricFocalTverskyLoss so callers know why their multi-channel input is
unsupported.


asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

Expand Down
4 changes: 2 additions & 2 deletions monai/utils/tf32.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def has_ampere_or_later() -> bool:
major, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
if major >= 8:
return True
except BaseException:
except Exception:
pass
finally:
pynvml.nvmlShutdown()
Expand Down Expand Up @@ -71,7 +71,7 @@ def detect_default_tf32() -> bool:
may_enable_tf32 = True

return may_enable_tf32
except BaseException:
except Exception:
from monai.utils.misc import MONAIEnvVars

if MONAIEnvVars.debug():
Expand Down
4 changes: 2 additions & 2 deletions tests/apps/detection/networks/test_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_retina_shape(self, model, input_param, input_shape):
def test_script(self, model, input_param, input_shape):
try:
idx = int(self.id().split("test_script_")[-1])
except BaseException:
except Exception:
idx = 0
idx %= 3
# test whether support torchscript
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_script(self, model, input_param, input_shape):
def test_onnx(self, model, input_param, input_shape):
try:
idx = int(self.id().split("test_onnx_")[-1])
except BaseException:
except Exception:
idx = 0
idx %= 3
# test whether support torchscript
Expand Down
18 changes: 18 additions & 0 deletions tests/losses/test_unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ def test_with_cuda(self):
print(output)
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)

def test_use_sigmoid(self):
loss = AsymmetricUnifiedFocalLoss(use_sigmoid=True)
y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]])
y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
result = loss(y_pred, y_true)
self.assertTrue(result.item() >= 0)
Comment on lines +64 to +69
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Run the specific test to check if single-channel logits work
python -m pytest tests/losses/test_unified_focal_loss.py::TestAsymmetricUnifiedFocalLoss::test_use_sigmoid -v

Repository: Project-MONAI/MONAI

Length of output: 103


🏁 Script executed:

# Find the AsymmetricUnifiedFocalLoss implementation
find . -type f -name "*.py" | xargs grep -l "class AsymmetricUnifiedFocalLoss"

Repository: Project-MONAI/MONAI

Length of output: 100


🏁 Script executed:

# Get the test file to see all three tests mentioned
cat -n tests/losses/test_unified_focal_loss.py

Repository: Project-MONAI/MONAI

Length of output: 3818


🏁 Script executed:

# Search for one_hot usage in the loss implementation
rg "one_hot" --type py

Repository: Project-MONAI/MONAI

Length of output: 10644


🏁 Script executed:

cat -n monai/losses/unified_focal_loss.py | head -300

Repository: Project-MONAI/MONAI

Length of output: 12927


🏁 Script executed:

# Find and read the one_hot function definition
rg -A 20 "^def one_hot" monai/networks/utils.py

Repository: Project-MONAI/MONAI

Length of output: 1250


🏁 Script executed:

# Double check by examining what happens with out-of-bounds indices in one_hot
python3 << 'EOF'
import torch
import torch.nn.functional as F

# Simulate what monai.networks.one_hot does
labels = torch.tensor([10, -10])  # Out of bounds for 2 classes
try:
    # This is essentially what one_hot does internally
    result = F.one_hot(labels.long(), num_classes=2)
    print("Result:", result)
except Exception as e:
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {e}")
EOF

Repository: Project-MONAI/MONAI

Length of output: 181


🏁 Script executed:

# Get the full one_hot implementation
rg -A 30 "^def one_hot" monai/networks/utils.py

Repository: Project-MONAI/MONAI

Length of output: 1575


🏁 Script executed:

# Get the complete one_hot function implementation
rg -A 50 "^def one_hot" monai/networks/utils.py | head -60

Repository: Project-MONAI/MONAI

Length of output: 2157


Test crashes with out-of-bounds indices in one-hot conversion.

Single-channel y_pred (shape [2, 1, 2, 2]) triggers one_hot() conversion at line 228, which internally calls scatter_() with indices from .long(). The logit values 10.0 and -10 become indices 10 and -10, both out of bounds for num_classes=2. This causes an IndexError before use_sigmoid is even applied.

Recommendation: Use multi-channel inputs ([B, 2, H, W]) to bypass one-hot conversion, or use integer class indices if single-channel is intended.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/test_unified_focal_loss.py` around lines 64 - 69, The test
crashes because AsymmetricUnifiedFocalLoss's internal one_hot conversion is
called when y_pred is single-channel float logits ([2,1,2,2]) and those logits
are being cast to long, producing out-of-range indices; change the test to
provide multi-channel logits (e.g. shape [2,2,2,2]) when use_sigmoid=True or
else supply integer class indices for y_true so one_hot isn't fed raw logits;
update the test_use_sigmoid to create y_pred with two channels and matching
y_true (or use class indices) so one_hot/scatter_ receives valid class indices.


def test_use_softmax(self):
loss = AsymmetricUnifiedFocalLoss(use_softmax=True)
y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]])
y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
result = loss(y_pred, y_true)
self.assertTrue(result.item() >= 0)

def test_mutually_exclusive(self):
with self.assertRaises(ValueError):
AsymmetricUnifiedFocalLoss(use_softmax=True, use_sigmoid=True)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/networks/nets/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def tearDown(self):
if os.path.exists(self.tmp_ckpt_filename):
try:
os.remove(self.tmp_ckpt_filename)
except BaseException:
except Exception:
pass

@parameterized.expand(TEST_CASES)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def is_tf32_env():
a_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu)
b_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu)
_tf32_enabled = (a_full.float() @ b_full.float() - a_full @ b_full).abs().max().item() > 0.001 # 0.1713
except BaseException:
except Exception:
pass
print(f"tf32 enabled: {_tf32_enabled}")
return _tf32_enabled
Expand Down
Loading