Skip to content

Commit b4e0f38

Browse files
committed
multiple elements can be processed
1 parent ac112c6 commit b4e0f38

1 file changed

Lines changed: 45 additions & 9 deletions

File tree

patched_yolo_infer/nodes/CombineDetections.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Union, List
12
import torch
23
import numpy as np
34
from .MakeCropsDetectThem import MakeCropsDetectThem
@@ -46,26 +47,61 @@ class CombineDetections:
4647

4748
def __init__(
4849
self,
49-
element_crops: MakeCropsDetectThem,
50+
element_crops: Union[MakeCropsDetectThem, List[MakeCropsDetectThem]],
5051
nms_threshold=0.3,
5152
match_metric='IOS',
5253
intelligent_sorter=True,
5354
sorter_bins=5,
5455
class_agnostic_nms=True
5556
) -> None:
56-
self.class_names = element_crops.class_names_dict
57-
self.crops = element_crops.crops # List to store the CropElement objects
58-
if element_crops.resize_initial_size:
59-
self.image = element_crops.crops[0].source_image
60-
else:
61-
self.image = element_crops.crops[0].source_image_resized
6257

6358
self.nms_threshold = nms_threshold # IOU or IOS treshold for NMS
6459
self.match_metric = match_metric
6560
self.intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter
6661
self.sorter_bins = sorter_bins
6762
self.class_agnostic_nms = class_agnostic_nms
6863

64+
# Check if element_crops is a list
65+
if isinstance(element_crops, list):
66+
# Ensure all elements in the list have the same source_image and other params
67+
first_image = element_crops[0].crops[0].source_image
68+
first_element_segment_status = element_crops[0].segment
69+
first_element_memory_optimize_status = element_crops[0].memory_optimize
70+
for element in element_crops:
71+
if not np.array_equal(element.crops[0].source_image, first_image):
72+
raise ValueError(
73+
"The source images in element_crops differ, "
74+
"so combining results from these objects is not possible."
75+
)
76+
if not element.resize_initial_size:
77+
raise ValueError(
78+
"When working with a list of element_crops, "
79+
"resize_initial_size should be True everywhere."
80+
)
81+
if (
82+
first_element_segment_status != element.segment
83+
or first_element_memory_optimize_status != element.memory_optimize
84+
):
85+
raise ValueError(
86+
"The segment or memory_optimize attributes of element_crops differ, "
87+
"so processing cannot be performed."
88+
)
89+
90+
self.class_names = element_crops[0].class_names_dict
91+
self.crops = [crop for element in element_crops for crop in element.crops]
92+
self.image = element_crops[0].crops[0].source_image
93+
self.segment = element_crops[0].segment
94+
self.memory_optimize = element_crops[0].memory_optimize
95+
else:
96+
self.class_names = element_crops.class_names_dict
97+
self.crops = element_crops.crops # List to store the CropElement objects
98+
if element_crops.resize_initial_size:
99+
self.image = element_crops.crops[0].source_image
100+
else:
101+
self.image = element_crops.crops[0].source_image_resized
102+
self.segment = element_crops.segment
103+
self.memory_optimize = element_crops.memory_optimize
104+
69105
# Combinate detections of all patches
70106
(
71107
self.detected_conf_list_full,
@@ -108,13 +144,13 @@ def __init__(
108144
self.filtered_classes_names = [self.detected_cls_names_list_full[i] for i in self.filtered_indices]
109145

110146
# Masks filtering:
111-
if element_crops.segment and not element_crops.memory_optimize:
147+
if self.segment and not self.memory_optimize:
112148
self.filtered_masks = [self.detected_masks_list_full[i] for i in self.filtered_indices]
113149
else:
114150
self.filtered_masks = []
115151

116152
# Polygons filtering:
117-
if element_crops.segment and element_crops.memory_optimize:
153+
if self.segment and self.memory_optimize:
118154
self.filtered_polygons = [self.detected_polygons_list_full[i] for i in self.filtered_indices]
119155
else:
120156
self.filtered_polygons = []

0 commit comments

Comments
 (0)