@@ -64,7 +64,7 @@ def __init__(
6464 self .intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
6565 self .sorter_bins = sorter_bins
6666 self .class_agnostic_nms = class_agnostic_nms
67-
67+
6868 # Combinate detections of all patches
6969 (
7070 self .detected_conf_list_full ,
@@ -81,18 +81,18 @@ def __init__(
8181 # Invoke the NMS:
8282 if self .class_agnostic_nms :
8383 self .filtered_indices = self .agnostic_nms (
84- self .detected_conf_list_full ,
85- self .detected_xyxy_list_full ,
84+ torch .tensor (self .detected_cls_id_list_full ),
85+ torch .tensor (self .detected_conf_list_full ),
86+ torch .tensor (self .detected_xyxy_list_full ),
8687 self .match_metric ,
8788 self .nms_threshold ,
8889 self .detected_masks_list_full ,
8990 intelligent_sorter = self .intelligent_sorter
9091 )
9192 else :
9293 self .filtered_indices = self .nms (
93- self .detected_conf_list_full ,
94- self .detected_xyxy_list_full ,
95- self .detected_cls_id_list_full ,
94+ torch .tensor (self .detected_conf_list_full ),
95+ torch .tensor (self .detected_xyxy_list_full ),
9696 self .match_metric ,
9797 self .nms_threshold ,
9898 self .detected_masks_list_full ,
@@ -212,14 +212,15 @@ def intersect_over_smaller(mask, masks_list):
212212 ios_scores .append (ios )
213213 return torch .tensor (ios_scores )
214214
215- def agnostic_nms (
215+ def nms (
216216 self ,
217- confidences : list ,
218- boxes : list ,
217+ confidences : torch . tensor ,
218+ boxes : torch . tensor ,
219219 match_metric ,
220220 nms_threshold ,
221221 masks = None ,
222- intelligent_sorter = False ,
222+ intelligent_sorter = False ,
223+ cls_indexes = None
223224 ):
224225 """
225226 Apply class-agnostic non-maximum suppression to avoid detecting too many
@@ -239,10 +240,6 @@ def agnostic_nms(
239240 if len (boxes ) == 0 :
240241 return []
241242
242- # Convert lists to tensors
243- boxes = torch .tensor (boxes )
244- confidences = torch .tensor (confidences )
245-
246243 # Extract coordinates for every prediction box present
247244 x1 = boxes [:, 0 ]
248245 y1 = boxes [:, 1 ]
@@ -350,5 +347,35 @@ def agnostic_nms(
350347 mask = match_metric_value < nms_threshold
351348
352349 order = order [mask ]
353-
350+ if cls_indexes is not None :
351+ keep = [cls_indexes [i ] for i in keep ]
354352 return keep
353+
354+ def agnostic_nms (
355+ self ,
356+ detected_cls_id_list_full ,
357+ detected_conf_list_full ,
358+ detected_xyxy_list_full ,
359+ match_metric ,
360+ nms_threshold ,
361+ detected_masks_list_full ,
362+ intelligent_sorter
363+ ):
364+ all_keeps = []
365+ for cls in torch .unique (detected_cls_id_list_full ):
366+ cls_indexes = torch .where (detected_cls_id_list_full == cls )[0 ]
367+ keep_indexes = self .nms (
368+ detected_conf_list_full [cls_indexes ],
369+ detected_xyxy_list_full [cls_indexes ],
370+ match_metric ,
371+ nms_threshold ,
372+ detected_masks_list_full ,
373+ intelligent_sorter ,
374+ cls_indexes
375+ )
376+ all_keeps .extend (keep_indexes )
377+ return all_keeps
378+
379+
380+
381+
0 commit comments