@@ -46,7 +46,7 @@ def visualize_results_usual_yolo_inference(
4646 show_class (bool): Whether to show class labels. Default is True.
4747 fill_mask (bool): Whether to fill the segmented regions with color. Default is False.
4848 alpha (float): The transparency of filled masks. Default is 0.3.
49- color_class_background (tuple): The background BGR color for class labels. Default is (0, 0, 255) (red).
49+ color_class_background (tuple / list of tuple ): The background BGR color for class labels. Default is (0, 0, 255) (red).
5050 color_class_text (tuple): The text BGR color for class labels. Default is (255, 255, 255) (white).
5151 thickness (int): The thickness of bounding box and text. Default is 4.
5252 font: The font type for class labels. Default is cv2.FONT_HERSHEY_SIMPLEX.
@@ -79,6 +79,10 @@ def visualize_results_usual_yolo_inference(
7979
8080 class_names = model .names
8181
82+ # Map class IDs to indices in the color list
83+ all_classes = set (cls for pred in predictions for cls in pred .boxes .cls .cpu ().int ().tolist ())
84+ class_to_color_index = {cls_id : idx for idx , cls_id in enumerate (all_classes )}
85+
8286 # Process each prediction
8387 for pred in predictions :
8488
@@ -116,7 +120,7 @@ def visualize_results_usual_yolo_inference(
116120 random .seed (int (classes [i ] + delta_colors ))
117121 color = (random .randint (0 , 255 ), random .randint (0 , 255 ), random .randint (0 , 255 ))
118122 else :
119- color = list_of_class_colors [classes [ i ]]
123+ color = list_of_class_colors [class_to_color_index [ class_index ]]
120124
121125 box = boxes [i ]
122126 x_min , y_min , x_max , y_max = box
@@ -143,11 +147,16 @@ def visualize_results_usual_yolo_inference(
143147 else :
144148 label = str (class_name )
145149 (text_width , text_height ), _ = cv2 .getTextSize (label , font , font_scale , thickness )
150+ background_color = (
151+ color_class_background [class_to_color_index [class_index ]]
152+ if isinstance (color_class_background , list )
153+ else color_class_background
154+ )
146155 cv2 .rectangle (
147156 labeled_image ,
148157 (x_min , y_min ),
149158 (x_min + text_width + 5 , y_min + text_height + 5 ),
150- color_class_background ,
159+ background_color ,
151160 - 1 ,
152161 )
153162 cv2 .putText (
@@ -303,7 +312,7 @@ def visualize_results(
303312 show_class (bool): Whether to show class labels. Default is True.
304313 fill_mask (bool): Whether to fill the segmented regions with color. Default is False.
305314 alpha (float): The transparency of filled masks. Default is 0.3.
306- color_class_background (tuple): The background BGR color for class labels. Default is (0, 0, 255) (red).
315+ color_class_background (tuple / list of tuple ): The background BGR color for class labels. Default is (0, 0, 255) (red).
307316 color_class_text (tuple): The text BGR color for class labels. Default is (255, 255, 255) (white).
308317 thickness (int): The thickness of bounding box and text. Default is 4.
309318 font: The font type for class labels. Default is cv2.FONT_HERSHEY_SIMPLEX.
@@ -330,6 +339,10 @@ def visualize_results(
330339 if random_object_colors :
331340 random .seed (int (delta_colors ))
332341
342+ # Map class IDs to indices in the color list
343+ unique_classes = set (classes_ids )
344+ class_to_color_index = {cls_id : idx for idx , cls_id in enumerate (unique_classes )}
345+
333346 # Process each prediction
334347 for i in range (len (classes_ids )):
335348 # Get the class for the current detection
@@ -348,7 +361,7 @@ def visualize_results(
348361 random .seed (int (classes_ids [i ] + delta_colors ))
349362 color = (random .randint (0 , 255 ), random .randint (0 , 255 ), random .randint (0 , 255 ))
350363 else :
351- color = list_of_class_colors [classes_ids [i ]]
364+ color = list_of_class_colors [class_to_color_index [ classes_ids [i ] ]]
352365
353366 box = boxes [i ]
354367 x_min , y_min , x_max , y_max = box
@@ -395,11 +408,16 @@ def visualize_results(
395408 else :
396409 label = str (class_name )
397410 (text_width , text_height ), _ = cv2 .getTextSize (label , font , font_scale , thickness )
411+ background_color = (
412+ color_class_background [class_to_color_index [classes_ids [i ]]]
413+ if isinstance (color_class_background , list )
414+ else color_class_background
415+ )
398416 cv2 .rectangle (
399417 labeled_image ,
400418 (x_min , y_min ),
401419 (x_min + text_width + 5 , y_min + text_height + 5 ),
402- color_class_background ,
420+ background_color ,
403421 - 1 ,
404422 )
405423 cv2 .putText (
0 commit comments