Spaces:
Sleeping
Sleeping
# lbw_detector.py | |
import torch | |
import numpy as np | |
from torchvision import transforms | |
import cv2 | |
from utils import extract_frames | |
from trajectory_predictor import predict_trajectory | |
from visualizer import draw_visuals | |
# Load the custom LBW model | |
model_path = "models/lbw_drs_unet_model.pth" | |
device = "cpu" # Hugging Face Free Tier | |
model = torch.load(model_path, map_location=device) | |
model.eval() | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
def detect_objects_with_model(frame): | |
"""Run segmentation on a frame using the custom model""" | |
input_tensor = transform(frame).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
# Convert output to mask | |
mask = torch.sigmoid(output).squeeze().cpu().numpy() | |
return mask # Assumed to be binary mask (ball/pad/stump segmentation) | |
def analyze_video(video_path): | |
frames = extract_frames(video_path) | |
ball_positions = [] | |
impact_frame_idx = None | |
impact_zone = "unknown" | |
for i, frame in enumerate(frames): | |
mask = detect_objects_with_model(frame) | |
# Very simple segmentation logic | |
ball_mask = mask[0] > 0.5 # channel 0 for ball | |
pad_mask = mask[1] > 0.5 if mask.ndim > 2 else None # channel 1 for pad | |
# Detect ball center | |
contours, _ = cv2.findContours(ball_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
largest = max(contours, key=cv2.contourArea) | |
M = cv2.moments(largest) | |
if M['m00'] != 0: | |
cx = int(M['m10']/M['m00']) | |
cy = int(M['m01']/M['m00']) | |
ball_positions.append((i, cx, cy)) | |
# Detect pad hit (optional logic: ball near pad area) | |
if pad_mask is not None and contours: | |
overlap = np.logical_and(ball_mask, pad_mask).sum() | |
if overlap > 10: # simple overlap threshold | |
impact_frame_idx = i | |
impact_zone = "pad" | |
break | |
# Run trajectory prediction if ball was detected | |
trajectory = predict_trajectory(ball_positions) | |
# Predict outcome | |
decision = "OUT" if trajectory_hits_stumps(trajectory) and impact_zone == "pad" else "NOT OUT" | |
# Visualize | |
result_path = draw_visuals(frames, ball_positions, trajectory, impact_frame_idx, decision) | |
return result_path, decision | |
def trajectory_hits_stumps(trajectory): | |
# Simple rule-based check (assuming stumps are around x=300 to 340 px for now) | |
for (x, y) in trajectory: | |
if 300 < x < 340 and y < 480: # ball projected height intersects stump zone | |
return True | |
return False | |