Skip to content

Commit 056822f

Browse files
committed
dockstrings and function naming
1 parent c048160 commit 056822f

1 file changed

Lines changed: 29 additions & 10 deletions

File tree

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,26 @@ def __init__(
7979
] # make str list
8080

8181
# Invoke the NMS:
82-
if not self.class_agnostic_nms:
83-
self.filtered_indices = self.agnostic_nms(
84-
torch.tensor(self.detected_cls_id_list_full),
82+
if self.class_agnostic_nms:
83+
self.filtered_indices = self.nms(
8584
torch.tensor(self.detected_conf_list_full),
8685
torch.tensor(self.detected_xyxy_list_full),
8786
self.match_metric,
8887
self.nms_threshold,
8988
self.detected_masks_list_full,
9089
intelligent_sorter=self.intelligent_sorter
91-
)
90+
)
91+
9292
else:
93-
self.filtered_indices = self.nms(
93+
self.filtered_indices = self.not_agnostic_nms(
94+
torch.tensor(self.detected_cls_id_list_full),
9495
torch.tensor(self.detected_conf_list_full),
9596
torch.tensor(self.detected_xyxy_list_full),
9697
self.match_metric,
9798
self.nms_threshold,
9899
self.detected_masks_list_full,
99100
intelligent_sorter=self.intelligent_sorter
100-
)
101+
)
101102

102103
# Apply filtering (nms output indeces) to the prediction lists
103104
self.filtered_confidences = [self.detected_conf_list_full[i] for i in self.filtered_indices]
@@ -227,13 +228,13 @@ def nms(
227228
overlapping bounding boxes for a given object.
228229
229230
Args:
230-
confidences (list): List of confidence scores.
231-
boxes (list): List of bounding boxes.
231+
confidences (torch.Tensor): List of confidence scores.
232+
boxes (torch.Tensor): List of bounding boxes.
232233
match_metric (str): Matching metric, either 'IOU' or 'IOS'.
233234
nms_threshold (float): The threshold for match metric.
234235
masks (list): List of masks.
235236
intelligent_sorter (bool, optional): intelligent sorter
236-
237+
cls_indexes (torch.Tensor): indexes from network detections corresponding to the defined class, uses in case of not agnostic nms
237238
Returns:
238239
list: List of filtered indexes.
239240
"""
@@ -351,7 +352,7 @@ def nms(
351352
keep = [cls_indexes[i] for i in keep]
352353
return keep
353354

354-
def agnostic_nms(
355+
def not_agnostic_nms(
355356
self,
356357
detected_cls_id_list_full,
357358
detected_conf_list_full,
@@ -361,6 +362,24 @@ def agnostic_nms(
361362
detected_masks_list_full,
362363
intelligent_sorter
363364
):
365+
'''
366+
Performs Non-Maximum Suppression (NMS) in a non-agnostic manner, where NMS is applied separately to each class.
367+
368+
Args:
369+
detected_cls_id_list_full (torch.Tensor): tensor containing the class IDs for each detected object.
370+
detected_conf_list_full (torch.Tensor): tensor of confidence scores.
371+
detected_xyxy_list_full (torch.Tensor): tensor of bounding boxes.
372+
match_metric (str): Matching metric, either 'IOU' or 'IOS'.
373+
nms_threshold (float): the threshold for match metric.
374+
detected_masks_list_full (torch.Tensor): List of masks.
375+
intelligent_sorter (bool, optional): intelligent sorter
376+
Returns:
377+
List[int]: A list of indices representing the detections that are kept after applying NMS for each class separately.
378+
379+
Notes:
380+
- This method performs NMS separately for each class, which helps in reducing false positives within each class.
381+
- The `nms` function is assumed to be defined elsewhere in the class and is responsible for performing the actual NMS operation.
382+
'''
364383
all_keeps = []
365384
for cls in torch.unique(detected_cls_id_list_full):
366385
cls_indexes = torch.where(detected_cls_id_list_full==cls)[0]

0 commit comments

Comments
 (0)