From e3a2bb255a40cd10dab5ddd25185765622503e54 Mon Sep 17 00:00:00 2001 From: Oleksandr Yizchak Sanin Date: Tue, 19 May 2026 00:19:34 +0200 Subject: [PATCH 1/2] Replace BaseException with Exception across codebase (#7401) Catching BaseException inadvertently suppresses KeyboardInterrupt, SystemExit, and GeneratorExit, which should nearly always propagate. All 17 occurrences across monai/ and tests/ are replaced with Exception, which is the appropriate base class for catchable errors. Signed-off-by: Oleksandr Sanin --- monai/__init__.py | 2 +- monai/apps/auto3dseg/data_analyzer.py | 4 ++-- monai/apps/auto3dseg/ensemble_builder.py | 2 +- monai/apps/auto3dseg/utils.py | 2 +- monai/apps/detection/metrics/coco.py | 2 +- monai/apps/nnunet/nnunetv2_runner.py | 4 ++-- monai/config/deviceconfig.py | 2 +- monai/data/__init__.py | 2 +- monai/inferers/inferer.py | 2 +- monai/utils/tf32.py | 4 ++-- tests/apps/detection/networks/test_retinanet.py | 4 ++-- tests/networks/nets/test_resnet.py | 2 +- tests/test_utils.py | 2 +- 13 files changed, 17 insertions(+), 17 deletions(-) diff --git a/monai/__init__.py b/monai/__init__.py index d92557a8e1..45e15ddcd9 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -131,7 +131,7 @@ def filter(self, record): # workaround related to https://github.com/Project-MONAI/MONAI/issues/7575 if hasattr(torch.cuda.device_count, "cache_clear"): torch.cuda.device_count.cache_clear() -except BaseException: +except Exception: from .utils.misc import MONAIEnvVars if MONAIEnvVars.debug(): diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 15e56abfea..30824e200f 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -341,7 +341,7 @@ def _get_all_case_stats( _label_argmax = True # track if label is argmaxed batch_data[self.label_key] = label.to(device) d = summarizer(batch_data) - except BaseException as err: + except Exception as err: if "image_meta_dict" in batch_data.keys(): filename = batch_data["image_meta_dict"][ImageMetaKey.FILENAME_OR_OBJ] else: @@ -357,7 +357,7 @@ def _get_all_case_stats( label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] batch_data[self.label_key] = label.to("cpu") d = summarizer(batch_data) - except BaseException as err: + except Exception as err: logger.info(f"Unable to process data {filename} on {device}. {err}") continue else: diff --git a/monai/apps/auto3dseg/ensemble_builder.py b/monai/apps/auto3dseg/ensemble_builder.py index e574baf7c8..eaf1f14c7f 100644 --- a/monai/apps/auto3dseg/ensemble_builder.py +++ b/monai/apps/auto3dseg/ensemble_builder.py @@ -216,7 +216,7 @@ def __call__(self, pred_param: dict | None = None) -> list: if "image_save_func" in param: try: ensemble_preds = self.ensemble_pred(preds, sigmoid=sigmoid) - except BaseException: + except Exception: ensemble_preds = self.ensemble_pred([_.to("cpu") for _ in preds], sigmoid=sigmoid) res = img_saver(ensemble_preds) # res is the path to the saved results diff --git a/monai/apps/auto3dseg/utils.py b/monai/apps/auto3dseg/utils.py index fbf9dc101c..1946bb718a 100644 --- a/monai/apps/auto3dseg/utils.py +++ b/monai/apps/auto3dseg/utils.py @@ -59,7 +59,7 @@ def import_bundle_algo_history( if best_metric is None: try: best_metric = algo.get_score() - except BaseException: + except Exception: pass is_trained = best_metric is not None diff --git a/monai/apps/detection/metrics/coco.py b/monai/apps/detection/metrics/coco.py index d1347d76b8..b8de0eb104 100644 --- a/monai/apps/detection/metrics/coco.py +++ b/monai/apps/detection/metrics/coco.py @@ -541,7 +541,7 @@ def _compute_stats_single_threshold( for save_idx, array_index in enumerate(inds): precision[save_idx] = pr[array_index] th_scores[save_idx] = dt_scores_sorted[array_index] - except BaseException: + except Exception: pass return recall, np.array(precision), np.array(th_scores) diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py index 98b265cbbb..4828ccb56e 100644 --- a/monai/apps/nnunet/nnunetv2_runner.py +++ b/monai/apps/nnunet/nnunetv2_runner.py @@ -200,7 +200,7 @@ def __init__( from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name self.dataset_name = maybe_convert_to_dataset_name(int(self.dataset_name_or_id)) - except BaseException: + except Exception: logger.warning( f"Dataset with name/ID: {self.dataset_name_or_id} cannot be found in the record. " "Please ignore the message above if you are running the pipeline from a fresh start. " @@ -278,7 +278,7 @@ def convert_dataset(self): num_input_channels=num_input_channels, output_datafolder=raw_data_foldername, ) - except BaseException as err: + except Exception as err: logger.warning(f"Input config may be incorrect. Detail info: error/exception message is:\n {err}") return diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index aa1f2a0b53..5800651e59 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -120,7 +120,7 @@ def print_config(file=sys.stdout): def _dict_append(in_dict, key, fn): try: in_dict[key] = fn() if callable(fn) else fn - except BaseException: + except Exception: in_dict[key] = "UNKNOWN for given OS" diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 5e367cc297..971d5121f7 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -118,7 +118,7 @@ from .wsi_datasets import MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader, WSIReader -with contextlib.suppress(BaseException): +with contextlib.suppress(Exception): from multiprocessing.reduction import ForkingPickler def _rebuild_meta(cls, storage, dtype, metadata): diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index eea573609c..ee94b1ebdb 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -545,7 +545,7 @@ def __init__( ) if cache_roi_weight_map and self.roi_weight_map is None: warnings.warn("cache_roi_weight_map=True, but cache is not created. (dynamic roi_size?)") - except BaseException as e: + except Exception as e: raise RuntimeError( f"roi size {self.roi_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." diff --git a/monai/utils/tf32.py b/monai/utils/tf32.py index ad5918a34a..db5e05279e 100644 --- a/monai/utils/tf32.py +++ b/monai/utils/tf32.py @@ -41,7 +41,7 @@ def has_ampere_or_later() -> bool: major, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle) if major >= 8: return True - except BaseException: + except Exception: pass finally: pynvml.nvmlShutdown() @@ -71,7 +71,7 @@ def detect_default_tf32() -> bool: may_enable_tf32 = True return may_enable_tf32 - except BaseException: + except Exception: from monai.utils.misc import MONAIEnvVars if MONAIEnvVars.debug(): diff --git a/tests/apps/detection/networks/test_retinanet.py b/tests/apps/detection/networks/test_retinanet.py index 3f4721a755..eb491361ea 100644 --- a/tests/apps/detection/networks/test_retinanet.py +++ b/tests/apps/detection/networks/test_retinanet.py @@ -140,7 +140,7 @@ def test_retina_shape(self, model, input_param, input_shape): def test_script(self, model, input_param, input_shape): try: idx = int(self.id().split("test_script_")[-1]) - except BaseException: + except Exception: idx = 0 idx %= 3 # test whether support torchscript @@ -174,7 +174,7 @@ def test_script(self, model, input_param, input_shape): def test_onnx(self, model, input_param, input_shape): try: idx = int(self.id().split("test_onnx_")[-1]) - except BaseException: + except Exception: idx = 0 idx %= 3 # test whether support torchscript diff --git a/tests/networks/nets/test_resnet.py b/tests/networks/nets/test_resnet.py index 371ec89682..241f57c78d 100644 --- a/tests/networks/nets/test_resnet.py +++ b/tests/networks/nets/test_resnet.py @@ -249,7 +249,7 @@ def tearDown(self): if os.path.exists(self.tmp_ckpt_filename): try: os.remove(self.tmp_ckpt_filename) - except BaseException: + except Exception: pass @parameterized.expand(TEST_CASES) diff --git a/tests/test_utils.py b/tests/test_utils.py index 27af61cefe..05f7cb88d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -241,7 +241,7 @@ def is_tf32_env(): a_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) b_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) _tf32_enabled = (a_full.float() @ b_full.float() - a_full @ b_full).abs().max().item() > 0.001 # 0.1713 - except BaseException: + except Exception: pass print(f"tf32 enabled: {_tf32_enabled}") return _tf32_enabled From ec11cc3fe9a79f7298d1a8b5f9b066e17dee5c9d Mon Sep 17 00:00:00 2001 From: Oleksandr Yizchak Sanin Date: Tue, 19 May 2026 01:00:19 +0200 Subject: [PATCH 2/2] Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss (#8603) Add use_softmax and use_sigmoid parameters so users can pass raw logits directly. When both are False (default), the input is assumed to already be probabilities, preserving backward compatibility. Also removes the stale TODO comment about multi-class support and adds proper docstrings for the reduction parameter. Signed-off-by: Oleksandr Sanin --- monai/losses/unified_focal_loss.py | 26 ++++++++++++++++++++----- tests/losses/test_unified_focal_loss.py | 18 +++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 745513fec0..ecd308a699 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -162,6 +162,8 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, + use_sigmoid: bool = False, ): """ Args: @@ -170,8 +172,14 @@ def __init__( weight : weight for each loss function. Defaults to 0.5. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - - + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + use_softmax: if True, use softmax to transform the input logits into probabilities. + Defaults to False. Mutually exclusive with ``use_sigmoid``. + use_sigmoid: if True, use sigmoid to transform the input logits into probabilities. + Defaults to False. Mutually exclusive with ``use_softmax``. + When both ``use_softmax`` and ``use_sigmoid`` are False, the input is assumed + to already be probabilities. Example: >>> import torch @@ -182,22 +190,25 @@ def __init__( >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) + if use_softmax and use_sigmoid: + raise ValueError("use_softmax and use_sigmoid are mutually exclusive.") self.to_onehot_y = to_onehot_y self.num_classes = num_classes self.gamma = gamma self.delta = delta self.weight: float = weight + self.use_softmax = use_softmax + self.use_sigmoid = use_sigmoid self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred : the shape should be BNH[WD], where N is the number of classes. It only supports binary segmentation. - The input should be the original logits since it will be transformed by - a sigmoid in the forward function. + The input can be raw logits or probabilities depending on ``use_softmax`` + and ``use_sigmoid`` settings. y_true : the shape should be BNH[WD], where N is the number of classes. It only supports binary segmentation. @@ -227,6 +238,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: else: y_true = one_hot(y_true, num_classes=n_pred_ch) + if self.use_softmax: + y_pred = torch.softmax(y_pred, dim=1) + elif self.use_sigmoid: + y_pred = torch.sigmoid(y_pred) + asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index 3b868a560e..3fa7354cf2 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -61,6 +61,24 @@ def test_with_cuda(self): print(output) np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + def test_use_sigmoid(self): + loss = AsymmetricUnifiedFocalLoss(use_sigmoid=True) + y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]]) + y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) + result = loss(y_pred, y_true) + self.assertTrue(result.item() >= 0) + + def test_use_softmax(self): + loss = AsymmetricUnifiedFocalLoss(use_softmax=True) + y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]]) + y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) + result = loss(y_pred, y_true) + self.assertTrue(result.item() >= 0) + + def test_mutually_exclusive(self): + with self.assertRaises(ValueError): + AsymmetricUnifiedFocalLoss(use_softmax=True, use_sigmoid=True) + if __name__ == "__main__": unittest.main()