@@ -59,10 +59,10 @@ def __init__(
5959 self .detected_cls_names_list_full = [
6060 self .class_names [value ] for value in self .detected_cls_id_list_full
6161 ] # make str list
62-
62+
6363 # Invoke the NMS for segmentation masks method for filtering predictions
6464 if len (self .detected_masks_list_full )> 0 :
65-
65+
6666 self .filtered_indices = self .nms (
6767
6868 self .detected_conf_list_full ,
@@ -117,8 +117,6 @@ def combinate_detections(self, crops):
117117
118118 return detected_conf , detected_xyxy , detected_masks , detected_cls
119119
120-
121-
122120 @staticmethod
123121 def intersect_over_union (mask , masks_list ):
124122 """
@@ -162,7 +160,7 @@ def intersect_over_smaller(mask, masks_list):
162160 iou = intersection / smaller_area if smaller_area != 0 else 0
163161 iou_scores .append (iou )
164162 return torch .tensor (iou_scores )
165-
163+
166164 def nms (self , confidences : list , boxes : list , match_metric , nms_threshold , masks = None ):
167165 """
168166 Apply non-maximum suppression to avoid detecting too many
@@ -196,7 +194,7 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
196194
197195 # Sort the prediction boxes according to their confidence scores
198196 order = confidences .argsort ()
199-
197+
200198 # Initialise an empty list for filtered prediction boxes
201199 keep = []
202200
@@ -239,7 +237,7 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
239237
240238 # Find the areas of BBoxes
241239 rem_areas = torch .index_select (areas , dim = 0 , index = order )
242-
240+
243241 # Calculate the distance between centers of the boxes
244242 cx = (x1 [idx ] + x2 [idx ]) / 2
245243 cy = (y1 [idx ] + y2 [idx ]) / 2
@@ -252,7 +250,7 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
252250 union = (rem_areas - inter ) + areas [idx ]
253251 # Find the IoU of every prediction
254252 match_metric_value = inter / union
255-
253+
256254 elif match_metric == "IOS" :
257255 # Find the smaller area of every prediction with the prediction
258256 smaller = torch .min (rem_areas , areas [idx ])
@@ -276,30 +274,28 @@ def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks
276274 if masks is not None and torch .any (match_metric_value > 0 ):
277275
278276 mask_mask = match_metric_value > 0
279-
277+
280278 order_2 = order [mask_mask ]
281279 filtered_masks = [masks [i ] for i in order_2 ]
282-
280+
283281 if match_metric == "IOU" :
284282 mask_iou = self .intersect_over_union (masks [idx ], filtered_masks )
285283 mask_mask = mask_iou > nms_threshold
286284
287285 elif match_metric == "IOS" :
288286 mask_iou = self .intersect_over_smaller (masks [idx ], filtered_masks )
289287 mask_mask = mask_iou > nms_threshold
290-
288+
291289 order_2 = order_2 [mask_mask ]
292290 inverse_mask = ~ torch .isin (order , order_2 )
293291
294- # Оставить только те значения order, которые не содержатся в order_2
292+ # Keep only those order values that are not contained in order_2
295293 order = order [inverse_mask ]
296294
297295 else :
298296 # Keep the boxes with IoU/IoS less than threshold
299297 mask = match_metric_value < nms_threshold
300-
298+
301299 order = order [mask ]
302300
303301 return keep
304-
305-
0 commit comments