Spaces:
Sleeping
Sleeping
""" | |
YOLO module for detecting flowchart elements (boxes and arrows). | |
Includes optional OCR for labeling arrows and deduplication to eliminate overlapping detections. | |
""" | |
from ultralytics import YOLO | |
from device_config import get_device | |
from PIL import Image | |
import numpy as np | |
import easyocr | |
from shapely.geometry import box as shapely_box | |
import torch | |
# Load YOLO model and move to appropriate device | |
MODEL_PATH = "models/best.pt" | |
device = get_device() | |
model = YOLO(MODEL_PATH).to(device) | |
print(f"✅ YOLO model loaded on: {device}") | |
# EasyOCR reader used for detecting optional labels near arrows | |
reader = easyocr.Reader(['en'], gpu=(device == "cuda")) | |
def iou(box1, box2): | |
"""Compute Intersection over Union (IoU) between two bounding boxes.""" | |
b1 = shapely_box(*box1) | |
b2 = shapely_box(*box2) | |
return b1.intersection(b2).area / b1.union(b2).area | |
def deduplicate_boxes(boxes, iou_threshold=0.6): | |
""" | |
Eliminate overlapping or duplicate boxes based on IoU threshold. | |
Args: | |
boxes (list): List of box dictionaries with 'bbox' key. | |
iou_threshold (float): Threshold above which boxes are considered duplicates. | |
Returns: | |
list: Filtered list of unique boxes. | |
""" | |
filtered = [] | |
for box in boxes: | |
if all(iou(box['bbox'], other['bbox']) < iou_threshold for other in filtered): | |
filtered.append(box) | |
return filtered | |
def run_yolo(image: Image.Image): | |
""" | |
Run YOLO model on input image and return detected boxes, arrows, and annotated image. | |
Args: | |
image (PIL.Image): Input RGB image of a flowchart. | |
Returns: | |
tuple: | |
boxes (list of dict): Each box has id, bbox, type, label. | |
arrows (list of dict): Each arrow has id, tail, head, label. | |
vis_image (PIL.Image): Annotated image with detections drawn. | |
""" | |
results = model.predict(image, conf=0.25, verbose=False)[0] | |
boxes = [] | |
arrows = [] | |
np_img = np.array(image) # Convert image to numpy array for OCR crops | |
for i, box in enumerate(results.boxes): | |
cls_id = int(box.cls) | |
label = model.names[cls_id] | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
bbox = [x1, y1, x2, y2] | |
width = x2 - x1 | |
height = y2 - y1 | |
aspect_ratio = width / height | |
# Default type assignment | |
item_type = "arrow" if label in ["arrow", "control_flow"] else "box" | |
# Adjust to 'decision' if it's nearly square (likely diamond shape) | |
if item_type == "box" and 0.8 < aspect_ratio < 1.2: | |
item_type = "decision" | |
# Create basic detection item | |
item = { | |
"id": f"node{i+1}", | |
"bbox": bbox, | |
"type": item_type, | |
"label": label | |
} | |
if item_type == "arrow": | |
# Extract small patch at arrow center for OCR label | |
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 | |
pad = 20 | |
crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad] | |
detected_label = "" | |
if crop.size > 0: | |
try: | |
ocr_results = reader.readtext(crop) | |
if ocr_results: | |
detected_label = ocr_results[0][1].strip().lower() | |
except Exception as e: | |
print(f"⚠️ Arrow OCR failed: {e}") | |
arrows.append({ | |
"id": f"arrow{len(arrows)+1}", | |
"tail": (x1, y1), | |
"head": (x2, y2), | |
"label": detected_label | |
}) | |
else: | |
boxes.append(item) | |
# Remove overlapping duplicate boxes | |
boxes = deduplicate_boxes(boxes) | |
# Create annotated image with bounding boxes | |
vis_image = results.plot(pil=True) | |
return boxes, arrows, vis_image | |