Venkat V
Fresh start with clean LFS tracking
c2fb848
raw
history blame
1.84 kB
# yolo_module.py
from ultralytics import YOLO
from device_config import get_device
from PIL import Image, ImageDraw
import numpy as np
import easyocr
# Load YOLO model
MODEL_PATH = "models/best.pt"
device = get_device()
model = YOLO(MODEL_PATH).to(device)
print(f"βœ… YOLO model loaded on: {device}")
# Optional OCR reader for arrow label detection
reader = easyocr.Reader(['en'], gpu=False)
def run_yolo(image: Image.Image):
results = model.predict(image, conf=0.25, verbose=False)[0]
boxes = []
arrows = []
# Convert image to OpenCV format for EasyOCR
np_img = np.array(image)
for i, box in enumerate(results.boxes):
cls_id = int(box.cls)
conf = float(box.conf)
label = model.names[cls_id]
x1, y1, x2, y2 = map(int, box.xyxy[0])
bbox = [x1, y1, x2, y2]
item = {
"id": f"node{i+1}",
"bbox": bbox,
"type": "arrow" if label in ["arrow", "control_flow"] else "box",
"label": label
}
if item["type"] == "arrow":
# Heuristically scan a small region near the middle of the arrow for a label
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
pad = 20
crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
detected_label = ""
if crop.size > 0:
ocr_results = reader.readtext(crop)
if ocr_results:
detected_label = ocr_results[0][1] # (bbox, text, conf)
arrows.append({
"id": f"arrow{len(arrows)+1}",
"tail": (x1, y1),
"head": (x2, y2),
"label": detected_label
})
else:
boxes.append(item)
vis_image = results.plot(pil=True)
return boxes, arrows, vis_image