Skip to content

Commit a825da9

Browse files
committed
yolo-pose visual
1 parent ffac767 commit a825da9

1 file changed

Lines changed: 46 additions & 11 deletions

File tree

patched_yolo_infer/functions_extra.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,14 @@ def visualize_results_yolo_pose_inference(
184184
conf=0.25,
185185
iou=0.7,
186186
thickness=4,
187-
point_radius=4,
187+
point_radius=5,
188188
connection_schema=None,
189-
min_landmark_visibility=0.25,
189+
min_landmarks_visibility=0.25,
190190
show_boxes=True,
191-
show_class=True,
191+
show_class=False,
192192
color_class_background=(0, 0, 255),
193193
color_class_text=(255, 255, 255),
194+
point_color=None,
194195
font=cv2.FONT_HERSHEY_SIMPLEX,
195196
font_scale=1.5,
196197
delta_colors=3,
@@ -218,24 +219,26 @@ def visualize_results_yolo_pose_inference(
218219
connection_schema (list): A list of tuples defining how landmarks should be connected to form a skeleton.
219220
Each tuple contains two indices representing the landmarks to be connected.
220221
If None or empty, only landmarks will be drawn without any connections.
221-
min_landmark_visibility (float): The minimum confidence threshold for a landmark's visibility to be drawn.
222+
min_landmarks_visibility (float): The minimum confidence threshold for a landmark's visibility to be drawn.
222223
show_boxes (bool): Whether to show bounding boxes. Default is True.
223-
show_class (bool): Whether to show class labels. Default is True.
224+
show_class (bool): Whether to show class labels. Default is False.
224225
color_class_background (tuple / list of tuple): The background BGR color for class labels. Default is (0, 0, 255) (red).
225226
color_class_text (tuple): The text BGR color for class labels. Default is (255, 255, 255) (white).
227+
delta_colors (int): The random seed offset for color variation. Default is 3.
228+
list_of_class_colors (list / None): A list of tuples representing the colors for each class in BGR format.
229+
If provided, these colors will be used for displaying the classes instead of random colors.
230+
The number of tuples in the list must match the number of possible classes in the network.
231+
random_object_colors (bool): If True, colors for each object are selected randomly.
232+
point_color (tuple / None): If None, then the point color is chosen to be the same as the box and skeleton;
233+
otherwise, the one you specify.
226234
font: The font type for class labels. Default is cv2.FONT_HERSHEY_SIMPLEX.
227235
font_scale (float): The scale factor for font size. Default is 1.5.
228-
delta_colors (int): The random seed offset for color variation. Default is 3.
229236
dpi (int): Final visualization size (plot is bigger when dpi is higher).
230-
random_object_colors (bool): If True, colors for each object are selected randomly.
231237
show_confidences (bool): If True and show_class=True, confidences near class are visualized.
232238
axis_off (bool): If True, axis is turned off in the final visualization.
233239
show_classes_list (list): If empty, visualize all classes. Otherwise, visualize only classes in the list.
234240
show_points_list (list): If empty, visualize all points. Otherwise, visualize only points in the list.
235-
inference_extra_args (dict/None): Dictionary with extra ultralytics inference parameters.
236-
list_of_class_colors (list/None): A list of tuples representing the colors for each class in BGR format.
237-
If provided, these colors will be used for displaying the classes instead of random colors.
238-
The number of tuples in the list must match the number of possible classes in the network.
241+
inference_extra_args (dict / None): Dictionary with extra ultralytics inference parameters.
239242
return_image_array (bool): If True, the function returns the image bgr array instead of displaying it.
240243
Default is False.
241244
@@ -266,6 +269,10 @@ def visualize_results_yolo_pose_inference(
266269
# Get the mask confidence scores
267270
confidences = pred.boxes.conf.cpu().numpy()
268271

272+
landmarks_visibility = pred.keypoints.conf.cpu().tolist()
273+
274+
landmarks_xy = pred.keypoints.xy.cpu().int().tolist()
275+
269276
num_objects = len(classes)
270277

271278
# Visualization
@@ -321,6 +328,34 @@ def visualize_results_yolo_pose_inference(
321328
thickness=thickness,
322329
)
323330

331+
if connection_schema is not None:
332+
for pair in connection_schema:
333+
start, end = pair
334+
if (
335+
landmarks_xy[i][start][0] > 0
336+
and landmarks_xy[i][start][1] > 0
337+
and landmarks_xy[i][end][0] > 0
338+
and landmarks_xy[i][end][1] > 0
339+
and landmarks_visibility[i][start] >= min_landmarks_visibility
340+
and landmarks_visibility[i][end] >= min_landmarks_visibility
341+
):
342+
x1, y1 = landmarks_xy[i][start][0], landmarks_xy[i][start][1]
343+
x2, y2 = landmarks_xy[i][end][0], landmarks_xy[i][end][1]
344+
cv2.line(labeled_image, (x1, y1), (x2, y2), color, thickness)
345+
346+
if point_radius > 0:
347+
num_point = -1
348+
for point, landmark_visibility in zip(landmarks_xy[i], landmarks_visibility[i]):
349+
num_point += 1
350+
if show_points_list and num_point not in show_points_list:
351+
continue
352+
x, y = point
353+
if x > 0 and y > 0 and landmark_visibility >= min_landmarks_visibility:
354+
if point_color is None:
355+
cv2.circle(labeled_image, (x, y), point_radius, color, -1)
356+
else:
357+
cv2.circle(labeled_image, (x, y), point_radius, point_color, -1)
358+
324359
if return_image_array:
325360
return labeled_image
326361
else:

0 commit comments

Comments
 (0)