11from collections import Counter
2+ from tqdm import tqdm
23import cv2
34import numpy as np
45import matplotlib .pyplot as plt
@@ -33,6 +34,7 @@ class MakeCropsDetectThem:
3334 memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
3435 batch_inference (bool): Batch inference of image crops through a neural network instead of
3536 sequential passes of crops (ps: Faster inference, higher memory use)
37+ progress_callback (function): Optional callback function, (task: str, current: int, total: int)
3638 inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
3739
3840 Attributes:
@@ -56,6 +58,7 @@ class MakeCropsDetectThem:
5658 memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
5759 batch_inference (bool): Batch inference of image crops through a neural network instead of
5860 sequential passes of crops (ps: Faster inference, higher memory use)
61+ progress_callback (function): Optional callback function, (task: str, current: int, total: int)
5962 inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
6063 """
6164 def __init__ (
@@ -77,6 +80,7 @@ def __init__(
7780 memory_optimize = True ,
7881 inference_extra_args = None ,
7982 batch_inference = False ,
83+ progress_callback = None ,
8084 ) -> None :
8185 if model is None :
8286 self .model = YOLO (model_path ) # Load the model from the specified path
@@ -99,6 +103,7 @@ def __init__(
99103 self .class_names_dict = self .model .names # dict with human-readable class names
100104 self .inference_extra_args = inference_extra_args # dict with extra ultralytics inference parameters
101105 self .batch_inference = batch_inference # batch inference of image crops through a neural network
106+ self .progress_callback = progress_callback # callback function to report progress of the inference
102107
103108 self .crops = self .get_crops_xy (
104109 self .image ,
@@ -158,6 +163,7 @@ def get_crops_xy(
158163 plt .figure (figsize = [x_steps * 0.9 , y_steps * 0.9 ])
159164
160165 count = 0
166+ total_steps = y_steps * x_steps # Total number of crops
161167 for i in range (y_steps ):
162168 for j in range (x_steps ):
163169 x_start = int (shape_x * j * cross_koef_x )
@@ -179,6 +185,10 @@ def get_crops_xy(
179185 plt .imshow (cv2 .cvtColor (im_temp .copy (), cv2 .COLOR_BGR2RGB ))
180186 plt .axis ('off' )
181187 count += 1
188+
189+ # Call the progress callback function if provided
190+ if self .progress_callback is not None :
191+ self .progress_callback ("Getting crops" , count , total_steps )
182192
183193 data_all_crops .append (CropElement (
184194 source_image = image_innitial ,
@@ -210,7 +220,8 @@ def _detect_objects(self):
210220 Returns:
211221 None
212222 """
213- for crop in self .crops :
223+ total_crops = len (self .crops ) # Total number of crops
224+ for index , crop in enumerate (self .crops ):
214225 crop .calculate_inference (
215226 self .model ,
216227 imgsz = self .imgsz ,
@@ -224,6 +235,11 @@ def _detect_objects(self):
224235 crop .calculate_real_values ()
225236 if self .resize_initial_size :
226237 crop .resize_results ()
238+
239+ # Call the progress callback function if provided
240+ if self .progress_callback is not None :
241+ self .progress_callback ("Detecting objects" , (index + 1 ), total_crops )
242+
227243
228244 def _detect_objects_batch (self ):
229245 """
0 commit comments