Skip to content

Commit 9e0c79e

Browse files
committed
intelegence_sorter added
1 parent 96ecd1f commit 9e0c79e

1 file changed

Lines changed: 18 additions & 11 deletions

File tree

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ class CombineDetections:
1717
class_names (dict): Dictionary containing class names pf yolov8 model.
1818
crops (list): List to store the CropElement objects.
1919
image (np.ndarray): Source image in BGR.
20-
nms_threshold (float): IoU/IoS threshold for non-maximum suppression.
20+
nms_threshold (float): IOU/IOS threshold for non-maximum suppression.
2121
match_metric (str): Matching metric (IOU/IOS).
22+
intelegence_sorter (bool): Corting by area and confederation parameter
2223
detected_conf_list_full (list): List of detected confidences.
2324
detected_xyxy_list_full (list): List of detected bounding boxes.
2425
detected_masks_list_full (list): List of detected masks.
@@ -36,7 +37,8 @@ def __init__(
3637
self,
3738
element_crops: MakeCropsDetectThem,
3839
nms_threshold=0.3,
39-
match_metric='IOS'
40+
match_metric='IOS',
41+
intelegence_sorter=False
4042
) -> None:
4143
self.conf_treshold = element_crops.conf
4244
self.class_names = element_crops.class_names_dict
@@ -46,8 +48,10 @@ def __init__(
4648
else:
4749
self.image = element_crops.crops[0].source_image_resized
4850

49-
self.nms_threshold = nms_threshold # IoU treshold for NMS
50-
self.match_metric = match_metric
51+
self.nms_threshold = nms_threshold # IOU or IOS treshold for NMS
52+
self.match_metric = match_metric
53+
self.intelegence_sorter = intelegence_sorter # enable sorting by area and confederation parameter
54+
5155
# seg mode
5256
(
5357
self.detected_conf_list_full,
@@ -69,7 +73,8 @@ def __init__(
6973
self.detected_xyxy_list_full,
7074
self.match_metric,
7175
self.nms_threshold,
72-
self.detected_masks_list_full
76+
self.detected_masks_list_full,
77+
intelegence_sorter=self.intelegence_sorter
7378

7479
)
7580
else:
@@ -79,7 +84,8 @@ def __init__(
7984
self.detected_conf_list_full,
8085
self.detected_xyxy_list_full,
8186
self.match_metric,
82-
self.nms_threshold
87+
self.nms_threshold,
88+
intelegence_sorter=self.intelegence_sorter
8389
)
8490

8591
# Apply filtering to the prediction lists
@@ -161,7 +167,7 @@ def intersect_over_smaller(mask, masks_list):
161167
iou_scores.append(iou)
162168
return torch.tensor(iou_scores)
163169

164-
def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks=None):
170+
def nms(self, confidences: list,boxes: list, match_metric, nms_threshold, masks=None, intelegence_sorter=False):
165171
"""
166172
Apply non-maximum suppression to avoid detecting too many
167173
overlapping bounding boxes for a given object.
@@ -192,9 +198,10 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
192198
# Calculate area of every box
193199
areas = (x2 - x1) * (y2 - y1)
194200

195-
# Sort the prediction boxes according to their confidence scores
196-
order = confidences.argsort()
197-
201+
# Sort the prediction boxes according to their confidence scores and intelegence_sorter mode
202+
if intelegence_sorter: order = torch.tensor(sorted(range(len(confidences)),
203+
key=lambda k: (round(confidences[k].item(), 1), areas[k]), reverse=False))
204+
else: order = confidences.argsort()
198205
# Initialise an empty list for filtered prediction boxes
199206
keep = []
200207

@@ -285,7 +292,7 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
285292
elif match_metric == "IOS":
286293
mask_iou = self.intersect_over_smaller(masks[idx], filtered_masks)
287294
mask_mask = mask_iou > nms_threshold
288-
295+
#create a tensor of indences to delete in tensor order
289296
order_2 = order_2[mask_mask]
290297
inverse_mask = ~torch.isin(order, order_2)
291298

0 commit comments

Comments
 (0)