Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import logging | |
from PIL import Image, ImageDraw, ImageFont | |
from huggingface_hub import hf_hub_download | |
logger = logging.getLogger(__name__) | |
class ObjectDetector: | |
def __init__(self, model_key="yolov8n", device="cpu"): | |
self.device = device | |
self.model = None | |
self.model_key = model_key.lower().replace(".pt", "") | |
hf_map = { | |
"yolov8n": ("ultralytics/yolov8", "yolov8n.pt"), | |
"yolov8s": ("ultralytics/yolov8", "yolov8s.pt"), | |
"yolov8l": ("ultralytics/yolov8", "yolov8l.pt"), | |
"yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"), | |
} | |
if self.model_key not in hf_map: | |
raise ValueError(f"Unsupported model key: {self.model_key}") | |
repo_id, filename = hf_map[self.model_key] | |
self.weights_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
cache_dir="models/detection/weights", | |
force_download=False | |
) | |
def load_model(self): | |
logger.info(f"Loading model from path: {self.weights_path}") | |
if self.model is None: | |
import torch # Safe to import here | |
from ultralytics import YOLO # Defer import | |
if self.device == "cpu": | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
# Initialize model | |
self.model = YOLO(self.weights_path) | |
# Move to CUDA only if necessary and safe | |
if self.device == "cuda" and torch.cuda.is_available(): | |
self.model.to("cuda") | |
return self | |
def predict(self, image: Image.Image, conf_threshold=0.25): | |
self.load_model() | |
if self.model is None: | |
raise RuntimeError("YOLO model not loaded. Call load_model() first.") | |
results = self.model(image) | |
detections = [] | |
for r in results: | |
for box in r.boxes: | |
detections.append({ | |
"class_name": r.names[int(box.cls)], | |
"confidence": float(box.conf), | |
"bbox": box.xyxy[0].tolist() | |
}) | |
return detections | |
def draw(self, image: Image.Image, detections, alpha=0.5): | |
overlay = image.copy() | |
draw = ImageDraw.Draw(overlay) | |
# Load font | |
try: | |
font = ImageFont.truetype("arial.ttf", 16) | |
except: | |
font = ImageFont.load_default() | |
for det in detections: | |
bbox = det["bbox"] | |
label = f'{det["class_name"]} {det["confidence"]:.2f}' | |
# Draw thicker bounding box | |
for offset in range(3): | |
draw.rectangle( | |
[bbox[0] - offset, bbox[1] - offset, bbox[2] + offset, bbox[3] + offset], | |
outline="red" | |
) | |
# Calculate text size using textbbox | |
text_bbox = draw.textbbox((bbox[0], bbox[1]), label, font=font) | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
# Define background rectangle for text | |
text_bg = [bbox[0], bbox[1] - text_height, bbox[0] + text_width + 4, bbox[1]] | |
draw.rectangle(text_bg, fill="red") | |
# Draw text | |
draw.text((bbox[0] + 2, bbox[1] - text_height), label, fill="white", font=font) | |
return Image.blend(image, overlay, alpha) | |