lbw_drs_app_new / lbw_detector.py
dschandra's picture
Update lbw_detector.py
5eaa37a verified
raw
history blame
846 Bytes
# lbw_detector.py
import torch
import numpy as np
from torchvision import transforms
import cv2
from utils import extract_frames
from trajectory_predictor import predict_trajectory
from visualizer import draw_visuals
# Load the custom LBW model
model_path = "models/lbw_drs_unet_model.pth"
device = "cpu" # Hugging Face Free Tier
model = torch.load(model_path, map_location=device)
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
])
def detect_objects_with_model(frame):
"""Run segmentation on a frame using the custom model"""
input_tensor = transform(frame).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
# Convert output to mask
mask = torch.sigmoid(output).squeeze().cpu().numpy()
return mask # Assumed to be binary mask (ball/pad/stump segmentation)