Skip to content

Commit 4d2e809

Browse files
committed
class agnostic
1 parent c204fc9 commit 4d2e809

1 file changed

Lines changed: 42 additions & 15 deletions

File tree

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self.intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
6565
self.sorter_bins = sorter_bins
6666
self.class_agnostic_nms = class_agnostic_nms
67-
67+
6868
# Combinate detections of all patches
6969
(
7070
self.detected_conf_list_full,
@@ -81,18 +81,18 @@ def __init__(
8181
# Invoke the NMS:
8282
if self.class_agnostic_nms:
8383
self.filtered_indices = self.agnostic_nms(
84-
self.detected_conf_list_full,
85-
self.detected_xyxy_list_full,
84+
torch.tensor(self.detected_cls_id_list_full),
85+
torch.tensor(self.detected_conf_list_full),
86+
torch.tensor(self.detected_xyxy_list_full),
8687
self.match_metric,
8788
self.nms_threshold,
8889
self.detected_masks_list_full,
8990
intelligent_sorter=self.intelligent_sorter
9091
)
9192
else:
9293
self.filtered_indices = self.nms(
93-
self.detected_conf_list_full,
94-
self.detected_xyxy_list_full,
95-
self.detected_cls_id_list_full,
94+
torch.tensor(self.detected_conf_list_full),
95+
torch.tensor(self.detected_xyxy_list_full),
9696
self.match_metric,
9797
self.nms_threshold,
9898
self.detected_masks_list_full,
@@ -212,14 +212,15 @@ def intersect_over_smaller(mask, masks_list):
212212
ios_scores.append(ios)
213213
return torch.tensor(ios_scores)
214214

215-
def agnostic_nms(
215+
def nms(
216216
self,
217-
confidences: list,
218-
boxes: list,
217+
confidences: torch.tensor,
218+
boxes: torch.tensor,
219219
match_metric,
220220
nms_threshold,
221221
masks=None,
222-
intelligent_sorter=False,
222+
intelligent_sorter=False,
223+
cls_indexes=None
223224
):
224225
"""
225226
Apply class-agnostic non-maximum suppression to avoid detecting too many
@@ -239,10 +240,6 @@ def agnostic_nms(
239240
if len(boxes) == 0:
240241
return []
241242

242-
# Convert lists to tensors
243-
boxes = torch.tensor(boxes)
244-
confidences = torch.tensor(confidences)
245-
246243
# Extract coordinates for every prediction box present
247244
x1 = boxes[:, 0]
248245
y1 = boxes[:, 1]
@@ -350,5 +347,35 @@ def agnostic_nms(
350347
mask = match_metric_value < nms_threshold
351348

352349
order = order[mask]
353-
350+
if cls_indexes is not None:
351+
keep = [cls_indexes[i] for i in keep]
354352
return keep
353+
354+
def agnostic_nms(
355+
self,
356+
detected_cls_id_list_full,
357+
detected_conf_list_full,
358+
detected_xyxy_list_full,
359+
match_metric,
360+
nms_threshold,
361+
detected_masks_list_full,
362+
intelligent_sorter
363+
):
364+
all_keeps = []
365+
for cls in torch.unique(detected_cls_id_list_full):
366+
cls_indexes = torch.where(detected_cls_id_list_full==cls)[0]
367+
keep_indexes = self.nms(
368+
detected_conf_list_full[cls_indexes],
369+
detected_xyxy_list_full[cls_indexes],
370+
match_metric,
371+
nms_threshold,
372+
detected_masks_list_full,
373+
intelligent_sorter,
374+
cls_indexes
375+
)
376+
all_keeps.extend(keep_indexes)
377+
return all_keeps
378+
379+
380+
381+

0 commit comments

Comments
 (0)