|
|
|
from ultralytics import YOLO |
|
from device_config import get_device |
|
from PIL import Image, ImageDraw |
|
import numpy as np |
|
import easyocr |
|
|
|
|
|
MODEL_PATH = "models/best.pt" |
|
device = get_device() |
|
model = YOLO(MODEL_PATH).to(device) |
|
print(f"β
YOLO model loaded on: {device}") |
|
|
|
|
|
|
|
reader = easyocr.Reader(['en'], gpu=False) |
|
|
|
def run_yolo(image: Image.Image): |
|
results = model.predict(image, conf=0.25, verbose=False)[0] |
|
|
|
boxes = [] |
|
arrows = [] |
|
|
|
|
|
np_img = np.array(image) |
|
|
|
for i, box in enumerate(results.boxes): |
|
cls_id = int(box.cls) |
|
conf = float(box.conf) |
|
label = model.names[cls_id] |
|
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
bbox = [x1, y1, x2, y2] |
|
|
|
item = { |
|
"id": f"node{i+1}", |
|
"bbox": bbox, |
|
"type": "arrow" if label in ["arrow", "control_flow"] else "box", |
|
"label": label |
|
} |
|
|
|
if item["type"] == "arrow": |
|
|
|
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 |
|
pad = 20 |
|
crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad] |
|
|
|
detected_label = "" |
|
if crop.size > 0: |
|
ocr_results = reader.readtext(crop) |
|
if ocr_results: |
|
detected_label = ocr_results[0][1] |
|
|
|
arrows.append({ |
|
"id": f"arrow{len(arrows)+1}", |
|
"tail": (x1, y1), |
|
"head": (x2, y2), |
|
"label": detected_label |
|
}) |
|
else: |
|
boxes.append(item) |
|
|
|
vis_image = results.plot(pil=True) |
|
return boxes, arrows, vis_image |