UVIS / models /detection /detector.py
DurgaDeepak's picture
Update models/detection/detector.py
a79c4c9 verified
raw
history blame
3.65 kB
import logging
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import shutil
logger = logging.getLogger(__name__)
shutil.rmtree("models/detection/weights", ignore_errors=True)
class ObjectDetector:
"""
Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants.
"""
def __init__(self, model_key="yolov5n", device="cpu"):
"""
Initialize the Object Detection model.
Args:
model_key (str): Model identifier as defined in model_downloader.py.
weights_dir (str): Directory to store/download model weights.
device (str): Inference device ("cpu" or "cuda").
"""
alias_map = {
"yolov5n-seg": "yolov5n",
"yolov5s-seg": "yolov5s",
"yolov8s": "yolov8s",
"yolov8l": "yolov8l",
"yolov11b": "yolov11b",
"rtdetr": "rtdetr"
}
raw_key = model_key.lower()
model_key = alias_map.get(raw_key, raw_key) #to make model_key case-insensitive
repo_map = {
"yolov5n": ("ultralytics/yolov5", "yolov5n.pt"),
"yolov5s": ("ultralytics/yolov5", "yolov5s.pt"),
"yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
"yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
"yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
"yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
"rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth")
}
if model_key not in repo_map:
raise ValueError(f"Unsupported model_key: {model_key}")
repo_id, filename = repo_map[model_key]
weights_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir="models/detection/weights",
force_download=True #Clear cache
)
self.device = device
print("Loading weights from:", weights_path)
self.model = YOLO(weights_path)
print("Model object type:", type(self.model))
print("Model class string:", self.model.__class__)
def predict(self, image: Image.Image):
"""
Run object detection.
Args:
image (PIL.Image.Image): Input image.
Returns:
List[Dict]: List of detected objects with class name, confidence, and bbox.
"""
logger.info("Running object detection")
results = self.model(image)
detections = []
for r in results:
for box in r.boxes:
detections.append({
"class_name": r.names[int(box.cls)],
"confidence": float(box.conf),
"bbox": box.xyxy[0].tolist()
})
logger.info(f"Detected {len(detections)} objects")
return detections
def draw(self, image: Image.Image, detections, alpha=0.5):
"""
Draw bounding boxes on image.
Args:
image (PIL.Image.Image): Input image.
detections (List[Dict]): Detection results.
alpha (float): Blend strength.
Returns:
PIL.Image.Image: Image with bounding boxes drawn.
"""
overlay = image.copy()
draw = ImageDraw.Draw(overlay)
for det in detections:
bbox = det["bbox"]
label = f'{det["class_name"]} {det["confidence"]:.2f}'
draw.rectangle(bbox, outline="red", width=2)
draw.text((bbox[0], bbox[1]), label, fill="red")
return Image.blend(image, overlay, alpha)