File size: 3,844 Bytes
152df72
 
 
 
 
928873d
c2717d6
152df72
928873d
c842ab7
152df72
 
928873d
152df72
928873d
6ea5d07
c2717d6
 
 
152df72
 
 
 
 
 
 
 
 
 
928873d
152df72
 
 
c842ab7
152df72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928873d
152df72
 
 
 
 
 
 
 
 
 
 
 
c842ab7
928873d
 
 
152df72
c842ab7
928873d
 
 
 
 
 
 
152df72
 
 
 
 
 
 
 
 
 
 
 
928873d
 
 
152df72
928873d
 
 
152df72
 
c842ab7
 
 
 
 
152df72
 
 
 
 
 
c842ab7
928873d
 
 
c842ab7
 
928873d
 
 
 
152df72
 
 
 
928873d
152df72
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
YOLO module for detecting flowchart elements (boxes and arrows).
Includes optional OCR for labeling arrows and deduplication to eliminate overlapping detections.
"""

from ultralytics import YOLO
from device_config import get_device
from PIL import Image
import numpy as np
import easyocr
from shapely.geometry import box as shapely_box
import torch

# Load YOLO model and move to appropriate device
MODEL_PATH = "models/best.pt"
device = get_device()
model = YOLO(MODEL_PATH).to(device)
print(f"✅ YOLO model loaded on: {device}")

# EasyOCR reader used for detecting optional labels near arrows
reader = easyocr.Reader(['en'], gpu=(device == "cuda"))


def iou(box1, box2):
    """Compute Intersection over Union (IoU) between two bounding boxes."""
    b1 = shapely_box(*box1)
    b2 = shapely_box(*box2)
    return b1.intersection(b2).area / b1.union(b2).area


def deduplicate_boxes(boxes, iou_threshold=0.6):
    """
    Eliminate overlapping or duplicate boxes based on IoU threshold.

    Args:
        boxes (list): List of box dictionaries with 'bbox' key.
        iou_threshold (float): Threshold above which boxes are considered duplicates.

    Returns:
        list: Filtered list of unique boxes.
    """
    filtered = []
    for box in boxes:
        if all(iou(box['bbox'], other['bbox']) < iou_threshold for other in filtered):
            filtered.append(box)
    return filtered


@torch.no_grad()
def run_yolo(image: Image.Image):
    """
    Run YOLO model on input image and return detected boxes, arrows, and annotated image.

    Args:
        image (PIL.Image): Input RGB image of a flowchart.

    Returns:
        tuple:
            boxes (list of dict): Each box has id, bbox, type, label.
            arrows (list of dict): Each arrow has id, tail, head, label.
            vis_image (PIL.Image): Annotated image with detections drawn.
    """
    results = model.predict(image, conf=0.25, verbose=False)[0]

    boxes = []
    arrows = []
    np_img = np.array(image)  # Convert image to numpy array for OCR crops

    for i, box in enumerate(results.boxes):
        cls_id = int(box.cls)
        label = model.names[cls_id]

        x1, y1, x2, y2 = map(int, box.xyxy[0])
        bbox = [x1, y1, x2, y2]

        width = x2 - x1
        height = y2 - y1
        aspect_ratio = width / height

        # Default type assignment
        item_type = "arrow" if label in ["arrow", "control_flow"] else "box"

        # Adjust to 'decision' if it's nearly square (likely diamond shape)
        if item_type == "box" and 0.8 < aspect_ratio < 1.2:
            item_type = "decision"

        # Create basic detection item
        item = {
            "id": f"node{i+1}",
            "bbox": bbox,
            "type": item_type,
            "label": label
        }

        if item_type == "arrow":
            # Extract small patch at arrow center for OCR label
            cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
            pad = 20
            crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
            detected_label = ""
            if crop.size > 0:
                try:
                    ocr_results = reader.readtext(crop)
                    if ocr_results:
                        detected_label = ocr_results[0][1].strip().lower()
                except Exception as e:
                    print(f"⚠️ Arrow OCR failed: {e}")

            arrows.append({
                "id": f"arrow{len(arrows)+1}",
                "tail": (x1, y1),
                "head": (x2, y2),
                "label": detected_label
            })
        else:
            boxes.append(item)

    # Remove overlapping duplicate boxes
    boxes = deduplicate_boxes(boxes)

    # Create annotated image with bounding boxes
    vis_image = results.plot(pil=True)

    return boxes, arrows, vis_image