Skip to content

Commit 074e62c

Browse files
committed
agnostic nms
1 parent 7dbb5e0 commit 074e62c

2 files changed

Lines changed: 26 additions & 22 deletions

File tree

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,23 @@ class CombineDetections:
99
1010
Args:
1111
element_crops (MakeCropsDetectThem): Object containing crop information.
12-
nms_threshold (float): IoU/IoS threshold for non-maximum suppression.
13-
match_metric (str): Matching metric, either 'IOU' or 'IOS'.
12+
nms_threshold (float): IoU/IoS threshold for non-maximum suppression. Dafault is 0.3.
13+
match_metric (str): Matching metric, either 'IOU' or 'IOS'. Dafault is IoS.
14+
class_agnostic_nms (bool) Determines the NMS mode in object detection. When set to True, NMS
15+
operates across all classes, ignoring class distinctions and suppressing less confident
16+
bounding boxes globally. Otherwise, NMS is applied separately for each class. Default is True.
1417
intelligent_sorter (bool): Enable sorting by area and rounded confidence parameter.
15-
If False, sorting will be done only by confidence (usual nms). (Dafault True)
18+
If False, sorting will be done only by confidence (usual nms). Dafault is True.
1619
sorter_bins (int): Number of bins to use for intelligent_sorter. A smaller number of bins makes
1720
the NMS more reliant on object sizes rather than confidence scores. Defaults to 10.
1821
1922
Attributes:
20-
conf_treshold (float): Confidence threshold of yolov8.
21-
class_names (dict): Dictionary containing class names pf yolov8 model.
23+
class_names (dict): Dictionary containing class names of yolo model.
2224
crops (list): List to store the CropElement objects.
2325
image (np.ndarray): Source image in BGR.
2426
nms_threshold (float): IOU/IOS threshold for non-maximum suppression.
2527
match_metric (str): Matching metric (IOU/IOS).
28+
class_agnostic_nms (bool) Determines the NMS mode in object detection.
2629
intelligent_sorter (bool): Flag indicating whether sorting by area and confidence parameter is enabled.
2730
sorter_bins (int): Number of bins to use for intelligent_sorter.
2831
detected_conf_list_full (list): List of detected confidences.
@@ -46,9 +49,9 @@ def __init__(
4649
nms_threshold=0.3,
4750
match_metric='IOS',
4851
intelligent_sorter=True,
49-
sorter_bins=10
52+
sorter_bins=10,
53+
class_agnostic_nms=True
5054
) -> None:
51-
self.conf_treshold = element_crops.conf
5255
self.class_names = element_crops.class_names_dict
5356
self.crops = element_crops.crops # List to store the CropElement objects
5457
if element_crops.resize_initial_size:
@@ -60,6 +63,7 @@ def __init__(
6063
self.match_metric = match_metric
6164
self.intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
6265
self.sorter_bins = sorter_bins
66+
self.class_agnostic_nms = class_agnostic_nms
6367

6468
# Combinate detections of all patches
6569
(
@@ -74,27 +78,27 @@ def __init__(
7478
self.class_names[value] for value in self.detected_cls_id_list_full
7579
] # make str list
7680

77-
# Invoke the NMS for segmentation masks method for filtering predictions
78-
if len(self.detected_masks_list_full) > 0:
79-
80-
self.filtered_indices = self.nms(
81+
# Invoke the NMS:
82+
if self.class_agnostic_nms:
83+
self.filtered_indices = self.agnostic_nms(
8184
self.detected_conf_list_full,
8285
self.detected_xyxy_list_full,
8386
self.match_metric,
8487
self.nms_threshold,
8588
self.detected_masks_list_full,
8689
intelligent_sorter=self.intelligent_sorter
87-
) # for instance segmentation
90+
)
8891
else:
89-
# Invoke the NMS method for filtering prediction
9092
self.filtered_indices = self.nms(
9193
self.detected_conf_list_full,
9294
self.detected_xyxy_list_full,
93-
self.match_metric,
95+
self.detected_cls_id_list_full,
96+
self.match_metric,
9497
self.nms_threshold,
98+
self.detected_masks_list_full,
9599
intelligent_sorter=self.intelligent_sorter
96-
) # for detection
97-
100+
)
101+
98102
# Apply filtering (nms output indeces) to the prediction lists
99103
self.filtered_confidences = [self.detected_conf_list_full[i] for i in self.filtered_indices]
100104
self.filtered_boxes = [self.detected_xyxy_list_full[i] for i in self.filtered_indices]
@@ -208,7 +212,7 @@ def intersect_over_smaller(mask, masks_list):
208212
ios_scores.append(ios)
209213
return torch.tensor(ios_scores)
210214

211-
def nms(
215+
def agnostic_nms(
212216
self,
213217
confidences: list,
214218
boxes: list,
@@ -218,15 +222,15 @@ def nms(
218222
intelligent_sorter=False,
219223
):
220224
"""
221-
Apply non-maximum suppression to avoid detecting too many
225+
Apply class-agnostic non-maximum suppression to avoid detecting too many
222226
overlapping bounding boxes for a given object.
223227
224228
Args:
225229
confidences (list): List of confidence scores.
226230
boxes (list): List of bounding boxes.
227231
match_metric (str): Matching metric, either 'IOU' or 'IOS'.
228232
nms_threshold (float): The threshold for match metric.
229-
masks (list, optional): List of masks. Defaults to None.
233+
masks (list): List of masks.
230234
intelligent_sorter (bool, optional): intelligent sorter
231235
232236
Returns:
@@ -320,7 +324,7 @@ def nms(
320324

321325
# If masks are provided and IoU based on bounding boxes is greater than 0,
322326
# calculate IoU for masks and keep the ones with IoU < nms_threshold
323-
if masks is not None and torch.any(match_metric_value > 0):
327+
if len(masks) > 0 and torch.any(match_metric_value > 0):
324328

325329
mask_mask = match_metric_value > 0
326330

patched_yolo_infer/nodes/MakeCropsDetectThem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class MakeCropsDetectThem:
1919
iou (float): IoU threshold for non-maximum suppression YOLOv8 of single crop.
2020
classes_list (List[int] or None): List of classes to filter detections. If None,
2121
all classes are considered. Defaults to None.
22-
segment (bool): Whether to perform segmentation (YOLOv8-seg).
22+
segment (bool): Whether to perform segmentation (YOLO-seg).
2323
shape_x (int): Size of the crop in the x-coordinate.
2424
shape_y (int): Size of the crop in the y-coordinate.
2525
overlap_x (int): Percentage of overlap along the x-axis.
@@ -42,7 +42,7 @@ class MakeCropsDetectThem:
4242
iou (float): IoU threshold for non-maximum suppression.
4343
classes_list (List[int] or None): List of classes to filter detections. If None,
4444
all classes are considered. Defaults to None.
45-
segment (bool): Whether to perform segmentation (YOLOv8-seg).
45+
segment (bool): Whether to perform segmentation (YOLO-seg).
4646
shape_x (int): Size of the crop in the x-coordinate.
4747
shape_y (int): Size of the crop in the y-coordinate.
4848
overlap_x (int): Percentage of overlap along the x-axis.

0 commit comments

Comments
 (0)