Skip to content

Commit ae796c6

Browse files
authored
Merge pull request #23 from curtis18/main
Update functions_extra.py to support color_class_background using list of tuple
2 parents bb837c0 + e5dd7c8 commit ae796c6

1 file changed

Lines changed: 24 additions & 6 deletions

File tree

patched_yolo_infer/functions_extra.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def visualize_results_usual_yolo_inference(
4646
show_class (bool): Whether to show class labels. Default is True.
4747
fill_mask (bool): Whether to fill the segmented regions with color. Default is False.
4848
alpha (float): The transparency of filled masks. Default is 0.3.
49-
color_class_background (tuple): The background BGR color for class labels. Default is (0, 0, 255) (red).
49+
color_class_background (tuple / list of tuple): The background BGR color for class labels. Default is (0, 0, 255) (red).
5050
color_class_text (tuple): The text BGR color for class labels. Default is (255, 255, 255) (white).
5151
thickness (int): The thickness of bounding box and text. Default is 4.
5252
font: The font type for class labels. Default is cv2.FONT_HERSHEY_SIMPLEX.
@@ -79,6 +79,10 @@ def visualize_results_usual_yolo_inference(
7979

8080
class_names = model.names
8181

82+
# Map class IDs to indices in the color list
83+
all_classes = set(cls for pred in predictions for cls in pred.boxes.cls.cpu().int().tolist())
84+
class_to_color_index = {cls_id: idx for idx, cls_id in enumerate(all_classes)}
85+
8286
# Process each prediction
8387
for pred in predictions:
8488

@@ -116,7 +120,7 @@ def visualize_results_usual_yolo_inference(
116120
random.seed(int(classes[i] + delta_colors))
117121
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
118122
else:
119-
color = list_of_class_colors[classes[i]]
123+
color = list_of_class_colors[class_to_color_index[class_index]]
120124

121125
box = boxes[i]
122126
x_min, y_min, x_max, y_max = box
@@ -143,11 +147,16 @@ def visualize_results_usual_yolo_inference(
143147
else:
144148
label = str(class_name)
145149
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
150+
background_color = (
151+
color_class_background[class_to_color_index[class_index]]
152+
if isinstance(color_class_background, list)
153+
else color_class_background
154+
)
146155
cv2.rectangle(
147156
labeled_image,
148157
(x_min, y_min),
149158
(x_min + text_width + 5, y_min + text_height + 5),
150-
color_class_background,
159+
background_color,
151160
-1,
152161
)
153162
cv2.putText(
@@ -303,7 +312,7 @@ def visualize_results(
303312
show_class (bool): Whether to show class labels. Default is True.
304313
fill_mask (bool): Whether to fill the segmented regions with color. Default is False.
305314
alpha (float): The transparency of filled masks. Default is 0.3.
306-
color_class_background (tuple): The background BGR color for class labels. Default is (0, 0, 255) (red).
315+
color_class_background (tuple / list of tuple): The background BGR color for class labels. Default is (0, 0, 255) (red).
307316
color_class_text (tuple): The text BGR color for class labels. Default is (255, 255, 255) (white).
308317
thickness (int): The thickness of bounding box and text. Default is 4.
309318
font: The font type for class labels. Default is cv2.FONT_HERSHEY_SIMPLEX.
@@ -330,6 +339,10 @@ def visualize_results(
330339
if random_object_colors:
331340
random.seed(int(delta_colors))
332341

342+
# Map class IDs to indices in the color list
343+
unique_classes = set(classes_ids)
344+
class_to_color_index = {cls_id: idx for idx, cls_id in enumerate(unique_classes)}
345+
333346
# Process each prediction
334347
for i in range(len(classes_ids)):
335348
# Get the class for the current detection
@@ -348,7 +361,7 @@ def visualize_results(
348361
random.seed(int(classes_ids[i] + delta_colors))
349362
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
350363
else:
351-
color = list_of_class_colors[classes_ids[i]]
364+
color = list_of_class_colors[class_to_color_index[classes_ids[i]]]
352365

353366
box = boxes[i]
354367
x_min, y_min, x_max, y_max = box
@@ -395,11 +408,16 @@ def visualize_results(
395408
else:
396409
label = str(class_name)
397410
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
411+
background_color = (
412+
color_class_background[class_to_color_index[classes_ids[i]]]
413+
if isinstance(color_class_background, list)
414+
else color_class_background
415+
)
398416
cv2.rectangle(
399417
labeled_image,
400418
(x_min, y_min),
401419
(x_min + text_width + 5, y_min + text_height + 5),
402-
color_class_background,
420+
background_color,
403421
-1,
404422
)
405423
cv2.putText(

0 commit comments

Comments
 (0)