Venkat V
UPDATED changes to streamlit, ocr
928873d
raw
history blame
1.27 kB
# yolo_module.py (updated to use .pt instead of ONNX)
from ultralytics import YOLO
from PIL import Image, ImageDraw
import numpy as np
# Define YOLO class labels (should be inferred automatically from .pt model)
# CLASS_NAMES no longer needed unless doing custom filtering
# Load YOLO model
MODEL_PATH = "models/best.pt"
model = YOLO(MODEL_PATH)
def run_yolo(image: Image.Image):
# Run YOLO prediction
results = model.predict(image, conf=0.25, verbose=False)[0] # single image
boxes = []
arrows = []
for i, box in enumerate(results.boxes):
cls_id = int(box.cls)
conf = float(box.conf)
label = model.names[cls_id]
x1, y1, x2, y2 = map(int, box.xyxy[0])
bbox = [x1, y1, x2, y2]
item = {
"id": f"node{i+1}",
"bbox": bbox,
"type": "arrow" if label in ["arrow", "control_flow"] else "box",
"label": label
}
if item["type"] == "arrow":
arrows.append({
"id": f"arrow{len(arrows)+1}",
"tail": (x1, y1),
"head": (x2, y2)
})
else:
boxes.append(item)
# Visualization
vis_image = results.plot(pil=True)
return boxes, arrows, vis_image