Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import supervision as sv | |
import torch | |
from render import draw_links, draw_points, keypoint_colors, link_colors | |
from tqdm import tqdm | |
from transformers import ( | |
AutoProcessor, | |
RTDetrForObjectDetection, | |
VitPoseForPoseEstimation, | |
) | |
css = """ | |
.feedback textarea {font-size: 24px !important} | |
""" | |
device = "cuda" | |
def calculate_end_frame_index(source_video_path): | |
video_info = sv.VideoInfo.from_video_path(source_video_path) | |
return video_info.total_frames | |
def process_image( | |
input_image, | |
model_variant, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
# You can choose detector by your choice | |
person_image_processor = AutoProcessor.from_pretrained( | |
"PekingU/rtdetr_r50vd_coco_o365" | |
) | |
person_model = RTDetrForObjectDetection.from_pretrained( | |
"PekingU/rtdetr_r50vd_coco_o365", device_map=device | |
) | |
if model_variant == "Base": | |
model_name = "yonigozlan/synthpose-vitpose-base-hf" | |
else: | |
model_name = "yonigozlan/synthpose-vitpose-huge-hf" | |
image_processor = AutoProcessor.from_pretrained(model_name) | |
model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device) | |
keypoint_edges = model.config.edges | |
frame = np.array(input_image) | |
inputs = person_image_processor(images=frame, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = person_model(**inputs) | |
results = person_image_processor.post_process_object_detection( | |
outputs, | |
target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]), | |
threshold=0.4, | |
) | |
result = results[0] # take first image results | |
# Human label refers 0 index in COCO dataset | |
person_boxes = result["boxes"][result["labels"] == 0] | |
person_boxes = person_boxes.cpu().numpy() | |
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format | |
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0] | |
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1] | |
# ------------------------------------------------------------------------ | |
# Stage 2. Detect keypoints for each person found | |
# ------------------------------------------------------------------------ | |
inputs = image_processor(frame, boxes=[person_boxes], return_tensors="pt").to( | |
device | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
pose_results = image_processor.post_process_pose_estimation( | |
outputs, boxes=[person_boxes] | |
) | |
image_pose_result = pose_results[0] # results for first image | |
for pose_result in image_pose_result: | |
scores = np.array(pose_result["scores"]) | |
keypoints = np.array(pose_result["keypoints"]) | |
# draw each point on image | |
draw_points( | |
frame, | |
keypoints, | |
scores, | |
keypoint_colors, | |
keypoint_score_threshold=0.3, | |
radius=max(2, int(max(frame.shape[0], frame.shape[1]) / 500)), | |
show_keypoint_weight=False, | |
) | |
# draw links | |
draw_links( | |
frame, | |
keypoints, | |
scores, | |
keypoint_edges, | |
link_colors, | |
keypoint_score_threshold=0.3, | |
thickness=max(2, int(max(frame.shape[0], frame.shape[1]) / 1000)), | |
show_keypoint_weight=False, | |
) | |
return frame | |
def process_video( | |
input_video, | |
model_variant, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
video_info = sv.VideoInfo.from_video_path(input_video) | |
total = calculate_end_frame_index(input_video) | |
frame_generator = sv.get_video_frames_generator(source_path=input_video, end=total) | |
result_file_name = "output.mp4" | |
result_file_path = os.path.join(os.getcwd(), result_file_name) | |
# You can choose detector by your choice | |
person_image_processor = AutoProcessor.from_pretrained( | |
"PekingU/rtdetr_r50vd_coco_o365" | |
) | |
person_model = RTDetrForObjectDetection.from_pretrained( | |
"PekingU/rtdetr_r50vd_coco_o365", device_map=device | |
) | |
if model_variant == "Base": | |
model_name = "yonigozlan/synthpose-vitpose-base-hf" | |
else: | |
model_name = "yonigozlan/synthpose-vitpose-huge-hf" | |
image_processor = AutoProcessor.from_pretrained(model_name) | |
model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device) | |
keypoint_edges = model.config.edges | |
with sv.VideoSink(result_file_path, video_info=video_info) as sink: | |
for _ in tqdm(range(total), desc="Processing video.."): | |
try: | |
frame = next(frame_generator) | |
except StopIteration: | |
break | |
# ------------------------------------------------------------------------ | |
# Stage 1. Detect humans on the image | |
# ------------------------------------------------------------------------ | |
inputs = person_image_processor(images=frame, return_tensors="pt").to( | |
device | |
) | |
with torch.no_grad(): | |
outputs = person_model(**inputs) | |
results = person_image_processor.post_process_object_detection( | |
outputs, | |
target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]), | |
threshold=0.4, | |
) | |
result = results[0] # take first image results | |
# Human label refers 0 index in COCO dataset | |
person_boxes = result["boxes"][result["labels"] == 0] | |
person_boxes = person_boxes.cpu().numpy() | |
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format | |
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0] | |
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1] | |
# ------------------------------------------------------------------------ | |
# Stage 2. Detect keypoints for each person found | |
# ------------------------------------------------------------------------ | |
if len(person_boxes) == 0: | |
sink.write_frame(frame) | |
continue | |
inputs = image_processor( | |
frame, boxes=[person_boxes], return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
pose_results = image_processor.post_process_pose_estimation( | |
outputs, boxes=[person_boxes] | |
) | |
image_pose_result = pose_results[0] # results for first image | |
for pose_result in image_pose_result: | |
scores = np.array(pose_result["scores"]) | |
keypoints = np.array(pose_result["keypoints"]) | |
# draw each point on image | |
draw_points( | |
frame, | |
keypoints, | |
scores, | |
keypoint_colors, | |
keypoint_score_threshold=0.3, | |
radius=max(2, int(frame.shape[0] / 500)), | |
show_keypoint_weight=False, | |
) | |
# draw links | |
draw_links( | |
frame, | |
keypoints, | |
scores, | |
keypoint_edges, | |
link_colors, | |
keypoint_score_threshold=0.3, | |
thickness=max(1, int(frame.shape[0] / 1000)), | |
show_keypoint_weight=False, | |
) | |
sink.write_frame(frame) | |
return result_file_path | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.Markdown("## Markerless Motion Capture with SynthPose") | |
gr.Markdown( | |
""" | |
SynthPose is a new approach that enables finetuning of pre-trained 2D human pose models to predict an arbitrarily denser set of keypoints for accurate kinematic analysis through the use of synthetic data. | |
More details are available in [OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics](https://arxiv.org/abs/2406.09788).<br /> | |
This particular variant was finetuned on a set of keypoints usually found on motion capture setups, and include coco keypoints as well.<br /> | |
The keypoints part of the skeleton are the COCO keypoints, and the pink ones the anatomical markers. | |
""" | |
) | |
gr.Markdown( | |
"Simply upload a video, and press run to start the inference! You can also try the examples below. π" | |
) | |
with gr.Tabs(): | |
with gr.Tab("Video"): | |
with gr.Row(): | |
with gr.Column(): | |
model_variant = gr.Radio( | |
["Base", "Huge"], | |
label="Model Variant", | |
value="Base", | |
interactive=True, | |
) | |
input_video = gr.Video(label="Input Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video") | |
with gr.Row(): | |
submit_video = gr.Button(variant="primary") | |
example = gr.Examples( | |
examples=[ | |
["./tennis.mp4"], | |
["./football.mp4"], | |
["./basket.mp4"], | |
["./hurdles.mp4"], | |
], | |
inputs=[input_video], | |
outputs=output_video, | |
) | |
submit_video.click( | |
fn=process_video, | |
inputs=[input_video, model_variant], | |
outputs=[output_video], | |
) | |
with gr.Tab("Image"): | |
with gr.Row(): | |
with gr.Column(): | |
model_variant = gr.Radio( | |
["Base", "Huge"], | |
label="Model Variant", | |
value="Base", | |
interactive=True, | |
) | |
input_image = gr.Image(label="Input Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image") | |
with gr.Row(): | |
submit_image = gr.Button(variant="primary") | |
example_image = gr.Examples( | |
examples=[ | |
["demo.jpeg"], | |
], | |
inputs=[input_image], | |
outputs=output_image, | |
) | |
submit_image.click( | |
fn=process_image, | |
inputs=[input_image, model_variant], | |
outputs=[output_image], | |
) | |
if __name__ == "__main__": | |
demo.launch(show_error=True) | |