Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,413 Bytes
75fbee5 d71beb6 5ec100e ce2b58f 487b7f6 bbc95d9 1044803 487b7f6 b0c7a24 cdbafa3 487b7f6 1044803 487b7f6 8cfece3 5caf904 5ec100e 5caf904 8cfece3 5caf904 8cfece3 5caf904 8cfece3 178b1f7 ce2b58f 5caf904 b0c7a24 bbc95d9 ce2b58f 5ec100e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import logging
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
class ObjectDetector:
def __init__(self, model_key="yolov8n", device="cpu"):
self.device = device
self.model = None
self.model_key = model_key.lower().replace(".pt", "")
hf_map = {
"yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
"yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
"yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
"yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
}
if self.model_key not in hf_map:
raise ValueError(f"Unsupported model key: {self.model_key}")
repo_id, filename = hf_map[self.model_key]
self.weights_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir="models/detection/weights",
force_download=False
)
def load_model(self):
if self.model is None:
import torch # Safe to import here
from ultralytics import YOLO # Defer import
if self.device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Initialize model
self.model = YOLO(self.weights_path)
# Move to CUDA only if necessary and safe
if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda")
return self
def predict(self, image: Image.Image, conf_threshold=0.25):
self.load_model()
results = self.model(image)
detections = []
for r in results:
for box in r.boxes:
detections.append({
"class_name": r.names[int(box.cls)],
"confidence": float(box.conf),
"bbox": box.xyxy[0].tolist()
})
return detections
def draw(self, image: Image.Image, detections, alpha=0.5):
overlay = image.copy()
draw = ImageDraw.Draw(overlay)
for det in detections:
bbox = det["bbox"]
label = f'{det["class_name"]} {det["confidence"]:.2f}'
draw.rectangle(bbox, outline="red", width=2)
draw.text((bbox[0], bbox[1]), label, fill="red")
return Image.blend(image, overlay, alpha)
|