@@ -9,20 +9,23 @@ class CombineDetections:
99
1010 Args:
1111 element_crops (MakeCropsDetectThem): Object containing crop information.
12- nms_threshold (float): IoU/IoS threshold for non-maximum suppression.
13- match_metric (str): Matching metric, either 'IOU' or 'IOS'.
12+ nms_threshold (float): IoU/IoS threshold for non-maximum suppression. Dafault is 0.3.
13+ match_metric (str): Matching metric, either 'IOU' or 'IOS'. Dafault is IoS.
14+ class_agnostic_nms (bool) Determines the NMS mode in object detection. When set to True, NMS
15+ operates across all classes, ignoring class distinctions and suppressing less confident
16+ bounding boxes globally. Otherwise, NMS is applied separately for each class. Default is True.
1417 intelligent_sorter (bool): Enable sorting by area and rounded confidence parameter.
15- If False, sorting will be done only by confidence (usual nms). ( Dafault True)
18+ If False, sorting will be done only by confidence (usual nms). Dafault is True.
1619 sorter_bins (int): Number of bins to use for intelligent_sorter. A smaller number of bins makes
1720 the NMS more reliant on object sizes rather than confidence scores. Defaults to 10.
1821
1922 Attributes:
20- conf_treshold (float): Confidence threshold of yolov8.
21- class_names (dict): Dictionary containing class names pf yolov8 model.
23+ class_names (dict): Dictionary containing class names of yolo model.
2224 crops (list): List to store the CropElement objects.
2325 image (np.ndarray): Source image in BGR.
2426 nms_threshold (float): IOU/IOS threshold for non-maximum suppression.
2527 match_metric (str): Matching metric (IOU/IOS).
28+ class_agnostic_nms (bool) Determines the NMS mode in object detection.
2629 intelligent_sorter (bool): Flag indicating whether sorting by area and confidence parameter is enabled.
2730 sorter_bins (int): Number of bins to use for intelligent_sorter.
2831 detected_conf_list_full (list): List of detected confidences.
@@ -46,9 +49,9 @@ def __init__(
4649 nms_threshold = 0.3 ,
4750 match_metric = 'IOS' ,
4851 intelligent_sorter = True ,
49- sorter_bins = 10
52+ sorter_bins = 10 ,
53+ class_agnostic_nms = True
5054 ) -> None :
51- self .conf_treshold = element_crops .conf
5255 self .class_names = element_crops .class_names_dict
5356 self .crops = element_crops .crops # List to store the CropElement objects
5457 if element_crops .resize_initial_size :
@@ -60,6 +63,7 @@ def __init__(
6063 self .match_metric = match_metric
6164 self .intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
6265 self .sorter_bins = sorter_bins
66+ self .class_agnostic_nms = class_agnostic_nms
6367
6468 # Combinate detections of all patches
6569 (
@@ -74,27 +78,27 @@ def __init__(
7478 self .class_names [value ] for value in self .detected_cls_id_list_full
7579 ] # make str list
7680
77- # Invoke the NMS for segmentation masks method for filtering predictions
78- if len (self .detected_masks_list_full ) > 0 :
79-
80- self .filtered_indices = self .nms (
81+ # Invoke the NMS:
82+ if self .class_agnostic_nms :
83+ self .filtered_indices = self .agnostic_nms (
8184 self .detected_conf_list_full ,
8285 self .detected_xyxy_list_full ,
8386 self .match_metric ,
8487 self .nms_threshold ,
8588 self .detected_masks_list_full ,
8689 intelligent_sorter = self .intelligent_sorter
87- ) # for instance segmentation
90+ )
8891 else :
89- # Invoke the NMS method for filtering prediction
9092 self .filtered_indices = self .nms (
9193 self .detected_conf_list_full ,
9294 self .detected_xyxy_list_full ,
93- self .match_metric ,
95+ self .detected_cls_id_list_full ,
96+ self .match_metric ,
9497 self .nms_threshold ,
98+ self .detected_masks_list_full ,
9599 intelligent_sorter = self .intelligent_sorter
96- ) # for detection
97-
100+ )
101+
98102 # Apply filtering (nms output indeces) to the prediction lists
99103 self .filtered_confidences = [self .detected_conf_list_full [i ] for i in self .filtered_indices ]
100104 self .filtered_boxes = [self .detected_xyxy_list_full [i ] for i in self .filtered_indices ]
@@ -208,7 +212,7 @@ def intersect_over_smaller(mask, masks_list):
208212 ios_scores .append (ios )
209213 return torch .tensor (ios_scores )
210214
211- def nms (
215+ def agnostic_nms (
212216 self ,
213217 confidences : list ,
214218 boxes : list ,
@@ -218,15 +222,15 @@ def nms(
218222 intelligent_sorter = False ,
219223 ):
220224 """
221- Apply non-maximum suppression to avoid detecting too many
225+ Apply class-agnostic non-maximum suppression to avoid detecting too many
222226 overlapping bounding boxes for a given object.
223227
224228 Args:
225229 confidences (list): List of confidence scores.
226230 boxes (list): List of bounding boxes.
227231 match_metric (str): Matching metric, either 'IOU' or 'IOS'.
228232 nms_threshold (float): The threshold for match metric.
229- masks (list, optional ): List of masks. Defaults to None.
233+ masks (list): List of masks.
230234 intelligent_sorter (bool, optional): intelligent sorter
231235
232236 Returns:
@@ -320,7 +324,7 @@ def nms(
320324
321325 # If masks are provided and IoU based on bounding boxes is greater than 0,
322326 # calculate IoU for masks and keep the ones with IoU < nms_threshold
323- if masks is not None and torch .any (match_metric_value > 0 ):
327+ if len ( masks ) > 0 and torch .any (match_metric_value > 0 ):
324328
325329 mask_mask = match_metric_value > 0
326330
0 commit comments