File size: 3,477 Bytes
75fbee5
 
e29184b
75fbee5
d71beb6
5ec100e
 
ce2b58f
487b7f6
bbc95d9
1044803
 
487b7f6
 
b0c7a24
 
 
 
 
cdbafa3
487b7f6
1044803
487b7f6
 
 
 
 
 
 
 
 
 
93d071e
8cfece3
5caf904
 
 
5ec100e
 
 
5caf904
8cfece3
5caf904
 
8cfece3
5caf904
 
8cfece3
178b1f7
ce2b58f
5caf904
b0c7a24
bbc95d9
5052f14
93d071e
5052f14
 
ce2b58f
 
 
 
 
 
 
 
 
 
5ec100e
55317e4
5ec100e
 
e29184b
7374951
e29184b
99328a4
e29184b
 
 
5ec100e
 
 
e29184b
99328a4
 
e29184b
99328a4
e29184b
 
 
7374951
 
 
 
99328a4
7374951
 
e29184b
7374951
 
99328a4
e29184b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import logging
from PIL import Image, ImageDraw, ImageFont
from huggingface_hub import hf_hub_download

logger = logging.getLogger(__name__)

class ObjectDetector:
    def __init__(self, model_key="yolov8n", device="cpu"):
        self.device = device
        self.model = None
        self.model_key = model_key.lower().replace(".pt", "")

        hf_map = {
            "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
            "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
            "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
            "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
        }

        if self.model_key not in hf_map:
            raise ValueError(f"Unsupported model key: {self.model_key}")

        repo_id, filename = hf_map[self.model_key]
        self.weights_path = hf_hub_download(
            repo_id=repo_id,
            filename=filename,
            cache_dir="models/detection/weights",
            force_download=False
        )

    def load_model(self):
        logger.info(f"Loading model from path: {self.weights_path}")
        if self.model is None:
            import torch  # Safe to import here
            from ultralytics import YOLO  # Defer import
    
            if self.device == "cpu":
                os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
                
            # Initialize model
            self.model = YOLO(self.weights_path)
    
            # Move to CUDA only if necessary and safe
            if self.device == "cuda" and torch.cuda.is_available():
                self.model.to("cuda")
    
        return self



    def predict(self, image: Image.Image, conf_threshold=0.25):
        self.load_model()
        
        if self.model is None:
            raise RuntimeError("YOLO model not loaded. Call load_model() first.")
        
        results = self.model(image)
        detections = []
        for r in results:
            for box in r.boxes:
                detections.append({
                    "class_name": r.names[int(box.cls)],
                    "confidence": float(box.conf),
                    "bbox": box.xyxy[0].tolist()
                })
        return detections

    def draw(self, image: Image.Image, detections, alpha=0.5):
        overlay = image.copy()
        draw = ImageDraw.Draw(overlay)
    
        # Load font
        try:
            font = ImageFont.truetype("arial.ttf", 16)
        except:
            font = ImageFont.load_default()
    
        for det in detections:
            bbox = det["bbox"]
            label = f'{det["class_name"]} {det["confidence"]:.2f}'
    
            # Draw thicker bounding box
            for offset in range(3):
                draw.rectangle(
                    [bbox[0] - offset, bbox[1] - offset, bbox[2] + offset, bbox[3] + offset],
                    outline="red"
                )
    
            # Calculate text size using textbbox
            text_bbox = draw.textbbox((bbox[0], bbox[1]), label, font=font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]
    
            # Define background rectangle for text
            text_bg = [bbox[0], bbox[1] - text_height, bbox[0] + text_width + 4, bbox[1]]
            draw.rectangle(text_bg, fill="red")
    
            # Draw text
            draw.text((bbox[0] + 2, bbox[1] - text_height), label, fill="white", font=font)
    
        return Image.blend(image, overlay, alpha)