Spaces:
Paused
Paused
import cv2 | |
import numpy as np | |
from scipy.optimize import linear_sum_assignment | |
import requests | |
import json | |
from fastapi.responses import JSONResponse | |
from config import API_URL, API_KEY | |
import logging | |
logger = logging.getLogger(__name__) | |
class RepCounter: | |
def __init__(self, fps, height_cm, mass_kg=0, target_reps=None): | |
self.count = 0 | |
self.last_state = None | |
self.cooldown_frames = 15 | |
self.cooldown = 0 | |
self.rep_start_frame = None | |
self.start_wrist_y = None | |
self.rep_data = [] | |
self.power_data = [] | |
self.fps = fps | |
self.cm_per_pixel = None | |
self.real_distance_cm = height_cm * 0.2735 | |
self.calibration_done = False | |
self.mass_kg = mass_kg | |
self.gravity = 9.81 | |
self.target_reps = int(target_reps) | |
self.target_reached = False | |
self.final_speed = None | |
self.final_power = None | |
self.SKELETON = [ | |
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10), | |
(9, 10), (11, 12), (5, 11), (6, 12), | |
(11, 13), (13, 15), (12, 14), (14, 16), (13, 14) | |
] | |
def update(self, wrist_y, knee_y, current_frame): | |
if self.target_reached or self.cooldown > 0: | |
self.cooldown = max(0, self.cooldown - 1) | |
return | |
current_state = 'above' if wrist_y < knee_y else 'below' | |
if self.last_state != current_state: | |
if current_state == 'below': | |
self.rep_start_frame = current_frame | |
self.start_wrist_y = wrist_y | |
elif current_state == 'above' and self.last_state == 'below': | |
if self.rep_start_frame is not None and self.cm_per_pixel is not None: | |
end_frame = current_frame | |
duration = (end_frame - self.rep_start_frame) / self.fps | |
distance_pixels = self.start_wrist_y - wrist_y | |
distance_cm = distance_pixels * self.cm_per_pixel | |
if duration > 0: | |
speed_cmps = abs(distance_cm) / duration | |
self.rep_data.append(speed_cmps) | |
if self.mass_kg > 0: | |
speed_mps = speed_cmps / 100 | |
force = self.mass_kg * self.gravity | |
power = force * speed_mps | |
self.power_data.append(power) | |
self.count += 1 | |
if self.target_reps and self.count >= self.target_reps: | |
self.count = self.target_reps | |
self.target_reached = True | |
self.final_speed = np.mean(self.rep_data) if self.rep_data else 0 | |
self.final_power = np.mean(self.power_data) if self.power_data else 0 | |
self.cooldown = self.cooldown_frames | |
self.last_state = current_state | |
# CentroidTracker class | |
class CentroidTracker: | |
def __init__(self, max_disappeared=50, max_distance=100): | |
self.next_id = 0 | |
self.objects = {} | |
self.max_disappeared = max_disappeared | |
self.max_distance = max_distance | |
def _update_missing(self): | |
to_delete = [] | |
for obj_id in list(self.objects.keys()): | |
self.objects[obj_id]["missed"] += 1 | |
if self.objects[obj_id]["missed"] > self.max_disappeared: | |
to_delete.append(obj_id) | |
for obj_id in to_delete: | |
del self.objects[obj_id] | |
def update(self, detections): | |
if len(detections) == 0: | |
self._update_missing() | |
return [] | |
centroids = np.array([[(x1 + x2) / 2, (y1 + y2) / 2] for x1, y1, x2, y2 in detections]) | |
if len(self.objects) == 0: | |
return self._register_new(centroids) | |
return self._match_existing(centroids, detections) | |
def _register_new(self, centroids): | |
new_ids = [] | |
for centroid in centroids: | |
self.objects[self.next_id] = {"centroid": centroid, "missed": 0} | |
new_ids.append(self.next_id) | |
self.next_id += 1 | |
return new_ids | |
def _match_existing(self, centroids, detections): | |
existing_ids = list(self.objects.keys()) | |
existing_centroids = [self.objects[obj_id]["centroid"] for obj_id in existing_ids] | |
cost = np.linalg.norm(np.array(existing_centroids)[:, np.newaxis] - centroids, axis=2) | |
row_ind, col_ind = linear_sum_assignment(cost) | |
used_rows = set() | |
used_cols = set() | |
matches = {} | |
for (row, col) in zip(row_ind, col_ind): | |
if cost[row, col] <= self.max_distance: | |
obj_id = existing_ids[row] | |
matches[obj_id] = centroids[col] | |
used_rows.add(row) | |
used_cols.add(col) | |
for obj_id in existing_ids: | |
if obj_id not in matches: | |
self.objects[obj_id]["missed"] += 1 | |
if self.objects[obj_id]["missed"] > self.max_disappeared: | |
del self.objects[obj_id] | |
new_ids = [] | |
for col in range(len(centroids)): | |
if col not in used_cols: | |
self.objects[self.next_id] = {"centroid": centroids[col], "missed": 0} | |
new_ids.append(self.next_id) | |
self.next_id += 1 | |
for obj_id, centroid in matches.items(): | |
self.objects[obj_id]["centroid"] = centroid | |
self.objects[obj_id]["missed"] = 0 | |
all_ids = [] | |
for detection in detections: | |
centroid = np.array([(detection[0] + detection[2]) / 2, (detection[1] + detection[3]) / 2]) | |
min_id = None | |
min_dist = float('inf') | |
for obj_id, data in self.objects.items(): | |
dist = np.linalg.norm(centroid - data["centroid"]) | |
if dist < min_dist and dist <= self.max_distance: | |
min_dist = dist | |
min_id = obj_id | |
if min_id is not None: | |
all_ids.append(min_id) | |
self.objects[min_id]["centroid"] = centroid | |
else: | |
all_ids.append(self.next_id) | |
self.objects[self.next_id] = {"centroid": centroid, "missed": 0} | |
self.next_id += 1 | |
return all_ids | |
# Función de procesamiento optimizada | |
def process_frame_for_counting(frame, tracker, rep_counter, frame_number,vitpose): | |
pose_results = vitpose.pipeline(frame) | |
keypoints = pose_results.keypoints_xy.float().cpu().numpy()[0] | |
scores = pose_results.scores.float().cpu().numpy()[0] | |
valid_points = {} | |
wrist_midpoint = None | |
knee_line_y = None | |
print(keypoints) | |
print(scores) | |
# Procesar puntos clave | |
for i, (kp, conf) in enumerate(zip(keypoints, scores)): | |
if conf > 0.3 and 5 <= i <= 16: | |
x, y = map(int, kp[:2]) | |
valid_points[i] = (x, y) | |
# Calibración usando keypoints de rodilla (14) y pie (16) | |
if not rep_counter.calibration_done and 14 in valid_points and 16 in valid_points: | |
knee = valid_points[14] | |
ankle = valid_points[16] | |
pixel_distance = np.sqrt((knee[0] - ankle[0])**2 + (knee[1] - ankle[1])**2) | |
if pixel_distance > 0: | |
rep_counter.cm_per_pixel = rep_counter.real_distance_cm / pixel_distance | |
rep_counter.calibration_done = True | |
# Calcular puntos de referencia para conteo | |
if 9 in valid_points and 10 in valid_points: | |
wrist_midpoint = ( | |
(valid_points[9][0] + valid_points[10][0]) // 2, | |
(valid_points[9][1] + valid_points[10][1]) // 2 | |
) | |
if 13 in valid_points and 14 in valid_points: | |
pt1 = np.array(valid_points[13]) | |
pt2 = np.array(valid_points[14]) | |
direction = pt2 - pt1 | |
extension = 0.2 | |
new_pt1 = pt1 - direction * extension | |
new_pt2 = pt2 + direction * extension | |
knee_line_y = (new_pt1[1] + new_pt2[1]) // 2 | |
# Actualizar contador | |
if wrist_midpoint and knee_line_y: | |
rep_counter.update(wrist_midpoint[1], knee_line_y, frame_number) | |
# Función principal de Gradio | |
def analyze_dead_lift(input_video, reps, weight, height,vitpose,player_id,exercise_id): | |
cap = cv2.VideoCapture(input_video) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
rep_counter = RepCounter(fps, int(height), int(weight), int(reps)) | |
tracker = CentroidTracker(max_distance=150) | |
frame_number = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
process_frame_for_counting(frame, tracker, rep_counter, frame_number,vitpose) | |
frame_number += 1 | |
cap.release() | |
# Preparar payload para webhook | |
if rep_counter.mass_kg > 0: | |
power_data = rep_counter.power_data | |
else: | |
# Si no hay masa, usar ceros para potencia | |
power_data = [0] * len(rep_counter.rep_data) if rep_counter.rep_data else [] | |
# Asegurar que tenemos datos para enviar | |
if rep_counter.rep_data: | |
payload = {"repetition_data": [ | |
{"repetition": i, "velocidad": round(s,1), "potencia": round(p,1)} | |
for i, (s, p) in enumerate(zip(rep_counter.rep_data, power_data), start=1) | |
]} | |
else: | |
# En caso de no detectar repeticiones | |
payload = {"repetition_data": []} | |
send_results_api(payload, player_id, exercise_id, input_video) | |
def send_results_api(results_dict: dict, | |
player_id: str, | |
exercise_id: str, | |
video_path: str) -> JSONResponse: | |
""" | |
Send video analysis results to the API webhook endpoint. | |
This function uploads the analyzed video file along with the computed metrics | |
to the API's webhook endpoint for processing and storage. | |
Args: | |
results_dict (dict): Dictionary containing analysis results including: | |
- video_analysis: Information about the processed video | |
- repetition_data: List of metrics for each jump repetition | |
player_id (str): Unique identifier for the player | |
exercise_id (str): Unique identifier for the exercise | |
video_path (str): Path to the video file to upload | |
Returns: | |
JSONResponse: HTTP response from the API endpoint | |
Raises: | |
FileNotFoundError: If the video file doesn't exist | |
requests.RequestException: If the API request fails | |
json.JSONEncodeError: If results_dict cannot be serialized to JSON | |
""" | |
url = API_URL + "/exercises/webhooks/video-processed-results" | |
logger.info(f"Sending video results to {url}") | |
# Open the video file | |
with open(video_path, 'rb') as video_file: | |
# Prepare the files dictionary for file upload | |
files = { | |
'file': (video_path.split('/')[-1], video_file, 'video/mp4') | |
} | |
# Prepare the form data | |
data = { | |
'player_id': player_id, | |
'exercise_id': exercise_id, | |
'results': json.dumps(results_dict) # Convert dict to JSON string | |
} | |
# Send the request with both files and data | |
response = requests.post( | |
url, | |
headers={"token": API_KEY}, | |
files=files, | |
data=data, | |
stream=True | |
) | |
logger.info(f"Response: {response.status_code}") | |
logger.info(f"Response: {response.text}") | |
return response |