@@ -79,10 +79,6 @@ def visualize_results_usual_yolo_inference(
7979
8080 class_names = model .names
8181
82- # Map class IDs to indices in the color list
83- all_classes = set (cls for pred in predictions for cls in pred .boxes .cls .cpu ().int ().tolist ())
84- class_to_color_index = {cls_id : idx for idx , cls_id in enumerate (all_classes )}
85-
8682 # Process each prediction
8783 for pred in predictions :
8884
@@ -120,7 +116,7 @@ def visualize_results_usual_yolo_inference(
120116 random .seed (int (classes [i ] + delta_colors ))
121117 color = (random .randint (0 , 255 ), random .randint (0 , 255 ), random .randint (0 , 255 ))
122118 else :
123- color = list_of_class_colors [class_to_color_index [ class_index ] ]
119+ color = list_of_class_colors [class_index ]
124120
125121 box = boxes [i ]
126122 x_min , y_min , x_max , y_max = box
@@ -148,7 +144,7 @@ def visualize_results_usual_yolo_inference(
148144 label = str (class_name )
149145 (text_width , text_height ), _ = cv2 .getTextSize (label , font , font_scale , thickness )
150146 background_color = (
151- color_class_background [class_to_color_index [ class_index ] ]
147+ color_class_background [class_index ]
152148 if isinstance (color_class_background , list )
153149 else color_class_background
154150 )
@@ -339,10 +335,6 @@ def visualize_results(
339335 if random_object_colors :
340336 random .seed (int (delta_colors ))
341337
342- # Map class IDs to indices in the color list
343- unique_classes = set (classes_ids )
344- class_to_color_index = {cls_id : idx for idx , cls_id in enumerate (unique_classes )}
345-
346338 # Process each prediction
347339 for i in range (len (classes_ids )):
348340 # Get the class for the current detection
@@ -361,7 +353,7 @@ def visualize_results(
361353 random .seed (int (classes_ids [i ] + delta_colors ))
362354 color = (random .randint (0 , 255 ), random .randint (0 , 255 ), random .randint (0 , 255 ))
363355 else :
364- color = list_of_class_colors [class_to_color_index [ classes_ids [i ] ]]
356+ color = list_of_class_colors [classes_ids [i ]]
365357
366358 box = boxes [i ]
367359 x_min , y_min , x_max , y_max = box
@@ -409,7 +401,7 @@ def visualize_results(
409401 label = str (class_name )
410402 (text_width , text_height ), _ = cv2 .getTextSize (label , font , font_scale , thickness )
411403 background_color = (
412- color_class_background [class_to_color_index [ classes_ids [i ] ]]
404+ color_class_background [classes_ids [i ]]
413405 if isinstance (color_class_background , list )
414406 else color_class_background
415407 )
0 commit comments