Skip to content

Commit 9bce31b

Browse files
committed
not agnostic nms
1 parent 056822f commit 9bce31b

2 files changed

Lines changed: 44 additions & 35 deletions

File tree

patched_yolo_infer/functions_extra.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def visualize_results_usual_yolo_inference(
5858
axis_off (bool): If True, axis is turned off in the final visualization.
5959
show_classes_list (list): If empty, visualize all classes. Otherwise, visualize only classes in the list.
6060
inference_extra_args (dict/None): Dictionary with extra ultralytics inference parameters.
61-
list_of_class_colors (list/None): A list of tuples representing the colors for each class in BGR format. If provided,
62-
these colors will be used for displaying the classes instead of random colors. The number of tuples
63-
in the list must match the number of possible classes in the network.
61+
list_of_class_colors (list/None): A list of tuples representing the colors for each class in BGR format.
62+
If provided, these colors will be used for displaying the classes instead of random colors.
63+
The number of tuples in the list must match the number of possible classes in the network.
6464
return_image_array (bool): If True, the function returns the image bgr array instead of displaying it.
6565
Default is False.
6666
@@ -314,10 +314,11 @@ def visualize_results(
314314
show_confidences (bool): If true and show_class=True, confidences near class are visualized. Default is False.
315315
axis_off (bool): If true, axis is turned off in the final visualization. Default is True.
316316
show_classes_list (list): If empty, visualize all classes. Otherwise, visualize only classes in the list.
317-
list_of_class_colors (list/None): A list of tuples representing the colors for each class in BGR format. If provided,
318-
these colors will be used for displaying the classes instead of random colors. The number of tuples
319-
in the list must match the number of possible classes in the network.
320-
return_image_array (bool): If True, the function returns the image bgr array instead of displaying it. Default is False.
317+
list_of_class_colors (list/None): A list of tuples representing the colors for each class in BGR format.
318+
If provided, these colors will be used for displaying the classes instead of random colors.
319+
The number of tuples in the list must match the number of possible classes in the network.
320+
return_image_array (bool): If True, the function returns the image bgr array instead of displaying it.
321+
Default is False.
321322
322323
Returns:
323324
None/np.array
@@ -476,8 +477,8 @@ def basic_crop_size_calculation(width, height):
476477
height (int): The height of the image in pixels.
477478
478479
Returns:
479-
tuple: A tuple containing the crop size in the x direction (crop_shape_x), crop size in the y direction (crop_shape_y),
480-
overlap in the x direction (crop_overlap_x), and overlap in the y direction (crop_overlap_y).
480+
tuple: A tuple containing the crop size in the x direction (crop_shape_x), crop size in the y direction
481+
(crop_shape_y), overlap in the x direction (crop_overlap_x), and overlap in the y direction (crop_overlap_y).
481482
"""
482483
total_pixels = width * height
483484

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class CombineDetections:
4040
filtered_classes_id (list): List of class IDs after non-maximum suppression.
4141
filtered_classes_names (list): List of class names after non-maximum suppression.
4242
filtered_masks (list): List of filtered (after nms) masks if segmentation is enabled.
43-
filtered_polygons (list): List of filtered (after nms) polygons if segmentation and memory optimization are enabled.
43+
filtered_polygons (list): List of filtered (after nms) polygons if segmentation and
44+
memory optimization are enabled.
4445
"""
4546

4647
def __init__(
@@ -64,7 +65,7 @@ def __init__(
6465
self.intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
6566
self.sorter_bins = sorter_bins
6667
self.class_agnostic_nms = class_agnostic_nms
67-
68+
6869
# Combinate detections of all patches
6970
(
7071
self.detected_conf_list_full,
@@ -99,7 +100,7 @@ def __init__(
99100
self.detected_masks_list_full,
100101
intelligent_sorter=self.intelligent_sorter
101102
)
102-
103+
103104
# Apply filtering (nms output indeces) to the prediction lists
104105
self.filtered_confidences = [self.detected_conf_list_full[i] for i in self.filtered_indices]
105106
self.filtered_boxes = [self.detected_xyxy_list_full[i] for i in self.filtered_indices]
@@ -111,7 +112,7 @@ def __init__(
111112
self.filtered_masks = [self.detected_masks_list_full[i] for i in self.filtered_indices]
112113
else:
113114
self.filtered_masks = []
114-
115+
115116
# Polygons filtering:
116117
if element_crops.segment and element_crops.memory_optimize:
117118
self.filtered_polygons = [self.detected_polygons_list_full[i] for i in self.filtered_indices]
@@ -160,13 +161,13 @@ def average_to_bound(confidences, N=10):
160161
# Create the bounds
161162
step = 1 / N
162163
bounds = np.arange(0, 1 + step, step)
163-
164+
164165
# Use np.digitize to determine the corresponding bin for each value
165166
indices = np.digitize(confidences, bounds, right=True) - 1
166-
167+
167168
# Bind values to the left boundary of the corresponding bin
168169
averaged_confidences = np.round(bounds[indices], 2)
169-
170+
170171
return averaged_confidences.tolist()
171172

172173
@staticmethod
@@ -201,7 +202,8 @@ def intersect_over_smaller(mask, masks_list):
201202
masks_list (list of np.ndarray): List of binary masks for comparison.
202203
203204
Returns:
204-
torch.Tensor: IoU scores for each mask in masks_list compared to the input mask, calculated over the smaller area.
205+
torch.Tensor: IoU scores for each mask in masks_list compared to the input mask,
206+
calculated over the smaller area.
205207
"""
206208
ios_scores = []
207209
for other_mask in masks_list:
@@ -234,7 +236,9 @@ def nms(
234236
nms_threshold (float): The threshold for match metric.
235237
masks (list): List of masks.
236238
intelligent_sorter (bool, optional): intelligent sorter
237-
cls_indexes (torch.Tensor): indexes from network detections corresponding to the defined class, uses in case of not agnostic nms
239+
cls_indexes (torch.Tensor): indexes from network detections corresponding
240+
to the defined class, uses in case of not agnostic nms
241+
238242
Returns:
239243
list: List of filtered indexes.
240244
"""
@@ -256,7 +260,10 @@ def nms(
256260
order = torch.tensor(
257261
sorted(
258262
range(len(confidences)),
259-
key=lambda k: (self.average_to_bound(confidences[k].item(), self.sorter_bins), areas[k]),
263+
key=lambda k: (
264+
self.average_to_bound(confidences[k].item(), self.sorter_bins),
265+
areas[k],
266+
),
260267
reverse=False,
261268
)
262269
)
@@ -351,7 +358,7 @@ def nms(
351358
if cls_indexes is not None:
352359
keep = [cls_indexes[i] for i in keep]
353360
return keep
354-
361+
355362
def not_agnostic_nms(
356363
self,
357364
detected_cls_id_list_full,
@@ -362,8 +369,9 @@ def not_agnostic_nms(
362369
detected_masks_list_full,
363370
intelligent_sorter
364371
):
365-
'''
366-
Performs Non-Maximum Suppression (NMS) in a non-agnostic manner, where NMS is applied separately to each class.
372+
'''
373+
Performs Non-Maximum Suppression (NMS) in a non-agnostic manner, where NMS
374+
is applied separately to each class.
367375
368376
Args:
369377
detected_cls_id_list_full (torch.Tensor): tensor containing the class IDs for each detected object.
@@ -373,17 +381,21 @@ def not_agnostic_nms(
373381
nms_threshold (float): the threshold for match metric.
374382
detected_masks_list_full (torch.Tensor): List of masks.
375383
intelligent_sorter (bool, optional): intelligent sorter
384+
376385
Returns:
377-
List[int]: A list of indices representing the detections that are kept after applying NMS for each class separately.
386+
List[int]: A list of indices representing the detections that are kept after applying
387+
NMS for each class separately.
378388
379389
Notes:
380-
- This method performs NMS separately for each class, which helps in reducing false positives within each class.
381-
- The `nms` function is assumed to be defined elsewhere in the class and is responsible for performing the actual NMS operation.
390+
- This method performs NMS separately for each class, which helps in
391+
reducing false positives within each class.
392+
- If in your scenario, an object of one class can physically be inside
393+
an object of another class, you should definitely use this non-agnostic nms
382394
'''
383-
all_keeps = []
384-
for cls in torch.unique(detected_cls_id_list_full):
385-
cls_indexes = torch.where(detected_cls_id_list_full==cls)[0]
386-
keep_indexes = self.nms(
395+
all_keeps = []
396+
for cls in torch.unique(detected_cls_id_list_full):
397+
cls_indexes = torch.where(detected_cls_id_list_full==cls)[0]
398+
keep_indexes = self.nms(
387399
detected_conf_list_full[cls_indexes],
388400
detected_xyxy_list_full[cls_indexes],
389401
match_metric,
@@ -392,9 +404,5 @@ def not_agnostic_nms(
392404
intelligent_sorter,
393405
cls_indexes
394406
)
395-
all_keeps.extend(keep_indexes)
396-
return all_keeps
397-
398-
399-
400-
407+
all_keeps.extend(keep_indexes)
408+
return all_keeps

0 commit comments

Comments
 (0)