Skip to content

Commit 6bf90c1

Browse files
committed
added nms using masks
1 parent 74fdc54 commit 6bf90c1

9 files changed

Lines changed: 200 additions & 318 deletions

GettyImages-1323764948.jpg

715 KB
Loading
948 KB
Loading

examples/example_extra_functions.ipynb

Lines changed: 30 additions & 16 deletions
Large diffs are not rendered by default.

examples/example_patch_based_inference.ipynb

Lines changed: 65 additions & 281 deletions
Large diffs are not rendered by default.

foto.jpg

526 KB
Loading

image-test.jpg

9.45 MB
Loading

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 105 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
2+
import numpy as np
33
from .MakeCropsDetectThem import MakeCropsDetectThem
44

55

@@ -48,6 +48,7 @@ def __init__(
4848

4949
self.nms_threshold = nms_threshold # IoU treshold for NMS
5050
self.match_metric = match_metric
51+
# seg mode
5152
(
5253
self.detected_conf_list_full,
5354
self.detected_xyxy_list_full,
@@ -58,14 +59,28 @@ def __init__(
5859
self.detected_cls_names_list_full = [
5960
self.class_names[value] for value in self.detected_cls_id_list_full
6061
] # make str list
61-
62-
# Invoke the NMS method for filtering predictions
63-
self.filtered_indices = self.nms(
64-
self.detected_conf_list_full,
65-
self.detected_xyxy_list_full,
66-
self.match_metric,
67-
self.nms_threshold
68-
)
62+
63+
# Invoke the NMS for segmentation masks method for filtering predictions
64+
if len(self.detected_masks_list_full)>0:
65+
66+
self.filtered_indices = self.nms(
67+
68+
self.detected_conf_list_full,
69+
self.detected_xyxy_list_full,
70+
self.match_metric,
71+
self.nms_threshold,
72+
self.detected_masks_list_full
73+
74+
)
75+
else:
76+
# Invoke the NMS method for filtering prediction
77+
self.filtered_indices = self.nms(
78+
79+
self.detected_conf_list_full,
80+
self.detected_xyxy_list_full,
81+
self.match_metric,
82+
self.nms_threshold
83+
)
6984

7085
# Apply filtering to the prediction lists
7186
self.filtered_confidences = [self.detected_conf_list_full[i] for i in self.filtered_indices]
@@ -102,12 +117,53 @@ def combinate_detections(self, crops):
102117

103118
return detected_conf, detected_xyxy, detected_masks, detected_cls
104119

105-
def nms(self,
106-
confidences: list,
107-
boxes: list,
108-
match_metric,
109-
nms_threshold,
110-
):
120+
121+
122+
@staticmethod
123+
def intersect_over_union(mask, masks_list):
124+
"""
125+
Compute Intersection over Union (IoU) scores for a given mask against a list of masks.
126+
127+
Args:
128+
mask (np.ndarray): Binary mask to compare.
129+
masks_list (list of np.ndarray): List of binary masks for comparison.
130+
131+
Returns:
132+
torch.Tensor: IoU scores for each mask in masks_list compared to the input mask.
133+
"""
134+
iou_scores = []
135+
for other_mask in masks_list:
136+
# Compute intersection and union
137+
intersection = np.logical_and(mask, other_mask).sum()
138+
union = np.logical_or(mask, other_mask).sum()
139+
# Compute IoU score, avoiding division by zero
140+
iou = intersection / union if union != 0 else 0
141+
iou_scores.append(iou)
142+
return torch.tensor(iou_scores)
143+
144+
@staticmethod
145+
def intersect_over_smaller(mask, masks_list):
146+
"""
147+
Compute Intersection over Smaller area scores for a given mask against a list of masks.
148+
149+
Args:
150+
mask (np.ndarray): Binary mask to compare.
151+
masks_list (list of np.ndarray): List of binary masks for comparison.
152+
153+
Returns:
154+
torch.Tensor: IoU scores for each mask in masks_list compared to the input mask, calculated over the smaller area.
155+
"""
156+
iou_scores = []
157+
for other_mask in masks_list:
158+
# Compute intersection and area of smaller mask
159+
intersection = np.logical_and(mask, other_mask).sum()
160+
smaller_area = min(mask.sum(), other_mask.sum())
161+
# Compute IoU score over smaller area, avoiding division by zero
162+
iou = intersection / smaller_area if smaller_area != 0 else 0
163+
iou_scores.append(iou)
164+
return torch.tensor(iou_scores)
165+
166+
def nms(self, confidences: list, boxes: list, match_metric, nms_threshold, masks=None):
111167
"""
112168
Apply non-maximum suppression to avoid detecting too many
113169
overlapping bounding boxes for a given object.
@@ -117,6 +173,7 @@ def nms(self,
117173
boxes (list): List of bounding boxes.
118174
match_metric (str): Matching metric, either 'IOU' or 'IOS'.
119175
nms_threshold (float): The threshold for match metric.
176+
masks (list, optional): List of masks. Defaults to None.
120177
121178
Returns:
122179
list: List of filtered indexes.
@@ -139,7 +196,7 @@ def nms(self,
139196

140197
# Sort the prediction boxes according to their confidence scores
141198
order = confidences.argsort()
142-
199+
143200
# Initialise an empty list for filtered prediction boxes
144201
keep = []
145202

@@ -182,7 +239,7 @@ def nms(self,
182239

183240
# Find the areas of BBoxes
184241
rem_areas = torch.index_select(areas, dim=0, index=order)
185-
242+
186243
# Calculate the distance between centers of the boxes
187244
cx = (x1[idx] + x2[idx]) / 2
188245
cy = (y1[idx] + y2[idx]) / 2
@@ -195,7 +252,7 @@ def nms(self,
195252
union = (rem_areas - inter) + areas[idx]
196253
# Find the IoU of every prediction
197254
match_metric_value = inter / union
198-
255+
199256
elif match_metric == "IOS":
200257
# Find the smaller area of every prediction with the prediction
201258
smaller = torch.min(rem_areas, areas[idx])
@@ -214,8 +271,35 @@ def nms(self,
214271
else:
215272
raise ValueError("Unknown matching metric")
216273

217-
# Keep the boxes with IoU/IoS less than threshold
218-
mask = match_metric_value < nms_threshold
219-
order = order[mask]
274+
# If masks are provided and IoU based on bounding boxes is greater than 0,
275+
# calculate IoU for masks and keep the ones with IoU < nms_threshold
276+
if masks is not None and torch.any(match_metric_value > 0):
277+
278+
mask_mask = match_metric_value > 0
279+
280+
order_2 = order[mask_mask]
281+
filtered_masks = [masks[i] for i in order_2]
282+
283+
if match_metric == "IOU":
284+
mask_iou = self.intersect_over_union(masks[idx], filtered_masks)
285+
mask_mask = mask_iou > nms_threshold
286+
287+
elif match_metric == "IOS":
288+
mask_iou = self.intersect_over_smaller(masks[idx], filtered_masks)
289+
mask_mask = mask_iou > nms_threshold
290+
291+
order_2 = order_2[mask_mask]
292+
inverse_mask = ~torch.isin(order, order_2)
293+
294+
# Оставить только те значения order, которые не содержатся в order_2
295+
order = order[inverse_mask]
296+
297+
else:
298+
# Keep the boxes with IoU/IoS less than threshold
299+
mask = match_metric_value < nms_threshold
300+
301+
order = order[mask]
220302

221303
return keep
304+
305+

stones.jpg

329 KB
Loading

stones2.jpg

267 KB
Loading

0 commit comments

Comments
 (0)