Spaces:
Sleeping
Sleeping
File size: 3,844 Bytes
152df72 928873d c2717d6 152df72 928873d c842ab7 152df72 928873d 152df72 928873d 6ea5d07 c2717d6 152df72 928873d 152df72 c842ab7 152df72 928873d 152df72 c842ab7 928873d 152df72 c842ab7 928873d 152df72 928873d 152df72 928873d 152df72 c842ab7 152df72 c842ab7 928873d c842ab7 928873d 152df72 928873d 152df72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""
YOLO module for detecting flowchart elements (boxes and arrows).
Includes optional OCR for labeling arrows and deduplication to eliminate overlapping detections.
"""
from ultralytics import YOLO
from device_config import get_device
from PIL import Image
import numpy as np
import easyocr
from shapely.geometry import box as shapely_box
import torch
# Load YOLO model and move to appropriate device
MODEL_PATH = "models/best.pt"
device = get_device()
model = YOLO(MODEL_PATH).to(device)
print(f"✅ YOLO model loaded on: {device}")
# EasyOCR reader used for detecting optional labels near arrows
reader = easyocr.Reader(['en'], gpu=(device == "cuda"))
def iou(box1, box2):
"""Compute Intersection over Union (IoU) between two bounding boxes."""
b1 = shapely_box(*box1)
b2 = shapely_box(*box2)
return b1.intersection(b2).area / b1.union(b2).area
def deduplicate_boxes(boxes, iou_threshold=0.6):
"""
Eliminate overlapping or duplicate boxes based on IoU threshold.
Args:
boxes (list): List of box dictionaries with 'bbox' key.
iou_threshold (float): Threshold above which boxes are considered duplicates.
Returns:
list: Filtered list of unique boxes.
"""
filtered = []
for box in boxes:
if all(iou(box['bbox'], other['bbox']) < iou_threshold for other in filtered):
filtered.append(box)
return filtered
@torch.no_grad()
def run_yolo(image: Image.Image):
"""
Run YOLO model on input image and return detected boxes, arrows, and annotated image.
Args:
image (PIL.Image): Input RGB image of a flowchart.
Returns:
tuple:
boxes (list of dict): Each box has id, bbox, type, label.
arrows (list of dict): Each arrow has id, tail, head, label.
vis_image (PIL.Image): Annotated image with detections drawn.
"""
results = model.predict(image, conf=0.25, verbose=False)[0]
boxes = []
arrows = []
np_img = np.array(image) # Convert image to numpy array for OCR crops
for i, box in enumerate(results.boxes):
cls_id = int(box.cls)
label = model.names[cls_id]
x1, y1, x2, y2 = map(int, box.xyxy[0])
bbox = [x1, y1, x2, y2]
width = x2 - x1
height = y2 - y1
aspect_ratio = width / height
# Default type assignment
item_type = "arrow" if label in ["arrow", "control_flow"] else "box"
# Adjust to 'decision' if it's nearly square (likely diamond shape)
if item_type == "box" and 0.8 < aspect_ratio < 1.2:
item_type = "decision"
# Create basic detection item
item = {
"id": f"node{i+1}",
"bbox": bbox,
"type": item_type,
"label": label
}
if item_type == "arrow":
# Extract small patch at arrow center for OCR label
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
pad = 20
crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
detected_label = ""
if crop.size > 0:
try:
ocr_results = reader.readtext(crop)
if ocr_results:
detected_label = ocr_results[0][1].strip().lower()
except Exception as e:
print(f"⚠️ Arrow OCR failed: {e}")
arrows.append({
"id": f"arrow{len(arrows)+1}",
"tail": (x1, y1),
"head": (x2, y2),
"label": detected_label
})
else:
boxes.append(item)
# Remove overlapping duplicate boxes
boxes = deduplicate_boxes(boxes)
# Create annotated image with bounding boxes
vis_image = results.plot(pil=True)
return boxes, arrows, vis_image
|