|
| 1 | +from typing import Union, List |
1 | 2 | import torch |
2 | 3 | import numpy as np |
3 | 4 | from .MakeCropsDetectThem import MakeCropsDetectThem |
@@ -46,26 +47,61 @@ class CombineDetections: |
46 | 47 |
|
47 | 48 | def __init__( |
48 | 49 | self, |
49 | | - element_crops: MakeCropsDetectThem, |
| 50 | + element_crops: Union[MakeCropsDetectThem, List[MakeCropsDetectThem]], |
50 | 51 | nms_threshold=0.3, |
51 | 52 | match_metric='IOS', |
52 | 53 | intelligent_sorter=True, |
53 | 54 | sorter_bins=5, |
54 | 55 | class_agnostic_nms=True |
55 | 56 | ) -> None: |
56 | | - self.class_names = element_crops.class_names_dict |
57 | | - self.crops = element_crops.crops # List to store the CropElement objects |
58 | | - if element_crops.resize_initial_size: |
59 | | - self.image = element_crops.crops[0].source_image |
60 | | - else: |
61 | | - self.image = element_crops.crops[0].source_image_resized |
62 | 57 |
|
63 | 58 | self.nms_threshold = nms_threshold # IOU or IOS treshold for NMS |
64 | 59 | self.match_metric = match_metric |
65 | 60 | self.intelligent_sorter = intelligent_sorter # enable sorting by area and confidence parameter |
66 | 61 | self.sorter_bins = sorter_bins |
67 | 62 | self.class_agnostic_nms = class_agnostic_nms |
68 | 63 |
|
| 64 | + # Check if element_crops is a list |
| 65 | + if isinstance(element_crops, list): |
| 66 | + # Ensure all elements in the list have the same source_image and other params |
| 67 | + first_image = element_crops[0].crops[0].source_image |
| 68 | + first_element_segment_status = element_crops[0].segment |
| 69 | + first_element_memory_optimize_status = element_crops[0].memory_optimize |
| 70 | + for element in element_crops: |
| 71 | + if not np.array_equal(element.crops[0].source_image, first_image): |
| 72 | + raise ValueError( |
| 73 | + "The source images in element_crops differ, " |
| 74 | + "so combining results from these objects is not possible." |
| 75 | + ) |
| 76 | + if not element.resize_initial_size: |
| 77 | + raise ValueError( |
| 78 | + "When working with a list of element_crops, " |
| 79 | + "resize_initial_size should be True everywhere." |
| 80 | + ) |
| 81 | + if ( |
| 82 | + first_element_segment_status != element.segment |
| 83 | + or first_element_memory_optimize_status != element.memory_optimize |
| 84 | + ): |
| 85 | + raise ValueError( |
| 86 | + "The segment or memory_optimize attributes of element_crops differ, " |
| 87 | + "so processing cannot be performed." |
| 88 | + ) |
| 89 | + |
| 90 | + self.class_names = element_crops[0].class_names_dict |
| 91 | + self.crops = [crop for element in element_crops for crop in element.crops] |
| 92 | + self.image = element_crops[0].crops[0].source_image |
| 93 | + self.segment = element_crops[0].segment |
| 94 | + self.memory_optimize = element_crops[0].memory_optimize |
| 95 | + else: |
| 96 | + self.class_names = element_crops.class_names_dict |
| 97 | + self.crops = element_crops.crops # List to store the CropElement objects |
| 98 | + if element_crops.resize_initial_size: |
| 99 | + self.image = element_crops.crops[0].source_image |
| 100 | + else: |
| 101 | + self.image = element_crops.crops[0].source_image_resized |
| 102 | + self.segment = element_crops.segment |
| 103 | + self.memory_optimize = element_crops.memory_optimize |
| 104 | + |
69 | 105 | # Combinate detections of all patches |
70 | 106 | ( |
71 | 107 | self.detected_conf_list_full, |
@@ -108,13 +144,13 @@ def __init__( |
108 | 144 | self.filtered_classes_names = [self.detected_cls_names_list_full[i] for i in self.filtered_indices] |
109 | 145 |
|
110 | 146 | # Masks filtering: |
111 | | - if element_crops.segment and not element_crops.memory_optimize: |
| 147 | + if self.segment and not self.memory_optimize: |
112 | 148 | self.filtered_masks = [self.detected_masks_list_full[i] for i in self.filtered_indices] |
113 | 149 | else: |
114 | 150 | self.filtered_masks = [] |
115 | 151 |
|
116 | 152 | # Polygons filtering: |
117 | | - if element_crops.segment and element_crops.memory_optimize: |
| 153 | + if self.segment and self.memory_optimize: |
118 | 154 | self.filtered_polygons = [self.detected_polygons_list_full[i] for i in self.filtered_indices] |
119 | 155 | else: |
120 | 156 | self.filtered_polygons = [] |
|
0 commit comments