# lbw_detector.py import torch import numpy as np from torchvision import transforms import cv2 from utils import extract_frames from trajectory_predictor import predict_trajectory from visualizer import draw_visuals # Load the custom LBW model model_path = "models/lbw_drs_unet_model.pth" device = "cpu" # Hugging Face Free Tier model = torch.load(model_path, map_location=device) model.eval() transform = transforms.Compose([ transforms.ToTensor(), ]) def detect_objects_with_model(frame): """Run segmentation on a frame using the custom model""" input_tensor = transform(frame).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) # Convert output to mask mask = torch.sigmoid(output).squeeze().cpu().numpy() return mask # Assumed to be binary mask (ball/pad/stump segmentation) def analyze_video(video_path): frames = extract_frames(video_path) ball_positions = [] impact_frame_idx = None impact_zone = "unknown" for i, frame in enumerate(frames): mask = detect_objects_with_model(frame) # Very simple segmentation logic ball_mask = mask[0] > 0.5 # channel 0 for ball pad_mask = mask[1] > 0.5 if mask.ndim > 2 else None # channel 1 for pad # Detect ball center contours, _ = cv2.findContours(ball_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: largest = max(contours, key=cv2.contourArea) M = cv2.moments(largest) if M['m00'] != 0: cx = int(M['m10']/M['m00']) cy = int(M['m01']/M['m00']) ball_positions.append((i, cx, cy)) # Detect pad hit (optional logic: ball near pad area) if pad_mask is not None and contours: overlap = np.logical_and(ball_mask, pad_mask).sum() if overlap > 10: # simple overlap threshold impact_frame_idx = i impact_zone = "pad" break # Run trajectory prediction if ball was detected trajectory = predict_trajectory(ball_positions) # Predict outcome decision = "OUT" if trajectory_hits_stumps(trajectory) and impact_zone == "pad" else "NOT OUT" # Visualize result_path = draw_visuals(frames, ball_positions, trajectory, impact_frame_idx, decision) return result_path, decision def trajectory_hits_stumps(trajectory): # Simple rule-based check (assuming stumps are around x=300 to 340 px for now) for (x, y) in trajectory: if 300 < x < 340 and y < 480: # ball projected height intersects stump zone return True return False