@@ -13,6 +13,8 @@ class CombineDetections:
1313 match_metric (str): Matching metric, either 'IOU' or 'IOS'.
1414 intelligent_sorter (bool): Enable sorting by area and rounded confidence parameter.
1515 If False, sorting will be done only by confidence (usual nms). (Dafault True)
16+ sorter_bins (int): Number of bins to use for intelligent_sorter. A smaller number of bins makes
17+ the NMS more reliant on object sizes rather than confidence scores. Defaults to 10.
1618
1719 Attributes:
1820 conf_treshold (float): Confidence threshold of yolov8.
@@ -22,6 +24,7 @@ class CombineDetections:
2224 nms_threshold (float): IOU/IOS threshold for non-maximum suppression.
2325 match_metric (str): Matching metric (IOU/IOS).
2426 intelligent_sorter (bool): Flag indicating whether sorting by area and confidence parameter is enabled.
27+ sorter_bins (int): Number of bins to use for intelligent_sorter.
2528 detected_conf_list_full (list): List of detected confidences.
2629 detected_xyxy_list_full (list): List of detected bounding boxes.
2730 detected_masks_list_full (list): List of detected masks.
@@ -42,7 +45,8 @@ def __init__(
4245 element_crops : MakeCropsDetectThem ,
4346 nms_threshold = 0.3 ,
4447 match_metric = 'IOS' ,
45- intelligent_sorter = True
48+ intelligent_sorter = True ,
49+ sorter_bins = 10
4650 ) -> None :
4751 self .conf_treshold = element_crops .conf
4852 self .class_names = element_crops .class_names_dict
@@ -55,6 +59,7 @@ def __init__(
5559 self .nms_threshold = nms_threshold # IOU or IOS treshold for NMS
5660 self .match_metric = match_metric
5761 self .intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
62+ self .sorter_bins = sorter_bins
5863
5964 # Combinate detections of all patches
6065 (
@@ -134,6 +139,31 @@ def combinate_detections(self, crops):
134139
135140 return detected_conf , detected_xyxy , detected_masks , detected_cls , detected_polygons
136141
142+ @staticmethod
143+ def average_to_bound (confidences , N = 10 ):
144+ """
145+ Bins the given confidences into N equal intervals between 0 and 1,
146+ and rounds each confidence to the left boundary of the corresponding bin.
147+
148+ Parameters:
149+ confidences (list or np.array): List of confidence values to be binned.
150+ N (int, optional): Number of bins to use. Defaults to 10.
151+
152+ Returns:
153+ list: List of rounded confidence values, each bound to the left boundary of its bin.
154+ """
155+ # Create the bounds
156+ step = 1 / N
157+ bounds = np .arange (0 , 1 + step , step )
158+
159+ # Use np.digitize to determine the corresponding bin for each value
160+ indices = np .digitize (confidences , bounds , right = True ) - 1
161+
162+ # Bind values to the left boundary of the corresponding bin
163+ averaged_confidences = np .round (bounds [indices ], 2 )
164+
165+ return averaged_confidences .tolist ()
166+
137167 @staticmethod
138168 def intersect_over_union (mask , masks_list ):
139169 """
@@ -224,7 +254,7 @@ def nms(
224254 order = torch .tensor (
225255 sorted (
226256 range (len (confidences )),
227- key = lambda k : (round (confidences [k ].item (), 1 ), areas [k ]),
257+ key = lambda k : (self . average_to_bound (confidences [k ].item (), self . sorter_bins ), areas [k ]),
228258 reverse = False ,
229259 )
230260 )
0 commit comments