@@ -79,25 +79,26 @@ def __init__(
7979 ] # make str list
8080
8181 # Invoke the NMS:
82- if not self .class_agnostic_nms :
83- self .filtered_indices = self .agnostic_nms (
84- torch .tensor (self .detected_cls_id_list_full ),
82+ if self .class_agnostic_nms :
83+ self .filtered_indices = self .nms (
8584 torch .tensor (self .detected_conf_list_full ),
8685 torch .tensor (self .detected_xyxy_list_full ),
8786 self .match_metric ,
8887 self .nms_threshold ,
8988 self .detected_masks_list_full ,
9089 intelligent_sorter = self .intelligent_sorter
91- )
90+ )
91+
9292 else :
93- self .filtered_indices = self .nms (
93+ self .filtered_indices = self .not_agnostic_nms (
94+ torch .tensor (self .detected_cls_id_list_full ),
9495 torch .tensor (self .detected_conf_list_full ),
9596 torch .tensor (self .detected_xyxy_list_full ),
9697 self .match_metric ,
9798 self .nms_threshold ,
9899 self .detected_masks_list_full ,
99100 intelligent_sorter = self .intelligent_sorter
100- )
101+ )
101102
102103 # Apply filtering (nms output indeces) to the prediction lists
103104 self .filtered_confidences = [self .detected_conf_list_full [i ] for i in self .filtered_indices ]
@@ -227,13 +228,13 @@ def nms(
227228 overlapping bounding boxes for a given object.
228229
229230 Args:
230- confidences (list ): List of confidence scores.
231- boxes (list ): List of bounding boxes.
231+ confidences (torch.Tensor ): List of confidence scores.
232+ boxes (torch.Tensor ): List of bounding boxes.
232233 match_metric (str): Matching metric, either 'IOU' or 'IOS'.
233234 nms_threshold (float): The threshold for match metric.
234235 masks (list): List of masks.
235236 intelligent_sorter (bool, optional): intelligent sorter
236-
237+ cls_indexes (torch.Tensor): indexes from network detections corresponding to the defined class, uses in case of not agnostic nms
237238 Returns:
238239 list: List of filtered indexes.
239240 """
@@ -351,7 +352,7 @@ def nms(
351352 keep = [cls_indexes [i ] for i in keep ]
352353 return keep
353354
354- def agnostic_nms (
355+ def not_agnostic_nms (
355356 self ,
356357 detected_cls_id_list_full ,
357358 detected_conf_list_full ,
@@ -361,6 +362,24 @@ def agnostic_nms(
361362 detected_masks_list_full ,
362363 intelligent_sorter
363364 ):
365+ '''
366+ Performs Non-Maximum Suppression (NMS) in a non-agnostic manner, where NMS is applied separately to each class.
367+
368+ Args:
369+ detected_cls_id_list_full (torch.Tensor): tensor containing the class IDs for each detected object.
370+ detected_conf_list_full (torch.Tensor): tensor of confidence scores.
371+ detected_xyxy_list_full (torch.Tensor): tensor of bounding boxes.
372+ match_metric (str): Matching metric, either 'IOU' or 'IOS'.
373+ nms_threshold (float): the threshold for match metric.
374+ detected_masks_list_full (torch.Tensor): List of masks.
375+ intelligent_sorter (bool, optional): intelligent sorter
376+ Returns:
377+ List[int]: A list of indices representing the detections that are kept after applying NMS for each class separately.
378+
379+ Notes:
380+ - This method performs NMS separately for each class, which helps in reducing false positives within each class.
381+ - The `nms` function is assumed to be defined elsewhere in the class and is responsible for performing the actual NMS operation.
382+ '''
364383 all_keeps = []
365384 for cls in torch .unique (detected_cls_id_list_full ):
366385 cls_indexes = torch .where (detected_cls_id_list_full == cls )[0 ]
0 commit comments