File size: 2,676 Bytes
5eaa37a
 
3bf3b3c
5eaa37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba6e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# lbw_detector.py
import torch
import numpy as np
from torchvision import transforms
import cv2
from utils import extract_frames
from trajectory_predictor import predict_trajectory
from visualizer import draw_visuals

# Load the custom LBW model
model_path = "models/lbw_drs_unet_model.pth"
device = "cpu"  # Hugging Face Free Tier

model = torch.load(model_path, map_location=device)
model.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
])

def detect_objects_with_model(frame):
    """Run segmentation on a frame using the custom model"""
    input_tensor = transform(frame).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
    # Convert output to mask
    mask = torch.sigmoid(output).squeeze().cpu().numpy()
    return mask  # Assumed to be binary mask (ball/pad/stump segmentation)

def analyze_video(video_path):
    frames = extract_frames(video_path)

    ball_positions = []
    impact_frame_idx = None
    impact_zone = "unknown"

    for i, frame in enumerate(frames):
        mask = detect_objects_with_model(frame)

        # Very simple segmentation logic
        ball_mask = mask[0] > 0.5  # channel 0 for ball
        pad_mask = mask[1] > 0.5 if mask.ndim > 2 else None  # channel 1 for pad

        # Detect ball center
        contours, _ = cv2.findContours(ball_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if contours:
            largest = max(contours, key=cv2.contourArea)
            M = cv2.moments(largest)
            if M['m00'] != 0:
                cx = int(M['m10']/M['m00'])
                cy = int(M['m01']/M['m00'])
                ball_positions.append((i, cx, cy))

        # Detect pad hit (optional logic: ball near pad area)
        if pad_mask is not None and contours:
            overlap = np.logical_and(ball_mask, pad_mask).sum()
            if overlap > 10:  # simple overlap threshold
                impact_frame_idx = i
                impact_zone = "pad"
                break

    # Run trajectory prediction if ball was detected
    trajectory = predict_trajectory(ball_positions)

    # Predict outcome
    decision = "OUT" if trajectory_hits_stumps(trajectory) and impact_zone == "pad" else "NOT OUT"

    # Visualize
    result_path = draw_visuals(frames, ball_positions, trajectory, impact_frame_idx, decision)

    return result_path, decision


def trajectory_hits_stumps(trajectory):
    # Simple rule-based check (assuming stumps are around x=300 to 340 px for now)
    for (x, y) in trajectory:
        if 300 < x < 340 and y < 480:  # ball projected height intersects stump zone
            return True
    return False