@@ -40,7 +40,8 @@ class CombineDetections:
4040 filtered_classes_id (list): List of class IDs after non-maximum suppression.
4141 filtered_classes_names (list): List of class names after non-maximum suppression.
4242 filtered_masks (list): List of filtered (after nms) masks if segmentation is enabled.
43- filtered_polygons (list): List of filtered (after nms) polygons if segmentation and memory optimization are enabled.
43+ filtered_polygons (list): List of filtered (after nms) polygons if segmentation and
44+ memory optimization are enabled.
4445 """
4546
4647 def __init__ (
@@ -64,7 +65,7 @@ def __init__(
6465 self .intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
6566 self .sorter_bins = sorter_bins
6667 self .class_agnostic_nms = class_agnostic_nms
67-
68+
6869 # Combinate detections of all patches
6970 (
7071 self .detected_conf_list_full ,
@@ -99,7 +100,7 @@ def __init__(
99100 self .detected_masks_list_full ,
100101 intelligent_sorter = self .intelligent_sorter
101102 )
102-
103+
103104 # Apply filtering (nms output indeces) to the prediction lists
104105 self .filtered_confidences = [self .detected_conf_list_full [i ] for i in self .filtered_indices ]
105106 self .filtered_boxes = [self .detected_xyxy_list_full [i ] for i in self .filtered_indices ]
@@ -111,7 +112,7 @@ def __init__(
111112 self .filtered_masks = [self .detected_masks_list_full [i ] for i in self .filtered_indices ]
112113 else :
113114 self .filtered_masks = []
114-
115+
115116 # Polygons filtering:
116117 if element_crops .segment and element_crops .memory_optimize :
117118 self .filtered_polygons = [self .detected_polygons_list_full [i ] for i in self .filtered_indices ]
@@ -160,13 +161,13 @@ def average_to_bound(confidences, N=10):
160161 # Create the bounds
161162 step = 1 / N
162163 bounds = np .arange (0 , 1 + step , step )
163-
164+
164165 # Use np.digitize to determine the corresponding bin for each value
165166 indices = np .digitize (confidences , bounds , right = True ) - 1
166-
167+
167168 # Bind values to the left boundary of the corresponding bin
168169 averaged_confidences = np .round (bounds [indices ], 2 )
169-
170+
170171 return averaged_confidences .tolist ()
171172
172173 @staticmethod
@@ -201,7 +202,8 @@ def intersect_over_smaller(mask, masks_list):
201202 masks_list (list of np.ndarray): List of binary masks for comparison.
202203
203204 Returns:
204- torch.Tensor: IoU scores for each mask in masks_list compared to the input mask, calculated over the smaller area.
205+ torch.Tensor: IoU scores for each mask in masks_list compared to the input mask,
206+ calculated over the smaller area.
205207 """
206208 ios_scores = []
207209 for other_mask in masks_list :
@@ -234,7 +236,9 @@ def nms(
234236 nms_threshold (float): The threshold for match metric.
235237 masks (list): List of masks.
236238 intelligent_sorter (bool, optional): intelligent sorter
237- cls_indexes (torch.Tensor): indexes from network detections corresponding to the defined class, uses in case of not agnostic nms
239+ cls_indexes (torch.Tensor): indexes from network detections corresponding
240+ to the defined class, uses in case of not agnostic nms
241+
238242 Returns:
239243 list: List of filtered indexes.
240244 """
@@ -256,7 +260,10 @@ def nms(
256260 order = torch .tensor (
257261 sorted (
258262 range (len (confidences )),
259- key = lambda k : (self .average_to_bound (confidences [k ].item (), self .sorter_bins ), areas [k ]),
263+ key = lambda k : (
264+ self .average_to_bound (confidences [k ].item (), self .sorter_bins ),
265+ areas [k ],
266+ ),
260267 reverse = False ,
261268 )
262269 )
@@ -351,7 +358,7 @@ def nms(
351358 if cls_indexes is not None :
352359 keep = [cls_indexes [i ] for i in keep ]
353360 return keep
354-
361+
355362 def not_agnostic_nms (
356363 self ,
357364 detected_cls_id_list_full ,
@@ -362,8 +369,9 @@ def not_agnostic_nms(
362369 detected_masks_list_full ,
363370 intelligent_sorter
364371 ):
365- '''
366- Performs Non-Maximum Suppression (NMS) in a non-agnostic manner, where NMS is applied separately to each class.
372+ '''
373+ Performs Non-Maximum Suppression (NMS) in a non-agnostic manner, where NMS
374+ is applied separately to each class.
367375
368376 Args:
369377 detected_cls_id_list_full (torch.Tensor): tensor containing the class IDs for each detected object.
@@ -373,17 +381,21 @@ def not_agnostic_nms(
373381 nms_threshold (float): the threshold for match metric.
374382 detected_masks_list_full (torch.Tensor): List of masks.
375383 intelligent_sorter (bool, optional): intelligent sorter
384+
376385 Returns:
377- List[int]: A list of indices representing the detections that are kept after applying NMS for each class separately.
386+ List[int]: A list of indices representing the detections that are kept after applying
387+ NMS for each class separately.
378388
379389 Notes:
380- - This method performs NMS separately for each class, which helps in reducing false positives within each class.
381- - The `nms` function is assumed to be defined elsewhere in the class and is responsible for performing the actual NMS operation.
390+ - This method performs NMS separately for each class, which helps in
391+ reducing false positives within each class.
392+ - If in your scenario, an object of one class can physically be inside
393+ an object of another class, you should definitely use this non-agnostic nms
382394 '''
383- all_keeps = []
384- for cls in torch .unique (detected_cls_id_list_full ):
385- cls_indexes = torch .where (detected_cls_id_list_full == cls )[0 ]
386- keep_indexes = self .nms (
395+ all_keeps = []
396+ for cls in torch .unique (detected_cls_id_list_full ):
397+ cls_indexes = torch .where (detected_cls_id_list_full == cls )[0 ]
398+ keep_indexes = self .nms (
387399 detected_conf_list_full [cls_indexes ],
388400 detected_xyxy_list_full [cls_indexes ],
389401 match_metric ,
@@ -392,9 +404,5 @@ def not_agnostic_nms(
392404 intelligent_sorter ,
393405 cls_indexes
394406 )
395- all_keeps .extend (keep_indexes )
396- return all_keeps
397-
398-
399-
400-
407+ all_keeps .extend (keep_indexes )
408+ return all_keeps
0 commit comments