Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from PIL import Image, ImageDraw | |
from huggingface_hub import hf_hub_download | |
from ultralytics import YOLO | |
import shutil | |
logger = logging.getLogger(__name__) | |
shutil.rmtree("models/detection/weights", ignore_errors=True) | |
class ObjectDetector: | |
""" | |
Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants. | |
""" | |
def __init__(self, model_key="yolov5n", device="cpu"): | |
""" | |
Initialize the Object Detection model. | |
Args: | |
model_key (str): Model identifier as defined in model_downloader.py. | |
weights_dir (str): Directory to store/download model weights. | |
device (str): Inference device ("cpu" or "cuda"). | |
""" | |
alias_map = { | |
"yolov5n-seg": "yolov5n", | |
"yolov5s-seg": "yolov5s", | |
"yolov8s": "yolov8s", | |
"yolov8l": "yolov8l", | |
"yolov11b": "yolov11b", | |
"rtdetr": "rtdetr" | |
} | |
raw_key = model_key.lower() | |
model_key = alias_map.get(raw_key, raw_key) #to make model_key case-insensitive | |
repo_map = { | |
"yolov5n": ("ultralytics/yolov5", "yolov5n.pt"), | |
"yolov5s": ("ultralytics/yolov5", "yolov5s.pt"), | |
"yolov8n": ("ultralytics/yolov8", "yolov8n.pt"), | |
"yolov8s": ("ultralytics/yolov8", "yolov8s.pt"), | |
"yolov8l": ("ultralytics/yolov8", "yolov8l.pt"), | |
"yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"), | |
"rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth") | |
} | |
if model_key not in repo_map: | |
raise ValueError(f"Unsupported model_key: {model_key}") | |
repo_id, filename = repo_map[model_key] | |
weights_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
cache_dir="models/detection/weights", | |
force_download=True #Clear cache | |
) | |
self.device = device | |
print("Loading weights from:", weights_path) | |
self.model = YOLO(weights_path) | |
print("Model object type:", type(self.model)) | |
print("Model class string:", self.model.__class__) | |
def predict(self, image: Image.Image): | |
""" | |
Run object detection. | |
Args: | |
image (PIL.Image.Image): Input image. | |
Returns: | |
List[Dict]: List of detected objects with class name, confidence, and bbox. | |
""" | |
logger.info("Running object detection") | |
results = self.model(image) | |
detections = [] | |
for r in results: | |
for box in r.boxes: | |
detections.append({ | |
"class_name": r.names[int(box.cls)], | |
"confidence": float(box.conf), | |
"bbox": box.xyxy[0].tolist() | |
}) | |
logger.info(f"Detected {len(detections)} objects") | |
return detections | |
def draw(self, image: Image.Image, detections, alpha=0.5): | |
""" | |
Draw bounding boxes on image. | |
Args: | |
image (PIL.Image.Image): Input image. | |
detections (List[Dict]): Detection results. | |
alpha (float): Blend strength. | |
Returns: | |
PIL.Image.Image: Image with bounding boxes drawn. | |
""" | |
overlay = image.copy() | |
draw = ImageDraw.Draw(overlay) | |
for det in detections: | |
bbox = det["bbox"] | |
label = f'{det["class_name"]} {det["confidence"]:.2f}' | |
draw.rectangle(bbox, outline="red", width=2) | |
draw.text((bbox[0], bbox[1]), label, fill="red") | |
return Image.blend(image, overlay, alpha) | |