Skip to content

Commit 153cdee

Browse files
committed
fix mask error
1 parent 9bce31b commit 153cdee

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def nms(
221221
boxes: torch.tensor,
222222
match_metric,
223223
nms_threshold,
224-
masks=None,
224+
masks=[],
225225
intelligent_sorter=False,
226226
cls_indexes=None
227227
):
@@ -395,12 +395,16 @@ def not_agnostic_nms(
395395
all_keeps = []
396396
for cls in torch.unique(detected_cls_id_list_full):
397397
cls_indexes = torch.where(detected_cls_id_list_full==cls)[0]
398+
if len(detected_masks_list_full) > 0:
399+
masks_of_class = [detected_masks_list_full[i] for i in cls_indexes]
400+
else:
401+
masks_of_class = []
398402
keep_indexes = self.nms(
399403
detected_conf_list_full[cls_indexes],
400404
detected_xyxy_list_full[cls_indexes],
401405
match_metric,
402406
nms_threshold,
403-
detected_masks_list_full,
407+
masks_of_class,
404408
intelligent_sorter,
405409
cls_indexes
406410
)

0 commit comments

Comments
 (0)