onuralpszr commited on
Commit
acb8452
·
verified ·
1 Parent(s): 55b540a

-feat: 🚀 new supervision vitpose support and annotators improvement and doccs and gradio UI updates for more functionally

Browse files
Files changed (3) hide show
  1. app.py +122 -39
  2. pyproject.toml +1 -1
  3. requirements.txt +2 -8
app.py CHANGED
@@ -13,10 +13,62 @@ import torch
13
  import tqdm
14
  from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
15
 
16
- DESCRIPTION = "# ViTPose"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  MAX_NUM_FRAMES = 300
19
 
 
 
 
 
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  person_detector_name = "PekingU/rtdetr_r50vd_coco_o365"
@@ -30,11 +82,19 @@ pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_ma
30
 
31
  @spaces.GPU(duration=5)
32
  @torch.inference_mode()
33
- def detect_pose_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dict]]:
 
 
 
 
 
34
  """Detects persons and estimates their poses in a single image.
35
 
36
  Args:
37
  image (PIL.Image.Image): Input image in which to detect persons and estimate poses.
 
 
 
38
 
39
  Returns:
40
  tuple[PIL.Image.Image, list[dict]]:
@@ -44,20 +104,14 @@ def detect_pose_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dic
44
  inputs = person_image_processor(images=image, return_tensors="pt").to(device)
45
  outputs = person_model(**inputs)
46
  results = person_image_processor.post_process_object_detection(
47
- outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
48
  )
49
  result = results[0] # take first image results
50
 
51
- # Human label refers 0 index in COCO dataset
52
- person_boxes_xyxy = result["boxes"][result["labels"] == 0]
53
- person_boxes_xyxy = person_boxes_xyxy.cpu().numpy()
54
-
55
- # Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
56
- person_boxes = person_boxes_xyxy.copy()
57
- person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
58
- person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
59
 
60
- inputs = pose_image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
61
 
62
  # for vitpose-plus-base checkpoint we should additionally provide dataset_index
63
  # to specify which MOE experts to use for inference
@@ -68,11 +122,12 @@ def detect_pose_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dic
68
 
69
  outputs = pose_model(**inputs)
70
 
71
- pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
72
  image_pose_result = pose_results[0] # results for first image
73
 
74
  # make results more human-readable
75
  human_readable_results = []
 
76
  for i, person_pose in enumerate(image_pose_result):
77
  data = {
78
  "person_id": i,
@@ -83,43 +138,55 @@ def detect_pose_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dic
83
  person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True
84
  ):
85
  keypoint_name = pose_model.config.id2label[label.item()]
 
86
  x, y = keypoint
87
  data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()})
88
  human_readable_results.append(data)
89
 
90
- # preprocess to torch tensor of shape (n_objects, n_keypoints, 2)
91
- xy = [pose_result["keypoints"] for pose_result in image_pose_result]
92
- xy = torch.stack(xy).cpu().numpy()
93
-
94
- scores = [pose_result["scores"] for pose_result in image_pose_result]
95
- scores = torch.stack(scores).cpu().numpy()
96
-
97
- keypoints = sv.KeyPoints(xy=xy, confidence=scores)
98
- detections = sv.Detections(xyxy=person_boxes_xyxy)
99
 
100
- edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=1)
101
- vertex_annotator = sv.VertexAnnotator(color=sv.Color.RED, radius=2)
102
- bounding_box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=1)
103
 
104
- annotated_frame = image.copy()
 
 
105
 
106
- # annotate bounding boxes
107
- annotated_frame = bounding_box_annotator.annotate(scene=image.copy(), detections=detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # annotate edges and vertices
110
- annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=keypoints)
111
- return vertex_annotator.annotate(scene=annotated_frame, key_points=keypoints), human_readable_results
112
 
113
 
114
- @spaces.GPU(duration=90)
115
  def detect_pose_video(
116
  video_path: str,
 
 
 
117
  progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
118
  ) -> str:
119
  """Detects persons and estimates their poses for each frame in a video, saving the annotated video.
120
 
121
  Args:
122
  video_path (str): Path to the input video file.
 
 
 
123
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
124
 
125
  Returns:
@@ -140,7 +207,12 @@ def detect_pose_video(
140
  if not ok:
141
  break
142
  rgb_frame = frame[:, :, ::-1]
143
- annotated_frame, _ = detect_pose_image(PIL.Image.fromarray(rgb_frame))
 
 
 
 
 
144
  writer.write(np.asarray(annotated_frame)[:, :, ::-1])
145
  writer.release()
146
  cap.release()
@@ -150,6 +222,17 @@ def detect_pose_video(
150
  with gr.Blocks(css_paths="style.css") as demo:
151
  gr.Markdown(DESCRIPTION)
152
 
 
 
 
 
 
 
 
 
 
 
 
153
  with gr.Tabs():
154
  with gr.Tab("Image"):
155
  with gr.Row():
@@ -160,15 +243,15 @@ with gr.Blocks(css_paths="style.css") as demo:
160
  output_image = gr.Image(label="Output Image")
161
  output_json = gr.JSON(label="Output JSON")
162
  gr.Examples(
163
- examples=sorted(pathlib.Path("images").glob("*.jpg")),
164
- inputs=input_image,
165
  outputs=[output_image, output_json],
166
  fn=detect_pose_image,
167
  )
168
 
169
  run_button_image.click(
170
  fn=detect_pose_image,
171
- inputs=input_image,
172
  outputs=[output_image, output_json],
173
  )
174
 
@@ -183,15 +266,15 @@ with gr.Blocks(css_paths="style.css") as demo:
183
  output_video = gr.Video(label="Output Video")
184
 
185
  gr.Examples(
186
- examples=sorted(pathlib.Path("videos").glob("*.mp4")),
187
- inputs=input_video,
188
  outputs=output_video,
189
  fn=detect_pose_video,
190
  cache_examples=False,
191
  )
192
  run_button_video.click(
193
  fn=detect_pose_video,
194
- inputs=input_video,
195
  outputs=output_video,
196
  )
197
 
 
13
  import tqdm
14
  from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
15
 
16
+ DESCRIPTION = """
17
+ # ViTPose
18
+
19
+ <div style="display: flex; gap: 10px;">
20
+ <a href="https://huggingface.co/docs/transformers/en/model_doc/vitpose">
21
+ <img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface">
22
+ </a>
23
+ <a href="https://arxiv.org/abs/2204.12484">
24
+ <img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper">
25
+ </a>
26
+ <a href="https://github.com/ViTAE-Transformer/ViTPose">
27
+ <img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github">
28
+ </a>
29
+ </div>
30
+
31
+ ViTPose is a state-of-the-art human pose estimation model based on Vision Transformers (ViT). It employs a standard, non-hierarchical ViT backbone and a simple decoder head to predict keypoint heatmaps from images. Despite its simplicity, ViTPose achieves top results on the MS COCO Keypoint Detection benchmark.
32
+
33
+ ViTPose++ further improves performance with a mixture-of-experts (MoE) module and extensive pre-training. The model is scalable, flexible, and demonstrates strong transferability across pose estimation tasks.
34
+
35
+ **Key features:**
36
+ - PyTorch implementation
37
+ - Scalable model size (100M to 1B parameters)
38
+ - Flexible training and inference
39
+ - State-of-the-art accuracy on challenging benchmarks
40
+
41
+ """
42
+
43
+
44
+ COLORS = [
45
+ "#A351FB",
46
+ "#FF4040",
47
+ "#FFA1A0",
48
+ "#FF7633",
49
+ "#FFB633",
50
+ "#D1D435",
51
+ "#4CFB12",
52
+ "#94CF1A",
53
+ "#40DE8A",
54
+ "#1B9640",
55
+ "#00D6C1",
56
+ "#2E9CAA",
57
+ "#00C4FF",
58
+ "#364797",
59
+ "#6675FF",
60
+ "#0019EF",
61
+ "#863AFF",
62
+ ]
63
+ COLORS = [sv.Color.from_hex(color_hex=c) for c in COLORS]
64
 
65
  MAX_NUM_FRAMES = 300
66
 
67
+ keypoint_score = 0.3
68
+ enable_labels_annotator = True
69
+ enable_vertices_annotator = True
70
+
71
+
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
 
74
  person_detector_name = "PekingU/rtdetr_r50vd_coco_o365"
 
82
 
83
  @spaces.GPU(duration=5)
84
  @torch.inference_mode()
85
+ def detect_pose_image(
86
+ image: PIL.Image.Image,
87
+ threshold: float = 0.3,
88
+ enable_labels_annotator: bool = True,
89
+ enable_vertices_annotator: bool = True,
90
+ ) -> tuple[PIL.Image.Image, list[dict]]:
91
  """Detects persons and estimates their poses in a single image.
92
 
93
  Args:
94
  image (PIL.Image.Image): Input image in which to detect persons and estimate poses.
95
+ threshold (Float): Confidence threshold for pose keypoints.
96
+ enable_labels_annotator (bool): Whether to enable annotating labels for pose keypoints.
97
+ enable_vertices_annotator (bool): Whether to enable annotating vertices for pose keypoints
98
 
99
  Returns:
100
  tuple[PIL.Image.Image, list[dict]]:
 
104
  inputs = person_image_processor(images=image, return_tensors="pt").to(device)
105
  outputs = person_model(**inputs)
106
  results = person_image_processor.post_process_object_detection(
107
+ outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=threshold
108
  )
109
  result = results[0] # take first image results
110
 
111
+ detections = sv.Detections.from_transformers(result)
112
+ person_detections_xywh = sv.xyxy_to_xywh(detections[detections.class_id == 0].xyxy)
 
 
 
 
 
 
113
 
114
+ inputs = pose_image_processor(image, boxes=[person_detections_xywh], return_tensors="pt").to(device)
115
 
116
  # for vitpose-plus-base checkpoint we should additionally provide dataset_index
117
  # to specify which MOE experts to use for inference
 
122
 
123
  outputs = pose_model(**inputs)
124
 
125
+ pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_detections_xywh])
126
  image_pose_result = pose_results[0] # results for first image
127
 
128
  # make results more human-readable
129
  human_readable_results = []
130
+ person_pose_labels = []
131
  for i, person_pose in enumerate(image_pose_result):
132
  data = {
133
  "person_id": i,
 
138
  person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True
139
  ):
140
  keypoint_name = pose_model.config.id2label[label.item()]
141
+ person_pose_labels.append(keypoint_name)
142
  x, y = keypoint
143
  data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()})
144
  human_readable_results.append(data)
145
 
146
+ line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=(image.width, image.height))
147
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=(image.width, image.height))
 
 
 
 
 
 
 
148
 
149
+ edge_annotator = sv.EdgeAnnotator(color=sv.Color.WHITE, thickness=line_thickness)
150
+ vertex_annotator = sv.VertexAnnotator(color=sv.Color.BLUE, radius=3)
151
+ box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=3)
152
 
153
+ vertex_label_annotator = sv.VertexLabelAnnotator(
154
+ color=COLORS, smart_position=True, border_radius=3, text_thickness=2, text_scale=text_scale
155
+ )
156
 
157
+ annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections)
158
+
159
+ for _, person_pose in enumerate(image_pose_result):
160
+ person_keypoints = sv.KeyPoints.from_transformers([person_pose])
161
+ person_labels = [pose_model.config.id2label[label.item()] for label in person_pose["labels"]]
162
+ # annotate edges and vertices for this person
163
+ annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=person_keypoints)
164
+ # annotate labels for this person
165
+ if enable_labels_annotator:
166
+ annotated_frame = vertex_label_annotator.annotate(
167
+ scene=np.array(annotated_frame), key_points=person_keypoints, labels=person_labels
168
+ )
169
+ # annotate vertices for this person
170
+ if enable_vertices_annotator:
171
+ annotated_frame = vertex_annotator.annotate(scene=annotated_frame, key_points=person_keypoints)
172
 
173
+ return annotated_frame, human_readable_results
 
 
174
 
175
 
 
176
  def detect_pose_video(
177
  video_path: str,
178
+ threshold: float,
179
+ enable_labels_annotator: bool = True,
180
+ enable_vertices_annotator: bool = True,
181
  progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
182
  ) -> str:
183
  """Detects persons and estimates their poses for each frame in a video, saving the annotated video.
184
 
185
  Args:
186
  video_path (str): Path to the input video file.
187
+ threshold (Float): Confidence threshold for pose keypoints.
188
+ enable_labels_annotator (bool): Whether to enable annotating labels for pose keypoints.
189
+ enable_vertices_annotator (bool): Whether to enable annotating vertices for pose keypoints.
190
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
191
 
192
  Returns:
 
207
  if not ok:
208
  break
209
  rgb_frame = frame[:, :, ::-1]
210
+ annotated_frame, _ = detect_pose_image(
211
+ PIL.Image.fromarray(rgb_frame),
212
+ threshold=threshold,
213
+ enable_labels_annotator=enable_labels_annotator,
214
+ enable_vertices_annotator=enable_vertices_annotator,
215
+ )
216
  writer.write(np.asarray(annotated_frame)[:, :, ::-1])
217
  writer.release()
218
  cap.release()
 
222
  with gr.Blocks(css_paths="style.css") as demo:
223
  gr.Markdown(DESCRIPTION)
224
 
225
+ keypoint_score = gr.Slider(
226
+ minimum=0.0,
227
+ maximum=1.0,
228
+ value=0.6,
229
+ step=0.01,
230
+ info="Adjust the confidence threshold for keypoint detection.",
231
+ label="Keypoint Score Threshold",
232
+ )
233
+ enable_labels_annotator = gr.Checkbox(interactive=True, value=True, label="Enable Labels")
234
+ enable_vertices_annotator = gr.Checkbox(interactive=True, value=True, label="Enable Vertices")
235
+
236
  with gr.Tabs():
237
  with gr.Tab("Image"):
238
  with gr.Row():
 
243
  output_image = gr.Image(label="Output Image")
244
  output_json = gr.JSON(label="Output JSON")
245
  gr.Examples(
246
+ examples=[[str(img), 0.5, True, True] for img in sorted(pathlib.Path("images").glob("*.jpg"))],
247
+ inputs=[input_image, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
248
  outputs=[output_image, output_json],
249
  fn=detect_pose_image,
250
  )
251
 
252
  run_button_image.click(
253
  fn=detect_pose_image,
254
+ inputs=[input_image, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
255
  outputs=[output_image, output_json],
256
  )
257
 
 
266
  output_video = gr.Video(label="Output Video")
267
 
268
  gr.Examples(
269
+ examples=[[str(video), 0.5, True, True] for video in sorted(pathlib.Path("videos").glob("*.mp4"))],
270
+ inputs=[input_video, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
271
  outputs=output_video,
272
  fn=detect_pose_video,
273
  cache_examples=False,
274
  )
275
  run_button_video.click(
276
  fn=detect_pose_video,
277
+ inputs=[input_video, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
278
  outputs=output_video,
279
  )
280
 
pyproject.toml CHANGED
@@ -10,7 +10,7 @@ dependencies = [
10
  "hf-transfer>=0.1.9",
11
  "setuptools>=80.9.0",
12
  "spaces>=0.37.1",
13
- "supervision>=0.25.1",
14
  "torch==2.5.1",
15
  "transformers>=4.53.0",
16
  ]
 
10
  "hf-transfer>=0.1.9",
11
  "setuptools>=80.9.0",
12
  "spaces>=0.37.1",
13
+ "supervision>=0.26.0",
14
  "torch==2.5.1",
15
  "transformers>=4.53.0",
16
  ]
requirements.txt CHANGED
@@ -25,15 +25,11 @@ click==8.1.8
25
  # typer
26
  # uvicorn
27
  contourpy==1.3.1
28
- # via
29
- # matplotlib
30
- # supervision
31
  cycler==0.12.1
32
  # via matplotlib
33
  defusedxml==0.7.1
34
  # via supervision
35
- exceptiongroup==1.2.2
36
- # via anyio
37
  fastapi==0.115.7
38
  # via gradio
39
  ffmpy==0.5.0
@@ -254,7 +250,7 @@ starlette==0.45.3
254
  # fastapi
255
  # gradio
256
  # mcp
257
- supervision==0.25.1
258
  # via vitpose-transformers (pyproject.toml)
259
  sympy==1.13.1
260
  # via torch
@@ -286,12 +282,10 @@ typing-extensions==4.12.2
286
  # huggingface-hub
287
  # pydantic
288
  # pydantic-core
289
- # rich
290
  # spaces
291
  # torch
292
  # typer
293
  # typing-inspection
294
- # uvicorn
295
  typing-inspection==0.4.1
296
  # via
297
  # pydantic
 
25
  # typer
26
  # uvicorn
27
  contourpy==1.3.1
28
+ # via matplotlib
 
 
29
  cycler==0.12.1
30
  # via matplotlib
31
  defusedxml==0.7.1
32
  # via supervision
 
 
33
  fastapi==0.115.7
34
  # via gradio
35
  ffmpy==0.5.0
 
250
  # fastapi
251
  # gradio
252
  # mcp
253
+ supervision==0.26.0
254
  # via vitpose-transformers (pyproject.toml)
255
  sympy==1.13.1
256
  # via torch
 
282
  # huggingface-hub
283
  # pydantic
284
  # pydantic-core
 
285
  # spaces
286
  # torch
287
  # typer
288
  # typing-inspection
 
289
  typing-inspection==0.4.1
290
  # via
291
  # pydantic