11import torch
2-
2+ import numpy as np
33from .MakeCropsDetectThem import MakeCropsDetectThem
44
55
@@ -48,6 +48,7 @@ def __init__(
4848
4949 self .nms_threshold = nms_threshold # IoU treshold for NMS
5050 self .match_metric = match_metric
51+ # seg mode
5152 (
5253 self .detected_conf_list_full ,
5354 self .detected_xyxy_list_full ,
@@ -58,14 +59,28 @@ def __init__(
5859 self .detected_cls_names_list_full = [
5960 self .class_names [value ] for value in self .detected_cls_id_list_full
6061 ] # make str list
61-
62- # Invoke the NMS method for filtering predictions
63- self .filtered_indices = self .nms (
64- self .detected_conf_list_full ,
65- self .detected_xyxy_list_full ,
66- self .match_metric ,
67- self .nms_threshold
68- )
62+
63+ # Invoke the NMS for segmentation masks method for filtering predictions
64+ if len (self .detected_masks_list_full )> 0 :
65+
66+ self .filtered_indices = self .nms (
67+
68+ self .detected_conf_list_full ,
69+ self .detected_xyxy_list_full ,
70+ self .match_metric ,
71+ self .nms_threshold ,
72+ self .detected_masks_list_full
73+
74+ )
75+ else :
76+ # Invoke the NMS method for filtering prediction
77+ self .filtered_indices = self .nms (
78+
79+ self .detected_conf_list_full ,
80+ self .detected_xyxy_list_full ,
81+ self .match_metric ,
82+ self .nms_threshold
83+ )
6984
7085 # Apply filtering to the prediction lists
7186 self .filtered_confidences = [self .detected_conf_list_full [i ] for i in self .filtered_indices ]
@@ -102,12 +117,53 @@ def combinate_detections(self, crops):
102117
103118 return detected_conf , detected_xyxy , detected_masks , detected_cls
104119
105- def nms (self ,
106- confidences : list ,
107- boxes : list ,
108- match_metric ,
109- nms_threshold ,
110- ):
120+
121+
122+ @staticmethod
123+ def intersect_over_union (mask , masks_list ):
124+ """
125+ Compute Intersection over Union (IoU) scores for a given mask against a list of masks.
126+
127+ Args:
128+ mask (np.ndarray): Binary mask to compare.
129+ masks_list (list of np.ndarray): List of binary masks for comparison.
130+
131+ Returns:
132+ torch.Tensor: IoU scores for each mask in masks_list compared to the input mask.
133+ """
134+ iou_scores = []
135+ for other_mask in masks_list :
136+ # Compute intersection and union
137+ intersection = np .logical_and (mask , other_mask ).sum ()
138+ union = np .logical_or (mask , other_mask ).sum ()
139+ # Compute IoU score, avoiding division by zero
140+ iou = intersection / union if union != 0 else 0
141+ iou_scores .append (iou )
142+ return torch .tensor (iou_scores )
143+
144+ @staticmethod
145+ def intersect_over_smaller (mask , masks_list ):
146+ """
147+ Compute Intersection over Smaller area scores for a given mask against a list of masks.
148+
149+ Args:
150+ mask (np.ndarray): Binary mask to compare.
151+ masks_list (list of np.ndarray): List of binary masks for comparison.
152+
153+ Returns:
154+ torch.Tensor: IoU scores for each mask in masks_list compared to the input mask, calculated over the smaller area.
155+ """
156+ iou_scores = []
157+ for other_mask in masks_list :
158+ # Compute intersection and area of smaller mask
159+ intersection = np .logical_and (mask , other_mask ).sum ()
160+ smaller_area = min (mask .sum (), other_mask .sum ())
161+ # Compute IoU score over smaller area, avoiding division by zero
162+ iou = intersection / smaller_area if smaller_area != 0 else 0
163+ iou_scores .append (iou )
164+ return torch .tensor (iou_scores )
165+
166+ def nms (self , confidences : list , boxes : list , match_metric , nms_threshold , masks = None ):
111167 """
112168 Apply non-maximum suppression to avoid detecting too many
113169 overlapping bounding boxes for a given object.
@@ -117,6 +173,7 @@ def nms(self,
117173 boxes (list): List of bounding boxes.
118174 match_metric (str): Matching metric, either 'IOU' or 'IOS'.
119175 nms_threshold (float): The threshold for match metric.
176+ masks (list, optional): List of masks. Defaults to None.
120177
121178 Returns:
122179 list: List of filtered indexes.
@@ -139,7 +196,7 @@ def nms(self,
139196
140197 # Sort the prediction boxes according to their confidence scores
141198 order = confidences .argsort ()
142-
199+
143200 # Initialise an empty list for filtered prediction boxes
144201 keep = []
145202
@@ -182,7 +239,7 @@ def nms(self,
182239
183240 # Find the areas of BBoxes
184241 rem_areas = torch .index_select (areas , dim = 0 , index = order )
185-
242+
186243 # Calculate the distance between centers of the boxes
187244 cx = (x1 [idx ] + x2 [idx ]) / 2
188245 cy = (y1 [idx ] + y2 [idx ]) / 2
@@ -195,7 +252,7 @@ def nms(self,
195252 union = (rem_areas - inter ) + areas [idx ]
196253 # Find the IoU of every prediction
197254 match_metric_value = inter / union
198-
255+
199256 elif match_metric == "IOS" :
200257 # Find the smaller area of every prediction with the prediction
201258 smaller = torch .min (rem_areas , areas [idx ])
@@ -214,8 +271,35 @@ def nms(self,
214271 else :
215272 raise ValueError ("Unknown matching metric" )
216273
217- # Keep the boxes with IoU/IoS less than threshold
218- mask = match_metric_value < nms_threshold
219- order = order [mask ]
274+ # If masks are provided and IoU based on bounding boxes is greater than 0,
275+ # calculate IoU for masks and keep the ones with IoU < nms_threshold
276+ if masks is not None and torch .any (match_metric_value > 0 ):
277+
278+ mask_mask = match_metric_value > 0
279+
280+ order_2 = order [mask_mask ]
281+ filtered_masks = [masks [i ] for i in order_2 ]
282+
283+ if match_metric == "IOU" :
284+ mask_iou = self .intersect_over_union (masks [idx ], filtered_masks )
285+ mask_mask = mask_iou > nms_threshold
286+
287+ elif match_metric == "IOS" :
288+ mask_iou = self .intersect_over_smaller (masks [idx ], filtered_masks )
289+ mask_mask = mask_iou > nms_threshold
290+
291+ order_2 = order_2 [mask_mask ]
292+ inverse_mask = ~ torch .isin (order , order_2 )
293+
294+ # Оставить только те значения order, которые не содержатся в order_2
295+ order = order [inverse_mask ]
296+
297+ else :
298+ # Keep the boxes with IoU/IoS less than threshold
299+ mask = match_metric_value < nms_threshold
300+
301+ order = order [mask ]
220302
221303 return keep
304+
305+
0 commit comments