File size: 1,840 Bytes
c2fb848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# yolo_module.py
from ultralytics import YOLO
from device_config import get_device
from PIL import Image, ImageDraw
import numpy as np
import easyocr

# Load YOLO model
MODEL_PATH = "models/best.pt"
device = get_device()
model = YOLO(MODEL_PATH).to(device)
print(f"✅ YOLO model loaded on: {device}")


# Optional OCR reader for arrow label detection
reader = easyocr.Reader(['en'], gpu=False)

def run_yolo(image: Image.Image):
    results = model.predict(image, conf=0.25, verbose=False)[0]

    boxes = []
    arrows = []

    # Convert image to OpenCV format for EasyOCR
    np_img = np.array(image)

    for i, box in enumerate(results.boxes):
        cls_id = int(box.cls)
        conf = float(box.conf)
        label = model.names[cls_id]

        x1, y1, x2, y2 = map(int, box.xyxy[0])
        bbox = [x1, y1, x2, y2]

        item = {
            "id": f"node{i+1}",
            "bbox": bbox,
            "type": "arrow" if label in ["arrow", "control_flow"] else "box",
            "label": label
        }

        if item["type"] == "arrow":
            # Heuristically scan a small region near the middle of the arrow for a label
            cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
            pad = 20
            crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]

            detected_label = ""
            if crop.size > 0:
                ocr_results = reader.readtext(crop)
                if ocr_results:
                    detected_label = ocr_results[0][1]  # (bbox, text, conf)

            arrows.append({
                "id": f"arrow{len(arrows)+1}",
                "tail": (x1, y1),
                "head": (x2, y2),
                "label": detected_label
            })
        else:
            boxes.append(item)

    vis_image = results.plot(pil=True)
    return boxes, arrows, vis_image