Spaces:
Sleeping
Sleeping
# yolo_module.py (updated to use .pt instead of ONNX) | |
from ultralytics import YOLO | |
from PIL import Image, ImageDraw | |
import numpy as np | |
# Define YOLO class labels (should be inferred automatically from .pt model) | |
# CLASS_NAMES no longer needed unless doing custom filtering | |
# Load YOLO model | |
MODEL_PATH = "models/best.pt" | |
model = YOLO(MODEL_PATH) | |
def run_yolo(image: Image.Image): | |
# Run YOLO prediction | |
results = model.predict(image, conf=0.25, verbose=False)[0] # single image | |
boxes = [] | |
arrows = [] | |
for i, box in enumerate(results.boxes): | |
cls_id = int(box.cls) | |
conf = float(box.conf) | |
label = model.names[cls_id] | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
bbox = [x1, y1, x2, y2] | |
item = { | |
"id": f"node{i+1}", | |
"bbox": bbox, | |
"type": "arrow" if label in ["arrow", "control_flow"] else "box", | |
"label": label | |
} | |
if item["type"] == "arrow": | |
arrows.append({ | |
"id": f"arrow{len(arrows)+1}", | |
"tail": (x1, y1), | |
"head": (x2, y2) | |
}) | |
else: | |
boxes.append(item) | |
# Visualization | |
vis_image = results.plot(pil=True) | |
return boxes, arrows, vis_image |