Skip to content

Commit 949aa34

Browse files
committed
bins for sortering
1 parent 0602deb commit 949aa34

1 file changed

Lines changed: 32 additions & 2 deletions

File tree

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class CombineDetections:
1313
match_metric (str): Matching metric, either 'IOU' or 'IOS'.
1414
intelligent_sorter (bool): Enable sorting by area and rounded confidence parameter.
1515
If False, sorting will be done only by confidence (usual nms). (Dafault True)
16+
sorter_bins (int): Number of bins to use for intelligent_sorter. A smaller number of bins makes
17+
the NMS more reliant on object sizes rather than confidence scores. Defaults to 10.
1618
1719
Attributes:
1820
conf_treshold (float): Confidence threshold of yolov8.
@@ -22,6 +24,7 @@ class CombineDetections:
2224
nms_threshold (float): IOU/IOS threshold for non-maximum suppression.
2325
match_metric (str): Matching metric (IOU/IOS).
2426
intelligent_sorter (bool): Flag indicating whether sorting by area and confidence parameter is enabled.
27+
sorter_bins (int): Number of bins to use for intelligent_sorter.
2528
detected_conf_list_full (list): List of detected confidences.
2629
detected_xyxy_list_full (list): List of detected bounding boxes.
2730
detected_masks_list_full (list): List of detected masks.
@@ -42,7 +45,8 @@ def __init__(
4245
element_crops: MakeCropsDetectThem,
4346
nms_threshold=0.3,
4447
match_metric='IOS',
45-
intelligent_sorter=True
48+
intelligent_sorter=True,
49+
sorter_bins=10
4650
) -> None:
4751
self.conf_treshold = element_crops.conf
4852
self.class_names = element_crops.class_names_dict
@@ -55,6 +59,7 @@ def __init__(
5559
self.nms_threshold = nms_threshold # IOU or IOS treshold for NMS
5660
self.match_metric = match_metric
5761
self.intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
62+
self.sorter_bins = sorter_bins
5863

5964
# Combinate detections of all patches
6065
(
@@ -134,6 +139,31 @@ def combinate_detections(self, crops):
134139

135140
return detected_conf, detected_xyxy, detected_masks, detected_cls, detected_polygons
136141

142+
@staticmethod
143+
def average_to_bound(confidences, N=10):
144+
"""
145+
Bins the given confidences into N equal intervals between 0 and 1,
146+
and rounds each confidence to the left boundary of the corresponding bin.
147+
148+
Parameters:
149+
confidences (list or np.array): List of confidence values to be binned.
150+
N (int, optional): Number of bins to use. Defaults to 10.
151+
152+
Returns:
153+
list: List of rounded confidence values, each bound to the left boundary of its bin.
154+
"""
155+
# Create the bounds
156+
step = 1 / N
157+
bounds = np.arange(0, 1 + step, step)
158+
159+
# Use np.digitize to determine the corresponding bin for each value
160+
indices = np.digitize(confidences, bounds, right=True) - 1
161+
162+
# Bind values to the left boundary of the corresponding bin
163+
averaged_confidences = np.round(bounds[indices], 2)
164+
165+
return averaged_confidences.tolist()
166+
137167
@staticmethod
138168
def intersect_over_union(mask, masks_list):
139169
"""
@@ -224,7 +254,7 @@ def nms(
224254
order = torch.tensor(
225255
sorted(
226256
range(len(confidences)),
227-
key=lambda k: (round(confidences[k].item(), 1), areas[k]),
257+
key=lambda k: (self.average_to_bound(confidences[k].item(), self.sorter_bins), areas[k]),
228258
reverse=False,
229259
)
230260
)

0 commit comments

Comments
 (0)