File size: 846 Bytes
5eaa37a
 
3bf3b3c
5eaa37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# lbw_detector.py
import torch
import numpy as np
from torchvision import transforms
import cv2
from utils import extract_frames
from trajectory_predictor import predict_trajectory
from visualizer import draw_visuals

# Load the custom LBW model
model_path = "models/lbw_drs_unet_model.pth"
device = "cpu"  # Hugging Face Free Tier

model = torch.load(model_path, map_location=device)
model.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
])

def detect_objects_with_model(frame):
    """Run segmentation on a frame using the custom model"""
    input_tensor = transform(frame).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
    # Convert output to mask
    mask = torch.sigmoid(output).squeeze().cpu().numpy()
    return mask  # Assumed to be binary mask (ball/pad/stump segmentation)