11from collections import Counter
2+ from tqdm import tqdm
23import cv2
34import numpy as np
45import matplotlib .pyplot as plt
@@ -26,13 +27,15 @@ class MakeCropsDetectThem:
2627 overlap_x (int): Percentage of overlap along the x-axis.
2728 overlap_y (int): Percentage of overlap along the y-axis.
2829 show_crops (bool): Whether to visualize the cropping.
30+ show_processing_status (bool): Whether to show the processing status using tqdm.
2931 resize_initial_size (bool): Whether to resize the results to the original
3032 image size (ps: slow operation).
3133 model: Pre-initialized model object. If provided, the model will be used directly
3234 instead of loading from model_path.
3335 memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
3436 batch_inference (bool): Batch inference of image crops through a neural network instead of
3537 sequential passes of crops (ps: Faster inference, higher memory use)
38+ progress_callback (function): Optional custom callback function, (task: str, current: int, total: int)
3639 inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
3740
3841 Attributes:
@@ -50,12 +53,14 @@ class MakeCropsDetectThem:
5053 overlap_y (int): Percentage of overlap along the y-axis.
5154 crops (list): List to store the CropElement objects.
5255 show_crops (bool): Whether to visualize the cropping.
56+ show_processing_status (bool): Whether to show the processing status using tqdm.
5357 resize_initial_size (bool): Whether to resize the results to the original
5458 image size (ps: slow operation).
5559 class_names_dict (dict): Dictionary containing class names of the YOLO model.
5660 memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
5761 batch_inference (bool): Batch inference of image crops through a neural network instead of
5862 sequential passes of crops (ps: Faster inference, higher memory use)
63+ progress_callback (function): Optional custom callback function, (task: str, current: int, total: int)
5964 inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
6065 """
6166 def __init__ (
@@ -72,12 +77,27 @@ def __init__(
7277 overlap_x = 25 ,
7378 overlap_y = 25 ,
7479 show_crops = False ,
80+ show_processing_status = True ,
7581 resize_initial_size = True ,
7682 model = None ,
7783 memory_optimize = True ,
7884 inference_extra_args = None ,
7985 batch_inference = False ,
86+ progress_callback = None ,
8087 ) -> None :
88+
89+ # Add show_process_status parameter and initialize progress bars dict
90+ self .show_process_status = show_processing_status
91+ self ._progress_bars = {}
92+
93+ # Set up the progress callback based on parameters
94+ if progress_callback is not None :
95+ self .progress_callback = progress_callback
96+ elif show_processing_status :
97+ self .progress_callback = self ._tqdm_callback
98+ else :
99+ self .progress_callback = None
100+
81101 if model is None :
82102 self .model = YOLO (model_path ) # Load the model from the specified path
83103 else :
@@ -112,6 +132,31 @@ def __init__(
112132 self ._detect_objects_batch ()
113133 else :
114134 self ._detect_objects ()
135+
136+ def _tqdm_callback (self , task , current , total ):
137+ """Internal callback function that uses tqdm for progress tracking
138+
139+ Args:
140+ task (str): The name of the task being tracked
141+ current (int): The current progress value
142+ total (int): The total number of steps in the task
143+
144+ """
145+ if task not in self ._progress_bars :
146+ self ._progress_bars [task ] = tqdm (
147+ total = total ,
148+ desc = task ,
149+ unit = 'items'
150+ )
151+
152+ # Update progress
153+ self ._progress_bars [task ].n = current
154+ self ._progress_bars [task ].refresh ()
155+
156+ # Close and cleanup if task is complete
157+ if current >= total :
158+ self ._progress_bars [task ].close ()
159+ del self ._progress_bars [task ]
115160
116161 def get_crops_xy (
117162 self ,
@@ -158,6 +203,7 @@ def get_crops_xy(
158203 plt .figure (figsize = [x_steps * 0.9 , y_steps * 0.9 ])
159204
160205 count = 0
206+ total_steps = y_steps * x_steps # Total number of crops
161207 for i in range (y_steps ):
162208 for j in range (x_steps ):
163209 x_start = int (shape_x * j * cross_koef_x )
@@ -179,6 +225,10 @@ def get_crops_xy(
179225 plt .imshow (cv2 .cvtColor (im_temp .copy (), cv2 .COLOR_BGR2RGB ))
180226 plt .axis ('off' )
181227 count += 1
228+
229+ # Call the progress callback function if provided
230+ if self .progress_callback is not None :
231+ self .progress_callback ("Getting crops" , count , total_steps )
182232
183233 data_all_crops .append (CropElement (
184234 source_image = image_innitial ,
@@ -210,7 +260,8 @@ def _detect_objects(self):
210260 Returns:
211261 None
212262 """
213- for crop in self .crops :
263+ total_crops = len (self .crops ) # Total number of crops
264+ for index , crop in enumerate (self .crops ):
214265 crop .calculate_inference (
215266 self .model ,
216267 imgsz = self .imgsz ,
@@ -224,6 +275,11 @@ def _detect_objects(self):
224275 crop .calculate_real_values ()
225276 if self .resize_initial_size :
226277 crop .resize_results ()
278+
279+ # Call the progress callback function if provided
280+ if self .progress_callback is not None :
281+ self .progress_callback ("Detecting objects" , (index + 1 ), total_crops )
282+
227283
228284 def _detect_objects_batch (self ):
229285 """
@@ -326,4 +382,10 @@ def patches_info(self):
326382 # Append the formatted string to the patch_info list
327383 output += f"\n On patch № { i } , nothing was detected"
328384 print (output )
385+
386+ def __del__ (self ):
387+ """Cleanup method to ensure all progress bars are closed"""
388+ for pbar in self ._progress_bars .values ():
389+ pbar .close ()
390+ self ._progress_bars .clear ()
329391
0 commit comments