nicolasbuitragob commited on
Commit
478a673
1 Parent(s): d1dd306

dead lift exercise

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +45 -1
  3. src/exercises/dead_lift.py +316 -0
  4. tasks.py +2 -0
.gitignore CHANGED
@@ -175,4 +175,5 @@ cython_debug/
175
 
176
  static/
177
  *.mp4
178
- *.MOV
 
 
175
 
176
  static/
177
  *.mp4
178
+ *.MOV
179
+ *.json
app.py CHANGED
@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
14
  from config import AI_API_TOKEN
15
  import logging
16
  import json
 
17
 
18
 
19
  logging.basicConfig(level=logging.INFO)
@@ -73,7 +74,7 @@ async def upload(background_tasks: BackgroundTasks,
73
  token: str = Header(...),
74
  player_data: str = Body(...),
75
  repetitions: int|str = Body(...),
76
- exercise_id: str = Body(...)
77
  ):
78
 
79
 
@@ -106,3 +107,46 @@ async def upload(background_tasks: BackgroundTasks,
106
  return JSONResponse(content={"message": "Video uploaded successfully",
107
  "status": 200})
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from config import AI_API_TOKEN
15
  import logging
16
  import json
17
+ from src.exercises.dead_lift import analyze_dead_lift
18
 
19
 
20
  logging.basicConfig(level=logging.INFO)
 
74
  token: str = Header(...),
75
  player_data: str = Body(...),
76
  repetitions: int|str = Body(...),
77
+ exercise_id: str = Body(...),
78
  ):
79
 
80
 
 
107
  return JSONResponse(content={"message": "Video uploaded successfully",
108
  "status": 200})
109
 
110
+
111
+
112
+ @app.post("/exercise/dead_lift")
113
+ async def upload(background_tasks: BackgroundTasks,
114
+ file: UploadFile = File(...),
115
+ token: str = Header(...),
116
+ player_data: str = Body(...),
117
+ repetitions: int|str = Body(...),
118
+ exercise_id: str = Body(...),
119
+ weight: int|str = Body(...)
120
+ ):
121
+
122
+ player_data = json.loads(player_data)
123
+
124
+ if token != AI_API_TOKEN:
125
+
126
+ raise HTTPException(status_code=401, detail="Unauthorized")
127
+
128
+ logger.info("reading contents")
129
+ contents = await file.read()
130
+
131
+ # Save the file to the local directory
132
+ logger.info("saving file")
133
+ with open(file.filename, "wb") as f:
134
+ f.write(contents)
135
+
136
+ logger.info(f"file saved {file.filename}")
137
+
138
+
139
+
140
+ background_tasks.add_task(analyze_dead_lift,
141
+ file.filename,
142
+ repetitions,
143
+ weight,
144
+ player_data["height"],
145
+ vitpose,
146
+ player_data["id"],
147
+ exercise_id)
148
+
149
+ return JSONResponse(content={"message": "Video uploaded successfully",
150
+ "status": 200})
151
+
152
+
src/exercises/dead_lift.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from scipy.optimize import linear_sum_assignment
4
+ import requests
5
+ import json
6
+ from fastapi.responses import JSONResponse
7
+ from config import API_URL, API_KEY
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class RepCounter:
13
+ def __init__(self, fps, height_cm, mass_kg=0, target_reps=None):
14
+ self.count = 0
15
+ self.last_state = None
16
+ self.cooldown_frames = 15
17
+ self.cooldown = 0
18
+ self.rep_start_frame = None
19
+ self.start_wrist_y = None
20
+ self.rep_data = []
21
+ self.power_data = []
22
+ self.fps = fps
23
+ self.cm_per_pixel = None
24
+ self.real_distance_cm = height_cm * 0.2735
25
+ self.calibration_done = False
26
+ self.mass_kg = mass_kg
27
+ self.gravity = 9.81
28
+ self.target_reps = int(target_reps)
29
+ self.target_reached = False
30
+ self.final_speed = None
31
+ self.final_power = None
32
+ self.SKELETON = [
33
+ (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
34
+ (9, 10), (11, 12), (5, 11), (6, 12),
35
+ (11, 13), (13, 15), (12, 14), (14, 16), (13, 14)
36
+ ]
37
+
38
+ def update(self, wrist_y, knee_y, current_frame):
39
+ if self.target_reached or self.cooldown > 0:
40
+ self.cooldown = max(0, self.cooldown - 1)
41
+ return
42
+
43
+ current_state = 'above' if wrist_y < knee_y else 'below'
44
+
45
+ if self.last_state != current_state:
46
+ if current_state == 'below':
47
+ self.rep_start_frame = current_frame
48
+ self.start_wrist_y = wrist_y
49
+ elif current_state == 'above' and self.last_state == 'below':
50
+ if self.rep_start_frame is not None and self.cm_per_pixel is not None:
51
+ end_frame = current_frame
52
+ duration = (end_frame - self.rep_start_frame) / self.fps
53
+ distance_pixels = self.start_wrist_y - wrist_y
54
+ distance_cm = distance_pixels * self.cm_per_pixel
55
+
56
+ if duration > 0:
57
+ speed_cmps = abs(distance_cm) / duration
58
+ self.rep_data.append(speed_cmps)
59
+
60
+ if self.mass_kg > 0:
61
+ speed_mps = speed_cmps / 100
62
+ force = self.mass_kg * self.gravity
63
+ power = force * speed_mps
64
+ self.power_data.append(power)
65
+
66
+ self.count += 1
67
+ if self.target_reps and self.count >= self.target_reps:
68
+ self.count = self.target_reps
69
+ self.target_reached = True
70
+ self.final_speed = np.mean(self.rep_data) if self.rep_data else 0
71
+ self.final_power = np.mean(self.power_data) if self.power_data else 0
72
+
73
+ self.cooldown = self.cooldown_frames
74
+
75
+ self.last_state = current_state
76
+
77
+ # CentroidTracker class
78
+ class CentroidTracker:
79
+ def __init__(self, max_disappeared=50, max_distance=100):
80
+ self.next_id = 0
81
+ self.objects = {}
82
+ self.max_disappeared = max_disappeared
83
+ self.max_distance = max_distance
84
+
85
+ def _update_missing(self):
86
+ to_delete = []
87
+ for obj_id in list(self.objects.keys()):
88
+ self.objects[obj_id]["missed"] += 1
89
+ if self.objects[obj_id]["missed"] > self.max_disappeared:
90
+ to_delete.append(obj_id)
91
+ for obj_id in to_delete:
92
+ del self.objects[obj_id]
93
+
94
+ def update(self, detections):
95
+ if len(detections) == 0:
96
+ self._update_missing()
97
+ return []
98
+
99
+ centroids = np.array([[(x1 + x2) / 2, (y1 + y2) / 2] for x1, y1, x2, y2 in detections])
100
+
101
+ if len(self.objects) == 0:
102
+ return self._register_new(centroids)
103
+
104
+ return self._match_existing(centroids, detections)
105
+
106
+ def _register_new(self, centroids):
107
+ new_ids = []
108
+ for centroid in centroids:
109
+ self.objects[self.next_id] = {"centroid": centroid, "missed": 0}
110
+ new_ids.append(self.next_id)
111
+ self.next_id += 1
112
+ return new_ids
113
+
114
+ def _match_existing(self, centroids, detections):
115
+ existing_ids = list(self.objects.keys())
116
+ existing_centroids = [self.objects[obj_id]["centroid"] for obj_id in existing_ids]
117
+
118
+ cost = np.linalg.norm(np.array(existing_centroids)[:, np.newaxis] - centroids, axis=2)
119
+ row_ind, col_ind = linear_sum_assignment(cost)
120
+
121
+ used_rows = set()
122
+ used_cols = set()
123
+ matches = {}
124
+
125
+ for (row, col) in zip(row_ind, col_ind):
126
+ if cost[row, col] <= self.max_distance:
127
+ obj_id = existing_ids[row]
128
+ matches[obj_id] = centroids[col]
129
+ used_rows.add(row)
130
+ used_cols.add(col)
131
+
132
+ for obj_id in existing_ids:
133
+ if obj_id not in matches:
134
+ self.objects[obj_id]["missed"] += 1
135
+ if self.objects[obj_id]["missed"] > self.max_disappeared:
136
+ del self.objects[obj_id]
137
+
138
+ new_ids = []
139
+ for col in range(len(centroids)):
140
+ if col not in used_cols:
141
+ self.objects[self.next_id] = {"centroid": centroids[col], "missed": 0}
142
+ new_ids.append(self.next_id)
143
+ self.next_id += 1
144
+
145
+ for obj_id, centroid in matches.items():
146
+ self.objects[obj_id]["centroid"] = centroid
147
+ self.objects[obj_id]["missed"] = 0
148
+
149
+ all_ids = []
150
+ for detection in detections:
151
+ centroid = np.array([(detection[0] + detection[2]) / 2, (detection[1] + detection[3]) / 2])
152
+ min_id = None
153
+ min_dist = float('inf')
154
+ for obj_id, data in self.objects.items():
155
+ dist = np.linalg.norm(centroid - data["centroid"])
156
+ if dist < min_dist and dist <= self.max_distance:
157
+ min_dist = dist
158
+ min_id = obj_id
159
+ if min_id is not None:
160
+ all_ids.append(min_id)
161
+ self.objects[min_id]["centroid"] = centroid
162
+ else:
163
+ all_ids.append(self.next_id)
164
+ self.objects[self.next_id] = {"centroid": centroid, "missed": 0}
165
+ self.next_id += 1
166
+
167
+ return all_ids
168
+
169
+ # Funci贸n de procesamiento optimizada
170
+ def process_frame_for_counting(frame, tracker, rep_counter, frame_number,vitpose):
171
+
172
+ pose_results = vitpose.pipeline(frame)
173
+
174
+
175
+ keypoints = pose_results.keypoints_xy.float().cpu().numpy()[0]
176
+ scores = pose_results.scores.float().cpu().numpy()[0]
177
+ valid_points = {}
178
+ wrist_midpoint = None
179
+ knee_line_y = None
180
+
181
+
182
+ print(keypoints)
183
+ print(scores)
184
+
185
+ # Procesar puntos clave
186
+ for i, (kp, conf) in enumerate(zip(keypoints, scores)):
187
+ if conf > 0.3 and 5 <= i <= 16:
188
+ x, y = map(int, kp[:2])
189
+ valid_points[i] = (x, y)
190
+
191
+ # Calibraci贸n usando keypoints de rodilla (14) y pie (16)
192
+ if not rep_counter.calibration_done and 14 in valid_points and 16 in valid_points:
193
+ knee = valid_points[14]
194
+ ankle = valid_points[16]
195
+ pixel_distance = np.sqrt((knee[0] - ankle[0])**2 + (knee[1] - ankle[1])**2)
196
+ if pixel_distance > 0:
197
+ rep_counter.cm_per_pixel = rep_counter.real_distance_cm / pixel_distance
198
+ rep_counter.calibration_done = True
199
+
200
+ # Calcular puntos de referencia para conteo
201
+ if 9 in valid_points and 10 in valid_points:
202
+ wrist_midpoint = (
203
+ (valid_points[9][0] + valid_points[10][0]) // 2,
204
+ (valid_points[9][1] + valid_points[10][1]) // 2
205
+ )
206
+ if 13 in valid_points and 14 in valid_points:
207
+ pt1 = np.array(valid_points[13])
208
+ pt2 = np.array(valid_points[14])
209
+ direction = pt2 - pt1
210
+ extension = 0.2
211
+ new_pt1 = pt1 - direction * extension
212
+ new_pt2 = pt2 + direction * extension
213
+ knee_line_y = (new_pt1[1] + new_pt2[1]) // 2
214
+
215
+ # Actualizar contador
216
+ if wrist_midpoint and knee_line_y:
217
+ rep_counter.update(wrist_midpoint[1], knee_line_y, frame_number)
218
+
219
+
220
+
221
+ # Funci贸n principal de Gradio
222
+ def analyze_dead_lift(input_video, reps, weight, height,vitpose,player_id,exercise_id):
223
+ cap = cv2.VideoCapture(input_video)
224
+ fps = cap.get(cv2.CAP_PROP_FPS)
225
+
226
+
227
+ rep_counter = RepCounter(fps, int(height), int(weight), int(reps))
228
+ tracker = CentroidTracker(max_distance=150)
229
+
230
+ frame_number = 0
231
+ while cap.isOpened():
232
+ ret, frame = cap.read()
233
+ if not ret:
234
+ break
235
+
236
+ process_frame_for_counting(frame, tracker, rep_counter, frame_number,vitpose)
237
+ frame_number += 1
238
+
239
+ cap.release()
240
+
241
+ # Preparar payload para webhook
242
+ if rep_counter.mass_kg > 0:
243
+ power_data = rep_counter.power_data
244
+ else:
245
+ # Si no hay masa, usar ceros para potencia
246
+ power_data = [0] * len(rep_counter.rep_data) if rep_counter.rep_data else []
247
+
248
+ # Asegurar que tenemos datos para enviar
249
+ if rep_counter.rep_data:
250
+ payload = {"repetition_data": [
251
+ {"repetition": i, "velocidad": round(s,1), "potencia": round(p,1)}
252
+ for i, (s, p) in enumerate(zip(rep_counter.rep_data, power_data), start=1)
253
+ ]}
254
+ else:
255
+ # En caso de no detectar repeticiones
256
+ payload = {"repetition_data": []}
257
+
258
+
259
+ send_results_api(payload, player_id, exercise_id, input_video)
260
+
261
+
262
+ def send_results_api(results_dict: dict,
263
+ player_id: str,
264
+ exercise_id: str,
265
+ video_path: str) -> JSONResponse:
266
+ """
267
+ Send video analysis results to the API webhook endpoint.
268
+
269
+ This function uploads the analyzed video file along with the computed metrics
270
+ to the API's webhook endpoint for processing and storage.
271
+
272
+ Args:
273
+ results_dict (dict): Dictionary containing analysis results including:
274
+ - video_analysis: Information about the processed video
275
+ - repetition_data: List of metrics for each jump repetition
276
+ player_id (str): Unique identifier for the player
277
+ exercise_id (str): Unique identifier for the exercise
278
+ video_path (str): Path to the video file to upload
279
+
280
+ Returns:
281
+ JSONResponse: HTTP response from the API endpoint
282
+
283
+ Raises:
284
+ FileNotFoundError: If the video file doesn't exist
285
+ requests.RequestException: If the API request fails
286
+ json.JSONEncodeError: If results_dict cannot be serialized to JSON
287
+ """
288
+ url = API_URL + "/exercises/webhooks/video-processed-results"
289
+ logger.info(f"Sending video results to {url}")
290
+
291
+ # Open the video file
292
+ with open(video_path, 'rb') as video_file:
293
+ # Prepare the files dictionary for file upload
294
+ files = {
295
+ 'file': (video_path.split('/')[-1], video_file, 'video/mp4')
296
+ }
297
+
298
+ # Prepare the form data
299
+ data = {
300
+ 'player_id': player_id,
301
+ 'exercise_id': exercise_id,
302
+ 'results': json.dumps(results_dict) # Convert dict to JSON string
303
+ }
304
+
305
+ # Send the request with both files and data
306
+ response = requests.post(
307
+ url,
308
+ headers={"token": API_KEY},
309
+ files=files,
310
+ data=data,
311
+ stream=True
312
+ )
313
+
314
+ logger.info(f"Response: {response.status_code}")
315
+ logger.info(f"Response: {response.text}")
316
+ return response
tasks.py CHANGED
@@ -97,6 +97,8 @@ class FramePosition:
97
  y: int
98
  width: int
99
  height: int
 
 
100
  def process_video(file_name: str,vitpose: VitPose,user_id: str,player_id: str):
101
  """
102
  Process a video file using VitPose for pose estimation and send results to webhook.
 
97
  y: int
98
  width: int
99
  height: int
100
+
101
+
102
  def process_video(file_name: str,vitpose: VitPose,user_id: str,player_id: str):
103
  """
104
  Process a video file using VitPose for pose estimation and send results to webhook.