File size: 3,645 Bytes
ce2b58f
5567efd
9538100
ce2b58f
8386bf1
9538100
ce2b58f
 
8386bf1
 
ce2b58f
 
 
 
 
9538100
ce2b58f
 
 
 
 
 
 
 
4cfdbcf
 
 
 
 
 
 
 
0bd515d
 
 
ce2b58f
 
b9a2f92
ce2b58f
9538100
 
 
 
ce2b58f
 
 
 
 
 
 
 
 
 
5567efd
 
ce2b58f
a79c4c9
ce2b58f
 
a79c4c9
ce2b58f
e999761
 
 
ce2b58f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import shutil



logger = logging.getLogger(__name__)
shutil.rmtree("models/detection/weights", ignore_errors=True)
class ObjectDetector:
    """
    Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants.
    """

    def __init__(self, model_key="yolov5n", device="cpu"):
        """
        Initialize the Object Detection model.

        Args:
            model_key (str): Model identifier as defined in model_downloader.py.
            weights_dir (str): Directory to store/download model weights.
            device (str): Inference device ("cpu" or "cuda").
        """
        alias_map = {
            "yolov5n-seg": "yolov5n",
            "yolov5s-seg": "yolov5s",
            "yolov8s": "yolov8s",
            "yolov8l": "yolov8l",
            "yolov11b": "yolov11b",
            "rtdetr": "rtdetr"
        }

        raw_key = model_key.lower()
        model_key = alias_map.get(raw_key, raw_key) #to make model_key case-insensitive
        repo_map = {
            "yolov5n": ("ultralytics/yolov5", "yolov5n.pt"),
            "yolov5s": ("ultralytics/yolov5", "yolov5s.pt"),
            "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
            "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
            "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
            "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
            "rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth")
        }
        
        if model_key not in repo_map:
            raise ValueError(f"Unsupported model_key: {model_key}")

        repo_id, filename = repo_map[model_key]

        weights_path = hf_hub_download(
            repo_id=repo_id,
            filename=filename,
            cache_dir="models/detection/weights",
            force_download=True #Clear cache
        )
        

        self.device = device
        print("Loading weights from:", weights_path)
        self.model = YOLO(weights_path)
        print("Model object type:", type(self.model))
        print("Model class string:", self.model.__class__)


    def predict(self, image: Image.Image):
        """
        Run object detection.

        Args:
            image (PIL.Image.Image): Input image.

        Returns:
            List[Dict]: List of detected objects with class name, confidence, and bbox.
        """
        logger.info("Running object detection")
        results = self.model(image)
        detections = []
        for r in results:
            for box in r.boxes:
                detections.append({
                    "class_name": r.names[int(box.cls)],
                    "confidence": float(box.conf),
                    "bbox": box.xyxy[0].tolist()
                })
        logger.info(f"Detected {len(detections)} objects")
        return detections

    def draw(self, image: Image.Image, detections, alpha=0.5):
        """
        Draw bounding boxes on image.

        Args:
            image (PIL.Image.Image): Input image.
            detections (List[Dict]): Detection results.
            alpha (float): Blend strength.

        Returns:
            PIL.Image.Image: Image with bounding boxes drawn.
        """
        overlay = image.copy()
        draw = ImageDraw.Draw(overlay)
        for det in detections:
            bbox = det["bbox"]
            label = f'{det["class_name"]} {det["confidence"]:.2f}'
            draw.rectangle(bbox, outline="red", width=2)
            draw.text((bbox[0], bbox[1]), label, fill="red")
        return Image.blend(image, overlay, alpha)