Skip to content

Commit 450fccb

Browse files
committed
default tqdm progress callback function
1 parent 43fe87f commit 450fccb

1 file changed

Lines changed: 49 additions & 3 deletions

File tree

patched_yolo_infer/nodes/MakeCropsDetectThem.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ class MakeCropsDetectThem:
2727
overlap_x (int): Percentage of overlap along the x-axis.
2828
overlap_y (int): Percentage of overlap along the y-axis.
2929
show_crops (bool): Whether to visualize the cropping.
30+
show_processing_status (bool): Whether to show the processing status using tqdm.
3031
resize_initial_size (bool): Whether to resize the results to the original
3132
image size (ps: slow operation).
3233
model: Pre-initialized model object. If provided, the model will be used directly
3334
instead of loading from model_path.
3435
memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
3536
batch_inference (bool): Batch inference of image crops through a neural network instead of
3637
sequential passes of crops (ps: Faster inference, higher memory use)
37-
progress_callback (function): Optional callback function, (task: str, current: int, total: int)
38+
progress_callback (function): Optional custom callback function, (task: str, current: int, total: int)
3839
inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
3940
4041
Attributes:
@@ -52,13 +53,14 @@ class MakeCropsDetectThem:
5253
overlap_y (int): Percentage of overlap along the y-axis.
5354
crops (list): List to store the CropElement objects.
5455
show_crops (bool): Whether to visualize the cropping.
56+
show_processing_status (bool): Whether to show the processing status using tqdm.
5557
resize_initial_size (bool): Whether to resize the results to the original
5658
image size (ps: slow operation).
5759
class_names_dict (dict): Dictionary containing class names of the YOLO model.
5860
memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
5961
batch_inference (bool): Batch inference of image crops through a neural network instead of
6062
sequential passes of crops (ps: Faster inference, higher memory use)
61-
progress_callback (function): Optional callback function, (task: str, current: int, total: int)
63+
progress_callback (function): Optional custom callback function, (task: str, current: int, total: int)
6264
inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
6365
"""
6466
def __init__(
@@ -75,13 +77,27 @@ def __init__(
7577
overlap_x=25,
7678
overlap_y=25,
7779
show_crops=False,
80+
show_processing_status=True,
7881
resize_initial_size=True,
7982
model=None,
8083
memory_optimize=True,
8184
inference_extra_args=None,
8285
batch_inference=False,
8386
progress_callback=None,
8487
) -> None:
88+
89+
# Add show_process_status parameter and initialize progress bars dict
90+
self.show_process_status = show_processing_status
91+
self._progress_bars = {}
92+
93+
# Set up the progress callback based on parameters
94+
if progress_callback is not None:
95+
self.progress_callback = progress_callback
96+
elif show_processing_status:
97+
self.progress_callback = self._tqdm_callback
98+
else:
99+
self.progress_callback = None
100+
85101
if model is None:
86102
self.model = YOLO(model_path) # Load the model from the specified path
87103
else:
@@ -103,7 +119,6 @@ def __init__(
103119
self.class_names_dict = self.model.names # dict with human-readable class names
104120
self.inference_extra_args = inference_extra_args # dict with extra ultralytics inference parameters
105121
self.batch_inference = batch_inference # batch inference of image crops through a neural network
106-
self.progress_callback = progress_callback # callback function to report progress of the inference
107122

108123
self.crops = self.get_crops_xy(
109124
self.image,
@@ -117,6 +132,31 @@ def __init__(
117132
self._detect_objects_batch()
118133
else:
119134
self._detect_objects()
135+
136+
def _tqdm_callback(self, task, current, total):
137+
"""Internal callback function that uses tqdm for progress tracking
138+
139+
Args:
140+
task (str): The name of the task being tracked
141+
current (int): The current progress value
142+
total (int): The total number of steps in the task
143+
144+
"""
145+
if task not in self._progress_bars:
146+
self._progress_bars[task] = tqdm(
147+
total=total,
148+
desc=task,
149+
unit='items'
150+
)
151+
152+
# Update progress
153+
self._progress_bars[task].n = current
154+
self._progress_bars[task].refresh()
155+
156+
# Close and cleanup if task is complete
157+
if current >= total:
158+
self._progress_bars[task].close()
159+
del self._progress_bars[task]
120160

121161
def get_crops_xy(
122162
self,
@@ -342,4 +382,10 @@ def patches_info(self):
342382
# Append the formatted string to the patch_info list
343383
output += f"\nOn patch № {i}, nothing was detected"
344384
print(output)
385+
386+
def __del__(self):
387+
"""Cleanup method to ensure all progress bars are closed"""
388+
for pbar in self._progress_bars.values():
389+
pbar.close()
390+
self._progress_bars.clear()
345391

0 commit comments

Comments
 (0)