""" OCR module with support for EasyOCR and Doctr. Provides the `extract_text` function that accepts a cropped bounding box and image, and runs OCR based on the selected engine ("easyocr" or "doctr"). """ import numpy as np from PIL import Image import cv2 from textblob import TextBlob from device_config import get_device # OCR engine flags USE_EASYOCR = True USE_DOCTR = False # Import EasyOCR if available try: import easyocr reader = easyocr.Reader(['en'], gpu=(get_device() == "cuda")) print(f"✅ EasyOCR reader initialized on: {get_device()}") USE_EASYOCR = True except ImportError: print("⚠️ EasyOCR not installed. Falling back if Doctr is available.") # Import Doctr if available try: from doctr.io import DocumentFile from doctr.models import ocr_predictor doctr_model = ocr_predictor(pretrained=True) print("✅ Doctr model loaded.") USE_DOCTR = True except ImportError: print("⚠️ Doctr not installed.") def expand_bbox(bbox, image_size, pad=10): """Expand a bounding box by padding within image bounds.""" x1, y1, x2, y2 = bbox x1 = max(0, x1 - pad) y1 = max(0, y1 - pad) x2 = min(image_size[0], x2 + pad) y2 = min(image_size[1], y2 + pad) return [x1, y1, x2, y2] def clean_text(text): """Use TextBlob to autocorrect basic OCR errors.""" blob = TextBlob(text) return str(blob.correct()) def extract_text(image, bbox, debug=False, engine="easyocr"): """ Run OCR on a cropped region using EasyOCR or Doctr. Parameters: image (PIL.Image): Full input image. bbox (list): [x1, y1, x2, y2] bounding box. debug (bool): Enable debug output. engine (str): 'easyocr' or 'doctr'. Returns: str: Cleaned OCR output. """ # Expand and crop image region bbox = expand_bbox(bbox, image.size, pad=10) x1, y1, x2, y2 = bbox cropped = image.crop((x1, y1, x2, y2)) # Convert to OpenCV grayscale cv_img = np.array(cropped) gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY) # Enhance contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) enhanced = clahe.apply(gray) # Apply adaptive threshold for better text separation thresh = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 4) # Resize for better OCR resolution resized = cv2.resize(thresh, (0, 0), fx=2.5, fy=2.5, interpolation=cv2.INTER_LINEAR) # Convert to RGB (some OCR engines expect 3-channel images) preprocessed = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB) if debug: Image.fromarray(preprocessed).save(f"debug_ocr_crop_{x1}_{y1}.png") if engine == "doctr" and USE_DOCTR: try: doc = DocumentFile.from_images([Image.fromarray(preprocessed)]) result = doctr_model(doc) out_text = " ".join([b.value for b in result.pages[0].blocks]) if debug: print(f"📘 Doctr OCR: {out_text}") return clean_text(out_text) except Exception as e: if debug: print(f"❌ Doctr failed: {e}") return "" elif engine == "easyocr" and USE_EASYOCR: try: results = reader.readtext(preprocessed, paragraph=False, min_size=10) filtered = [] for r in results: text = r[1].strip() conf = r[2] if conf > 0.5 and len(text) > 2 and any(c.isalnum() for c in text): filtered.append(r) # Remove duplicates by bounding box IoU overlap final = [] seen = set() for r in filtered: t = r[1].strip() if t.lower() not in seen: seen.add(t.lower()) final.append(r) final.sort(key=lambda r: (r[0][0][1], r[0][0][0])) text = " ".join([r[1] for r in final]).strip() if debug: for r in final: print(f"📱 EasyOCR: {r[1]} (conf: {r[2]:.2f})") return clean_text(text) if text else "" except Exception as e: if debug: print(f"❌ EasyOCR failed: {e}") return "" else: if debug: print(f"⚠️ Unsupported OCR engine: {engine} or not available.") return "" def count_elements(boxes, arrows, debug=False): """Return count of boxes and arrows detected.""" box_count = len(boxes) arrow_count = len(arrows) if debug: print(f"📦 Boxes: {box_count} | ➡️ Arrows: {arrow_count}") return {"box_count": box_count, "arrow_count": arrow_count} def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False): """Validate flowchart structure consistency based on expected counts.""" actual_boxes = len(flowchart_json.get("steps", [])) actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None if debug: print(f"🔍 JSON boxes: {actual_boxes}, edges: {actual_arrows}") return { "boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes), "arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows) }