import logging from PIL import Image, ImageDraw from huggingface_hub import hf_hub_download from ultralytics import YOLO import shutil logger = logging.getLogger(__name__) shutil.rmtree("models/detection/weights", ignore_errors=True) class ObjectDetector: """ Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants. """ def __init__(self, model_key="yolov5n", device="cpu"): """ Initialize the Object Detection model. Args: model_key (str): Model identifier as defined in model_downloader.py. weights_dir (str): Directory to store/download model weights. device (str): Inference device ("cpu" or "cuda"). """ alias_map = { "yolov5n-seg": "yolov5n", "yolov5s-seg": "yolov5s", "yolov8s": "yolov8s", "yolov8l": "yolov8l", "yolov11b": "yolov11b", "rtdetr": "rtdetr" } raw_key = model_key.lower() model_key = alias_map.get(raw_key, raw_key) #to make model_key case-insensitive repo_map = { "yolov5n": ("ultralytics/yolov5", "yolov5n.pt"), "yolov5s": ("ultralytics/yolov5", "yolov5s.pt"), "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"), "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"), "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"), "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"), "rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth") } if model_key not in repo_map: raise ValueError(f"Unsupported model_key: {model_key}") repo_id, filename = repo_map[model_key] weights_path = hf_hub_download( repo_id=repo_id, filename=filename, cache_dir="models/detection/weights", force_download=True #Clear cache ) self.device = device print("Loading weights from:", weights_path) self.model = YOLO(weights_path) print("Model object type:", type(self.model)) print("Model class string:", self.model.__class__) def predict(self, image: Image.Image): """ Run object detection. Args: image (PIL.Image.Image): Input image. Returns: List[Dict]: List of detected objects with class name, confidence, and bbox. """ logger.info("Running object detection") results = self.model(image) detections = [] for r in results: for box in r.boxes: detections.append({ "class_name": r.names[int(box.cls)], "confidence": float(box.conf), "bbox": box.xyxy[0].tolist() }) logger.info(f"Detected {len(detections)} objects") return detections def draw(self, image: Image.Image, detections, alpha=0.5): """ Draw bounding boxes on image. Args: image (PIL.Image.Image): Input image. detections (List[Dict]): Detection results. alpha (float): Blend strength. Returns: PIL.Image.Image: Image with bounding boxes drawn. """ overlay = image.copy() draw = ImageDraw.Draw(overlay) for det in detections: bbox = det["bbox"] label = f'{det["class_name"]} {det["confidence"]:.2f}' draw.rectangle(bbox, outline="red", width=2) draw.text((bbox[0], bbox[1]), label, fill="red") return Image.blend(image, overlay, alpha)