yonigozlan HF Staff commited on
Commit
51c9688
·
1 Parent(s): ec83bbc

initial commit

Browse files
Files changed (8) hide show
  1. .gitattributes +4 -0
  2. app.py +332 -0
  3. basket.mp4 +3 -0
  4. football.mp4 +3 -0
  5. hurdles.mp4 +3 -0
  6. render.py +125 -0
  7. requirements.txt +6 -0
  8. tennis.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ basket.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ football.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ hurdles.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ tennis.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import spaces
6
+ import supervision as sv
7
+ import torch
8
+ from render import draw_links, draw_points, keypoint_colors, link_colors
9
+ from tqdm import tqdm
10
+
11
+ from transformers import (
12
+ AutoProcessor,
13
+ RTDetrForObjectDetection,
14
+ VitPoseForPoseEstimation,
15
+ )
16
+
17
+ css = """
18
+ .feedback textarea {font-size: 24px !important}
19
+ """
20
+
21
+ device = "cuda"
22
+
23
+
24
+ def calculate_end_frame_index(source_video_path):
25
+ video_info = sv.VideoInfo.from_video_path(source_video_path)
26
+ return video_info.total_frames
27
+
28
+
29
+ @spaces.GPU
30
+ def process_image(
31
+ input_image,
32
+ model_variant,
33
+ progress=gr.Progress(track_tqdm=True),
34
+ ):
35
+ # You can choose detector by your choice
36
+ person_image_processor = AutoProcessor.from_pretrained(
37
+ "PekingU/rtdetr_r50vd_coco_o365"
38
+ )
39
+ person_model = RTDetrForObjectDetection.from_pretrained(
40
+ "PekingU/rtdetr_r50vd_coco_o365", device_map=device
41
+ )
42
+
43
+ if model_variant == "Base":
44
+ model_name = "yonigozlan/synthpose-vitpose-base-hf"
45
+ else:
46
+ model_name = "yonigozlan/synthpose-vitpose-huge-hf"
47
+
48
+ image_processor = AutoProcessor.from_pretrained(model_name)
49
+ model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device)
50
+
51
+ keypoint_edges = model.config.edges
52
+
53
+ frame = np.array(input_image)
54
+ inputs = person_image_processor(images=frame, return_tensors="pt").to(device)
55
+
56
+ with torch.no_grad():
57
+ outputs = person_model(**inputs)
58
+
59
+ results = person_image_processor.post_process_object_detection(
60
+ outputs,
61
+ target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]),
62
+ threshold=0.4,
63
+ )
64
+ result = results[0] # take first image results
65
+
66
+ # Human label refers 0 index in COCO dataset
67
+ person_boxes = result["boxes"][result["labels"] == 0]
68
+ person_boxes = person_boxes.cpu().numpy()
69
+
70
+ # Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
71
+ person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
72
+ person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
73
+
74
+ # ------------------------------------------------------------------------
75
+ # Stage 2. Detect keypoints for each person found
76
+ # ------------------------------------------------------------------------
77
+
78
+ inputs = image_processor(frame, boxes=[person_boxes], return_tensors="pt").to(
79
+ device
80
+ )
81
+
82
+ with torch.no_grad():
83
+ outputs = model(**inputs)
84
+
85
+ pose_results = image_processor.post_process_pose_estimation(
86
+ outputs, boxes=[person_boxes]
87
+ )
88
+ image_pose_result = pose_results[0] # results for first image
89
+
90
+ for pose_result in image_pose_result:
91
+ scores = np.array(pose_result["scores"])
92
+ keypoints = np.array(pose_result["keypoints"])
93
+
94
+ # draw each point on image
95
+ draw_points(
96
+ frame,
97
+ keypoints,
98
+ scores,
99
+ keypoint_colors,
100
+ keypoint_score_threshold=0.3,
101
+ radius=max(2, int(max(frame.shape[0], frame.shape[1]) / 500)),
102
+ show_keypoint_weight=False,
103
+ )
104
+
105
+ # draw links
106
+ draw_links(
107
+ frame,
108
+ keypoints,
109
+ scores,
110
+ keypoint_edges,
111
+ link_colors,
112
+ keypoint_score_threshold=0.3,
113
+ thickness=max(2, int(max(frame.shape[0], frame.shape[1]) / 1000)),
114
+ show_keypoint_weight=False,
115
+ )
116
+
117
+ return frame
118
+
119
+
120
+ @spaces.GPU
121
+ def process_video(
122
+ input_video,
123
+ model_variant,
124
+ progress=gr.Progress(track_tqdm=True),
125
+ ):
126
+ video_info = sv.VideoInfo.from_video_path(input_video)
127
+ total = calculate_end_frame_index(input_video)
128
+ frame_generator = sv.get_video_frames_generator(source_path=input_video, end=total)
129
+
130
+ result_file_name = "output.mp4"
131
+ result_file_path = os.path.join(os.getcwd(), result_file_name)
132
+ # You can choose detector by your choice
133
+ person_image_processor = AutoProcessor.from_pretrained(
134
+ "PekingU/rtdetr_r50vd_coco_o365"
135
+ )
136
+ person_model = RTDetrForObjectDetection.from_pretrained(
137
+ "PekingU/rtdetr_r50vd_coco_o365", device_map=device
138
+ )
139
+ if model_variant == "Base":
140
+ model_name = "yonigozlan/synthpose-vitpose-base-hf"
141
+ else:
142
+ model_name = "yonigozlan/synthpose-vitpose-huge-hf"
143
+
144
+ image_processor = AutoProcessor.from_pretrained(model_name)
145
+ model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device)
146
+
147
+ keypoint_edges = model.config.edges
148
+
149
+ with sv.VideoSink(result_file_path, video_info=video_info) as sink:
150
+ for _ in tqdm(range(total), desc="Processing video.."):
151
+ try:
152
+ frame = next(frame_generator)
153
+ except StopIteration:
154
+ break
155
+ # ------------------------------------------------------------------------
156
+ # Stage 1. Detect humans on the image
157
+ # ------------------------------------------------------------------------
158
+
159
+ inputs = person_image_processor(images=frame, return_tensors="pt").to(
160
+ device
161
+ )
162
+
163
+ with torch.no_grad():
164
+ outputs = person_model(**inputs)
165
+
166
+ results = person_image_processor.post_process_object_detection(
167
+ outputs,
168
+ target_sizes=torch.tensor([(frame.shape[0], frame.shape[1])]),
169
+ threshold=0.4,
170
+ )
171
+ result = results[0] # take first image results
172
+
173
+ # Human label refers 0 index in COCO dataset
174
+ person_boxes = result["boxes"][result["labels"] == 0]
175
+ person_boxes = person_boxes.cpu().numpy()
176
+
177
+ # Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
178
+ person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
179
+ person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
180
+
181
+ # ------------------------------------------------------------------------
182
+ # Stage 2. Detect keypoints for each person found
183
+ # ------------------------------------------------------------------------
184
+
185
+ inputs = image_processor(
186
+ frame, boxes=[person_boxes], return_tensors="pt"
187
+ ).to(device)
188
+
189
+ with torch.no_grad():
190
+ outputs = model(**inputs)
191
+
192
+ pose_results = image_processor.post_process_pose_estimation(
193
+ outputs, boxes=[person_boxes]
194
+ )
195
+ image_pose_result = pose_results[0] # results for first image
196
+
197
+ for pose_result in image_pose_result:
198
+ scores = np.array(pose_result["scores"])
199
+ keypoints = np.array(pose_result["keypoints"])
200
+
201
+ # draw each point on image
202
+ draw_points(
203
+ frame,
204
+ keypoints,
205
+ scores,
206
+ keypoint_colors,
207
+ keypoint_score_threshold=0.3,
208
+ radius=max(2, int(frame.shape[0] / 500)),
209
+ show_keypoint_weight=False,
210
+ )
211
+
212
+ # draw links
213
+ draw_links(
214
+ frame,
215
+ keypoints,
216
+ scores,
217
+ keypoint_edges,
218
+ link_colors,
219
+ keypoint_score_threshold=0.3,
220
+ thickness=max(1, int(frame.shape[0] / 1000)),
221
+ show_keypoint_weight=False,
222
+ )
223
+
224
+ sink.write_frame(frame)
225
+
226
+ return result_file_path
227
+
228
+
229
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
230
+ gr.Markdown("## Markerless Motion Capture with SynthPose")
231
+ gr.Markdown(
232
+ """
233
+ SynthPose is a new approach that enables finetuning of pre-trained 2D human pose models to predict an arbitrarily denser set of keypoints for accurate kinematic analysis through the use of synthetic data.
234
+ More details are available in [OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics](https://arxiv.org/abs/2406.09788).
235
+ This particular variant was finetuned on a set of keypoints usually found on motion capture setups, and include coco keypoints as well.<br />
236
+ The keypoints part of the skeleton are the COCO keypoints, and the pink ones the anatomical markers.
237
+ """
238
+ )
239
+ gr.Markdown(
240
+ "Simply upload a video, and press run to start the inference! You can also try the examples below. 👇"
241
+ )
242
+
243
+ with gr.Row():
244
+ with gr.Column():
245
+ input_choice = gr.Radio(
246
+ ["Video", "Image"], label="Input Type", value="Video", interactive=True
247
+ )
248
+ model_variant = gr.Radio(
249
+ ["Base", "Huge"], label="Model Variant", value="Base", interactive=True
250
+ )
251
+ input_video = gr.Video(label="Input Video")
252
+ input_image = gr.Image(label="Input Image", visible=False)
253
+ with gr.Column():
254
+ output_video = gr.Video(label="Output Video")
255
+ output_image = gr.Image(label="Output Image", visible=False)
256
+
257
+ with gr.Row():
258
+ submit_video = gr.Button(variant="primary")
259
+ submit_image = gr.Button(variant="primary", visible=False)
260
+
261
+ def switch_input_type(input_choice):
262
+ input_type = input_choice
263
+ if input_type == "Video":
264
+ return [
265
+ gr.update(visible=True),
266
+ gr.update(visible=False),
267
+ gr.update(visible=True),
268
+ gr.update(visible=False),
269
+ gr.update(visible=True),
270
+ gr.update(visible=False),
271
+ ]
272
+
273
+ # input_video.visible = True
274
+ # input_image.visible = False
275
+ # output_video.visible = True
276
+ # output_image.visible = False
277
+ # submit_video.visible = True
278
+ # submit_image.visible = False
279
+ else:
280
+ return [
281
+ gr.update(visible=False),
282
+ gr.update(visible=True),
283
+ gr.update(visible=False),
284
+ gr.update(visible=True),
285
+ gr.update(visible=False),
286
+ gr.update(visible=True),
287
+ ]
288
+
289
+ # input_video.visible = False
290
+ # input_image.visible = True
291
+ # output_video.visible = False
292
+ # output_image.visible = True
293
+ # submit_video.visible = False
294
+ # submit_image.visible = True
295
+
296
+ input_choice.change(
297
+ switch_input_type,
298
+ inputs=input_choice,
299
+ outputs=[
300
+ input_video,
301
+ input_image,
302
+ output_video,
303
+ output_image,
304
+ submit_video,
305
+ submit_image,
306
+ ],
307
+ )
308
+
309
+ example = gr.Examples(
310
+ examples=[
311
+ ["./tennis.mp4"],
312
+ ["./football.mp4"],
313
+ ["./basket.mp4"],
314
+ ["./hurdles.mp4"],
315
+ ],
316
+ inputs=[input_video],
317
+ outputs=output_video,
318
+ )
319
+
320
+ submit_video.click(
321
+ fn=process_video,
322
+ inputs=[input_video, model_variant],
323
+ outputs=[output_video],
324
+ )
325
+ submit_image.click(
326
+ fn=process_image,
327
+ inputs=[input_image, model_variant],
328
+ outputs=[output_image],
329
+ )
330
+
331
+ if __name__ == "__main__":
332
+ demo.launch(show_error=True)
basket.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52ade15f3ec0cb1838627090d646c2c12a21dedbe70d4bd60d9ca3fa6ff45e37
3
+ size 9347210
football.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56a85c5c7d5d6e0825f76a71e5e3ee2ce35c8ffbe841ef4bfa544af1089259aa
3
+ size 2855852
hurdles.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ee5aa420ea2629dcefd9bb3a26221f30b4639f6de001c372d6c2f84e79b0b66
3
+ size 6714353
render.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Visualization for advanced user
2
+ import math
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+
8
+ def draw_points(
9
+ image,
10
+ keypoints,
11
+ scores,
12
+ pose_keypoint_color,
13
+ keypoint_score_threshold,
14
+ radius,
15
+ show_keypoint_weight,
16
+ ):
17
+ if pose_keypoint_color is not None:
18
+ assert len(pose_keypoint_color) == len(keypoints)
19
+ for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
20
+ x_coord, y_coord = int(kpt[0]), int(kpt[1])
21
+ if kpt_score > keypoint_score_threshold:
22
+ color = tuple(int(c) for c in pose_keypoint_color[kid])
23
+ if show_keypoint_weight:
24
+ cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
25
+ transparency = max(0, min(1, kpt_score))
26
+ cv2.addWeighted(
27
+ image, transparency, image, 1 - transparency, 0, dst=image
28
+ )
29
+ else:
30
+ cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
31
+
32
+
33
+ def draw_links(
34
+ image,
35
+ keypoints,
36
+ scores,
37
+ keypoint_edges,
38
+ link_colors,
39
+ keypoint_score_threshold,
40
+ thickness,
41
+ show_keypoint_weight,
42
+ stick_width=2,
43
+ ):
44
+ height, width, _ = image.shape
45
+ if keypoint_edges is not None and link_colors is not None:
46
+ assert len(link_colors) == len(keypoint_edges)
47
+ for sk_id, sk in enumerate(keypoint_edges):
48
+ x1, y1, score1 = (
49
+ int(keypoints[sk[0], 0]),
50
+ int(keypoints[sk[0], 1]),
51
+ scores[sk[0]],
52
+ )
53
+ x2, y2, score2 = (
54
+ int(keypoints[sk[1], 0]),
55
+ int(keypoints[sk[1], 1]),
56
+ scores[sk[1]],
57
+ )
58
+ if (
59
+ x1 > 0
60
+ and x1 < width
61
+ and y1 > 0
62
+ and y1 < height
63
+ and x2 > 0
64
+ and x2 < width
65
+ and y2 > 0
66
+ and y2 < height
67
+ and score1 > keypoint_score_threshold
68
+ and score2 > keypoint_score_threshold
69
+ ):
70
+ color = tuple(int(c) for c in link_colors[sk_id])
71
+ if show_keypoint_weight:
72
+ X = (x1, x2)
73
+ Y = (y1, y2)
74
+ mean_x = np.mean(X)
75
+ mean_y = np.mean(Y)
76
+ length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
77
+ angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
78
+ polygon = cv2.ellipse2Poly(
79
+ (int(mean_x), int(mean_y)),
80
+ (int(length / 2), int(stick_width)),
81
+ int(angle),
82
+ 0,
83
+ 360,
84
+ 1,
85
+ )
86
+ cv2.fillConvexPoly(image, polygon, color)
87
+ transparency = max(
88
+ 0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2]))
89
+ )
90
+ cv2.addWeighted(
91
+ image, transparency, image, 1 - transparency, 0, dst=image
92
+ )
93
+ else:
94
+ cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
95
+
96
+
97
+ palette = np.array(
98
+ [
99
+ [255, 128, 0],
100
+ [255, 153, 51],
101
+ [255, 178, 102],
102
+ [230, 230, 0],
103
+ [255, 153, 255],
104
+ [153, 204, 255],
105
+ [255, 102, 255],
106
+ [255, 51, 255],
107
+ [102, 178, 255],
108
+ [51, 153, 255],
109
+ [255, 153, 153],
110
+ [255, 102, 102],
111
+ [255, 51, 51],
112
+ [153, 255, 153],
113
+ [102, 255, 102],
114
+ [51, 255, 51],
115
+ [0, 255, 0],
116
+ [0, 0, 255],
117
+ [255, 0, 0],
118
+ [255, 255, 255],
119
+ ]
120
+ )
121
+
122
+ link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
123
+ keypoint_colors = palette[
124
+ [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + [4] * (52 - 17)
125
+ ]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ timm
3
+ numpy==1.26.3
4
+ git+https://github.com/huggingface/transformers.git@main
5
+ supervision
6
+ spaces
tennis.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc0868023eb6fa2d68338406964396b2cb1123610fdc6af05ba37c539ee9e92a
3
+ size 6586057