yonigozlan's picture
yonigozlan HF Staff
change app to tabs
419c1aa
import os
import gradio as gr
import numpy as np
import spaces
import supervision as sv
import torch
from render import draw_links, draw_points, keypoint_colors, link_colors
from tqdm import tqdm
from transformers import (
AutoProcessor,
RTDetrForObjectDetection,
VitPoseForPoseEstimation,
)
css = """
.feedback textarea {font-size: 24px !important}
"""
device = "cuda"
def calculate_end_frame_index(source_video_path):
video_info = sv.VideoInfo.from_video_path(source_video_path)
return video_info.total_frames
@spaces.GPU
def process_image(
input_image,
model_variant,
progress=gr.Progress(track_tqdm=True),
):
# You can choose detector by your choice
person_image_processor = AutoProcessor.from_pretrained(
"PekingU/rtdetr_r50vd_coco_o365"
)
person_model = RTDetrForObjectDetection.from_pretrained(
"PekingU/rtdetr_r50vd_coco_o365", device_map=device
)
if model_variant == "Base":
model_name = "yonigozlan/synthpose-vitpose-base-hf"
else:
model_name = "yonigozlan/synthpose-vitpose-huge-hf"
image_processor = AutoProcessor.from_pretrained(model_name)
model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device)
keypoint_edges = model.config.edges
frame = np.array(input_image)
inputs = person_image_processor(images=frame, return_tensors="pt").to(device)
with torch.no_grad():
outputs = person_model(**inputs)
results = person_image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]),
threshold=0.4,
)
result = results[0] # take first image results
# Human label refers 0 index in COCO dataset
person_boxes = result["boxes"][result["labels"] == 0]
person_boxes = person_boxes.cpu().numpy()
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
# ------------------------------------------------------------------------
# Stage 2. Detect keypoints for each person found
# ------------------------------------------------------------------------
inputs = image_processor(frame, boxes=[person_boxes], return_tensors="pt").to(
device
)
with torch.no_grad():
outputs = model(**inputs)
pose_results = image_processor.post_process_pose_estimation(
outputs, boxes=[person_boxes]
)
image_pose_result = pose_results[0] # results for first image
for pose_result in image_pose_result:
scores = np.array(pose_result["scores"])
keypoints = np.array(pose_result["keypoints"])
# draw each point on image
draw_points(
frame,
keypoints,
scores,
keypoint_colors,
keypoint_score_threshold=0.3,
radius=max(2, int(max(frame.shape[0], frame.shape[1]) / 500)),
show_keypoint_weight=False,
)
# draw links
draw_links(
frame,
keypoints,
scores,
keypoint_edges,
link_colors,
keypoint_score_threshold=0.3,
thickness=max(2, int(max(frame.shape[0], frame.shape[1]) / 1000)),
show_keypoint_weight=False,
)
return frame
@spaces.GPU
def process_video(
input_video,
model_variant,
progress=gr.Progress(track_tqdm=True),
):
video_info = sv.VideoInfo.from_video_path(input_video)
total = calculate_end_frame_index(input_video)
frame_generator = sv.get_video_frames_generator(source_path=input_video, end=total)
result_file_name = "output.mp4"
result_file_path = os.path.join(os.getcwd(), result_file_name)
# You can choose detector by your choice
person_image_processor = AutoProcessor.from_pretrained(
"PekingU/rtdetr_r50vd_coco_o365"
)
person_model = RTDetrForObjectDetection.from_pretrained(
"PekingU/rtdetr_r50vd_coco_o365", device_map=device
)
if model_variant == "Base":
model_name = "yonigozlan/synthpose-vitpose-base-hf"
else:
model_name = "yonigozlan/synthpose-vitpose-huge-hf"
image_processor = AutoProcessor.from_pretrained(model_name)
model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device)
keypoint_edges = model.config.edges
with sv.VideoSink(result_file_path, video_info=video_info) as sink:
for _ in tqdm(range(total), desc="Processing video.."):
try:
frame = next(frame_generator)
except StopIteration:
break
# ------------------------------------------------------------------------
# Stage 1. Detect humans on the image
# ------------------------------------------------------------------------
inputs = person_image_processor(images=frame, return_tensors="pt").to(
device
)
with torch.no_grad():
outputs = person_model(**inputs)
results = person_image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]),
threshold=0.4,
)
result = results[0] # take first image results
# Human label refers 0 index in COCO dataset
person_boxes = result["boxes"][result["labels"] == 0]
person_boxes = person_boxes.cpu().numpy()
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
# ------------------------------------------------------------------------
# Stage 2. Detect keypoints for each person found
# ------------------------------------------------------------------------
if len(person_boxes) == 0:
sink.write_frame(frame)
continue
inputs = image_processor(
frame, boxes=[person_boxes], return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
pose_results = image_processor.post_process_pose_estimation(
outputs, boxes=[person_boxes]
)
image_pose_result = pose_results[0] # results for first image
for pose_result in image_pose_result:
scores = np.array(pose_result["scores"])
keypoints = np.array(pose_result["keypoints"])
# draw each point on image
draw_points(
frame,
keypoints,
scores,
keypoint_colors,
keypoint_score_threshold=0.3,
radius=max(2, int(frame.shape[0] / 500)),
show_keypoint_weight=False,
)
# draw links
draw_links(
frame,
keypoints,
scores,
keypoint_edges,
link_colors,
keypoint_score_threshold=0.3,
thickness=max(1, int(frame.shape[0] / 1000)),
show_keypoint_weight=False,
)
sink.write_frame(frame)
return result_file_path
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("## Markerless Motion Capture with SynthPose")
gr.Markdown(
"""
SynthPose is a new approach that enables finetuning of pre-trained 2D human pose models to predict an arbitrarily denser set of keypoints for accurate kinematic analysis through the use of synthetic data.
More details are available in [OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics](https://arxiv.org/abs/2406.09788).<br />
This particular variant was finetuned on a set of keypoints usually found on motion capture setups, and include coco keypoints as well.<br />
The keypoints part of the skeleton are the COCO keypoints, and the pink ones the anatomical markers.
"""
)
gr.Markdown(
"Simply upload a video, and press run to start the inference! You can also try the examples below. πŸ‘‡"
)
with gr.Tabs():
with gr.Tab("Video"):
with gr.Row():
with gr.Column():
model_variant = gr.Radio(
["Base", "Huge"],
label="Model Variant",
value="Base",
interactive=True,
)
input_video = gr.Video(label="Input Video")
with gr.Column():
output_video = gr.Video(label="Output Video")
with gr.Row():
submit_video = gr.Button(variant="primary")
example = gr.Examples(
examples=[
["./tennis.mp4"],
["./football.mp4"],
["./basket.mp4"],
["./hurdles.mp4"],
],
inputs=[input_video],
outputs=output_video,
)
submit_video.click(
fn=process_video,
inputs=[input_video, model_variant],
outputs=[output_video],
)
with gr.Tab("Image"):
with gr.Row():
with gr.Column():
model_variant = gr.Radio(
["Base", "Huge"],
label="Model Variant",
value="Base",
interactive=True,
)
input_image = gr.Image(label="Input Image")
with gr.Column():
output_image = gr.Image(label="Output Image")
with gr.Row():
submit_image = gr.Button(variant="primary")
example_image = gr.Examples(
examples=[
["demo.jpeg"],
],
inputs=[input_image],
outputs=output_image,
)
submit_image.click(
fn=process_image,
inputs=[input_image, model_variant],
outputs=[output_image],
)
if __name__ == "__main__":
demo.launch(show_error=True)