@@ -17,8 +17,9 @@ class CombineDetections:
1717 class_names (dict): Dictionary containing class names pf yolov8 model.
1818 crops (list): List to store the CropElement objects.
1919 image (np.ndarray): Source image in BGR.
20- nms_threshold (float): IoU/IoS threshold for non-maximum suppression.
20+ nms_threshold (float): IOU/IOS threshold for non-maximum suppression.
2121 match_metric (str): Matching metric (IOU/IOS).
22+ intelegence_sorter (bool): Corting by area and confederation parameter
2223 detected_conf_list_full (list): List of detected confidences.
2324 detected_xyxy_list_full (list): List of detected bounding boxes.
2425 detected_masks_list_full (list): List of detected masks.
@@ -36,7 +37,8 @@ def __init__(
3637 self ,
3738 element_crops : MakeCropsDetectThem ,
3839 nms_threshold = 0.3 ,
39- match_metric = 'IOS'
40+ match_metric = 'IOS' ,
41+ intelegence_sorter = False
4042 ) -> None :
4143 self .conf_treshold = element_crops .conf
4244 self .class_names = element_crops .class_names_dict
@@ -46,8 +48,10 @@ def __init__(
4648 else :
4749 self .image = element_crops .crops [0 ].source_image_resized
4850
49- self .nms_threshold = nms_threshold # IoU treshold for NMS
50- self .match_metric = match_metric
51+ self .nms_threshold = nms_threshold # IOU or IOS treshold for NMS
52+ self .match_metric = match_metric
53+ self .intelegence_sorter = intelegence_sorter # enable sorting by area and confederation parameter
54+
5155 # seg mode
5256 (
5357 self .detected_conf_list_full ,
@@ -69,7 +73,8 @@ def __init__(
6973 self .detected_xyxy_list_full ,
7074 self .match_metric ,
7175 self .nms_threshold ,
72- self .detected_masks_list_full
76+ self .detected_masks_list_full ,
77+ intelegence_sorter = self .intelegence_sorter
7378
7479 )
7580 else :
@@ -79,7 +84,8 @@ def __init__(
7984 self .detected_conf_list_full ,
8085 self .detected_xyxy_list_full ,
8186 self .match_metric ,
82- self .nms_threshold
87+ self .nms_threshold ,
88+ intelegence_sorter = self .intelegence_sorter
8389 )
8490
8591 # Apply filtering to the prediction lists
@@ -161,7 +167,7 @@ def intersect_over_smaller(mask, masks_list):
161167 iou_scores .append (iou )
162168 return torch .tensor (iou_scores )
163169
164- def nms (self , confidences : list , boxes : list , match_metric , nms_threshold , masks = None ):
170+ def nms (self , confidences : list ,boxes : list , match_metric , nms_threshold , masks = None , intelegence_sorter = False ):
165171 """
166172 Apply non-maximum suppression to avoid detecting too many
167173 overlapping bounding boxes for a given object.
@@ -192,9 +198,10 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
192198 # Calculate area of every box
193199 areas = (x2 - x1 ) * (y2 - y1 )
194200
195- # Sort the prediction boxes according to their confidence scores
196- order = confidences .argsort ()
197-
201+ # Sort the prediction boxes according to their confidence scores and intelegence_sorter mode
202+ if intelegence_sorter : order = torch .tensor (sorted (range (len (confidences )),
203+ key = lambda k : (round (confidences [k ].item (), 1 ), areas [k ]), reverse = False ))
204+ else : order = confidences .argsort ()
198205 # Initialise an empty list for filtered prediction boxes
199206 keep = []
200207
@@ -285,7 +292,7 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
285292 elif match_metric == "IOS" :
286293 mask_iou = self .intersect_over_smaller (masks [idx ], filtered_masks )
287294 mask_mask = mask_iou > nms_threshold
288-
295+ #create a tensor of indences to delete in tensor order
289296 order_2 = order_2 [mask_mask ]
290297 inverse_mask = ~ torch .isin (order , order_2 )
291298
0 commit comments