lbw_drs_app_new / lbw_detector.py
dschandra's picture
Update lbw_detector.py
8ba6e46 verified
raw
history blame
2.68 kB
# lbw_detector.py
import torch
import numpy as np
from torchvision import transforms
import cv2
from utils import extract_frames
from trajectory_predictor import predict_trajectory
from visualizer import draw_visuals
# Load the custom LBW model
model_path = "models/lbw_drs_unet_model.pth"
device = "cpu" # Hugging Face Free Tier
model = torch.load(model_path, map_location=device)
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
])
def detect_objects_with_model(frame):
"""Run segmentation on a frame using the custom model"""
input_tensor = transform(frame).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
# Convert output to mask
mask = torch.sigmoid(output).squeeze().cpu().numpy()
return mask # Assumed to be binary mask (ball/pad/stump segmentation)
def analyze_video(video_path):
frames = extract_frames(video_path)
ball_positions = []
impact_frame_idx = None
impact_zone = "unknown"
for i, frame in enumerate(frames):
mask = detect_objects_with_model(frame)
# Very simple segmentation logic
ball_mask = mask[0] > 0.5 # channel 0 for ball
pad_mask = mask[1] > 0.5 if mask.ndim > 2 else None # channel 1 for pad
# Detect ball center
contours, _ = cv2.findContours(ball_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
largest = max(contours, key=cv2.contourArea)
M = cv2.moments(largest)
if M['m00'] != 0:
cx = int(M['m10']/M['m00'])
cy = int(M['m01']/M['m00'])
ball_positions.append((i, cx, cy))
# Detect pad hit (optional logic: ball near pad area)
if pad_mask is not None and contours:
overlap = np.logical_and(ball_mask, pad_mask).sum()
if overlap > 10: # simple overlap threshold
impact_frame_idx = i
impact_zone = "pad"
break
# Run trajectory prediction if ball was detected
trajectory = predict_trajectory(ball_positions)
# Predict outcome
decision = "OUT" if trajectory_hits_stumps(trajectory) and impact_zone == "pad" else "NOT OUT"
# Visualize
result_path = draw_visuals(frames, ball_positions, trajectory, impact_frame_idx, decision)
return result_path, decision
def trajectory_hits_stumps(trajectory):
# Simple rule-based check (assuming stumps are around x=300 to 340 px for now)
for (x, y) in trajectory:
if 300 < x < 340 and y < 480: # ball projected height intersects stump zone
return True
return False