Skip to content

Commit a81680a

Browse files
authored
Merge pull request #31 from Jordan-Pierce/main
Progress Status and Optional Progress Callback Function for MakeCropsDetectThem
2 parents 5c3e90d + 450fccb commit a81680a

3 files changed

Lines changed: 65 additions & 1 deletion

File tree

patched_yolo_infer/nodes/MakeCropsDetectThem.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import Counter
2+
from tqdm import tqdm
23
import cv2
34
import numpy as np
45
import matplotlib.pyplot as plt
@@ -26,13 +27,15 @@ class MakeCropsDetectThem:
2627
overlap_x (int): Percentage of overlap along the x-axis.
2728
overlap_y (int): Percentage of overlap along the y-axis.
2829
show_crops (bool): Whether to visualize the cropping.
30+
show_processing_status (bool): Whether to show the processing status using tqdm.
2931
resize_initial_size (bool): Whether to resize the results to the original
3032
image size (ps: slow operation).
3133
model: Pre-initialized model object. If provided, the model will be used directly
3234
instead of loading from model_path.
3335
memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
3436
batch_inference (bool): Batch inference of image crops through a neural network instead of
3537
sequential passes of crops (ps: Faster inference, higher memory use)
38+
progress_callback (function): Optional custom callback function, (task: str, current: int, total: int)
3639
inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
3740
3841
Attributes:
@@ -50,12 +53,14 @@ class MakeCropsDetectThem:
5053
overlap_y (int): Percentage of overlap along the y-axis.
5154
crops (list): List to store the CropElement objects.
5255
show_crops (bool): Whether to visualize the cropping.
56+
show_processing_status (bool): Whether to show the processing status using tqdm.
5357
resize_initial_size (bool): Whether to resize the results to the original
5458
image size (ps: slow operation).
5559
class_names_dict (dict): Dictionary containing class names of the YOLO model.
5660
memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
5761
batch_inference (bool): Batch inference of image crops through a neural network instead of
5862
sequential passes of crops (ps: Faster inference, higher memory use)
63+
progress_callback (function): Optional custom callback function, (task: str, current: int, total: int)
5964
inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
6065
"""
6166
def __init__(
@@ -72,12 +77,27 @@ def __init__(
7277
overlap_x=25,
7378
overlap_y=25,
7479
show_crops=False,
80+
show_processing_status=True,
7581
resize_initial_size=True,
7682
model=None,
7783
memory_optimize=True,
7884
inference_extra_args=None,
7985
batch_inference=False,
86+
progress_callback=None,
8087
) -> None:
88+
89+
# Add show_process_status parameter and initialize progress bars dict
90+
self.show_process_status = show_processing_status
91+
self._progress_bars = {}
92+
93+
# Set up the progress callback based on parameters
94+
if progress_callback is not None:
95+
self.progress_callback = progress_callback
96+
elif show_processing_status:
97+
self.progress_callback = self._tqdm_callback
98+
else:
99+
self.progress_callback = None
100+
81101
if model is None:
82102
self.model = YOLO(model_path) # Load the model from the specified path
83103
else:
@@ -112,6 +132,31 @@ def __init__(
112132
self._detect_objects_batch()
113133
else:
114134
self._detect_objects()
135+
136+
def _tqdm_callback(self, task, current, total):
137+
"""Internal callback function that uses tqdm for progress tracking
138+
139+
Args:
140+
task (str): The name of the task being tracked
141+
current (int): The current progress value
142+
total (int): The total number of steps in the task
143+
144+
"""
145+
if task not in self._progress_bars:
146+
self._progress_bars[task] = tqdm(
147+
total=total,
148+
desc=task,
149+
unit='items'
150+
)
151+
152+
# Update progress
153+
self._progress_bars[task].n = current
154+
self._progress_bars[task].refresh()
155+
156+
# Close and cleanup if task is complete
157+
if current >= total:
158+
self._progress_bars[task].close()
159+
del self._progress_bars[task]
115160

116161
def get_crops_xy(
117162
self,
@@ -158,6 +203,7 @@ def get_crops_xy(
158203
plt.figure(figsize=[x_steps*0.9, y_steps*0.9])
159204

160205
count = 0
206+
total_steps = y_steps * x_steps # Total number of crops
161207
for i in range(y_steps):
162208
for j in range(x_steps):
163209
x_start = int(shape_x * j * cross_koef_x)
@@ -179,6 +225,10 @@ def get_crops_xy(
179225
plt.imshow(cv2.cvtColor(im_temp.copy(), cv2.COLOR_BGR2RGB))
180226
plt.axis('off')
181227
count += 1
228+
229+
# Call the progress callback function if provided
230+
if self.progress_callback is not None:
231+
self.progress_callback("Getting crops", count, total_steps)
182232

183233
data_all_crops.append(CropElement(
184234
source_image=image_innitial,
@@ -210,7 +260,8 @@ def _detect_objects(self):
210260
Returns:
211261
None
212262
"""
213-
for crop in self.crops:
263+
total_crops = len(self.crops) # Total number of crops
264+
for index, crop in enumerate(self.crops):
214265
crop.calculate_inference(
215266
self.model,
216267
imgsz=self.imgsz,
@@ -224,6 +275,11 @@ def _detect_objects(self):
224275
crop.calculate_real_values()
225276
if self.resize_initial_size:
226277
crop.resize_results()
278+
279+
# Call the progress callback function if provided
280+
if self.progress_callback is not None:
281+
self.progress_callback("Detecting objects", (index + 1), total_crops)
282+
227283

228284
def _detect_objects_batch(self):
229285
"""
@@ -326,4 +382,10 @@ def patches_info(self):
326382
# Append the formatted string to the patch_info list
327383
output += f"\nOn patch № {i}, nothing was detected"
328384
print(output)
385+
386+
def __del__(self):
387+
"""Cleanup method to ensure all progress bars are closed"""
388+
for pbar in self._progress_bars.values():
389+
pbar.close()
390+
self._progress_bars.clear()
329391

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
numpy<2.0
2+
tqdm
23
opencv-python
34
matplotlib
45
ultralytics

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
python_requires=">=3.8",
2525
install_requires=[
2626
'numpy<2.0',
27+
'tqdm',
2728
'opencv-python',
2829
'matplotlib',
2930
'ultralytics'

0 commit comments

Comments
 (0)