Spaces:
Sleeping
Sleeping
""" | |
OCR module with support for EasyOCR and Doctr. | |
Provides the `extract_text` function that accepts a cropped bounding box and image, | |
and runs OCR based on the selected engine ("easyocr" or "doctr"). | |
""" | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from textblob import TextBlob | |
from device_config import get_device | |
# OCR engine flags | |
USE_EASYOCR = True | |
USE_DOCTR = False | |
# Import EasyOCR if available | |
try: | |
import easyocr | |
reader = easyocr.Reader(['en'], gpu=(get_device() == "cuda")) | |
print(f"β EasyOCR reader initialized on: {get_device()}") | |
USE_EASYOCR = True | |
except ImportError: | |
print("β οΈ EasyOCR not installed. Falling back if Doctr is available.") | |
# Import Doctr if available | |
try: | |
from doctr.io import DocumentFile | |
from doctr.models import ocr_predictor | |
doctr_model = ocr_predictor(pretrained=True) | |
print("β Doctr model loaded.") | |
USE_DOCTR = True | |
except ImportError: | |
print("β οΈ Doctr not installed.") | |
def expand_bbox(bbox, image_size, pad=10): | |
"""Expand a bounding box by padding within image bounds.""" | |
x1, y1, x2, y2 = bbox | |
x1 = max(0, x1 - pad) | |
y1 = max(0, y1 - pad) | |
x2 = min(image_size[0], x2 + pad) | |
y2 = min(image_size[1], y2 + pad) | |
return [x1, y1, x2, y2] | |
def clean_text(text): | |
"""Use TextBlob to autocorrect basic OCR errors.""" | |
blob = TextBlob(text) | |
return str(blob.correct()) | |
def extract_text(image, bbox, debug=False, engine="easyocr"): | |
""" | |
Run OCR on a cropped region using EasyOCR or Doctr. | |
Parameters: | |
image (PIL.Image): Full input image. | |
bbox (list): [x1, y1, x2, y2] bounding box. | |
debug (bool): Enable debug output. | |
engine (str): 'easyocr' or 'doctr'. | |
Returns: | |
str: Cleaned OCR output. | |
""" | |
# Expand and crop image region | |
bbox = expand_bbox(bbox, image.size, pad=10) | |
x1, y1, x2, y2 = bbox | |
cropped = image.crop((x1, y1, x2, y2)) | |
# Convert to OpenCV grayscale | |
cv_img = np.array(cropped) | |
gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY) | |
# Enhance contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization) | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
enhanced = clahe.apply(gray) | |
# Apply adaptive threshold for better text separation | |
thresh = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 4) | |
# Resize for better OCR resolution | |
resized = cv2.resize(thresh, (0, 0), fx=2.5, fy=2.5, interpolation=cv2.INTER_LINEAR) | |
# Convert to RGB (some OCR engines expect 3-channel images) | |
preprocessed = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB) | |
if debug: | |
Image.fromarray(preprocessed).save(f"debug_ocr_crop_{x1}_{y1}.png") | |
if engine == "doctr" and USE_DOCTR: | |
try: | |
doc = DocumentFile.from_images([Image.fromarray(preprocessed)]) | |
result = doctr_model(doc) | |
out_text = " ".join([b.value for b in result.pages[0].blocks]) | |
if debug: | |
print(f"π Doctr OCR: {out_text}") | |
return clean_text(out_text) | |
except Exception as e: | |
if debug: | |
print(f"β Doctr failed: {e}") | |
return "" | |
elif engine == "easyocr" and USE_EASYOCR: | |
try: | |
results = reader.readtext(preprocessed, paragraph=False, min_size=10) | |
filtered = [] | |
for r in results: | |
text = r[1].strip() | |
conf = r[2] | |
if conf > 0.5 and len(text) > 2 and any(c.isalnum() for c in text): | |
filtered.append(r) | |
# Remove duplicates by bounding box IoU overlap | |
final = [] | |
seen = set() | |
for r in filtered: | |
t = r[1].strip() | |
if t.lower() not in seen: | |
seen.add(t.lower()) | |
final.append(r) | |
final.sort(key=lambda r: (r[0][0][1], r[0][0][0])) | |
text = " ".join([r[1] for r in final]).strip() | |
if debug: | |
for r in final: | |
print(f"π± EasyOCR: {r[1]} (conf: {r[2]:.2f})") | |
return clean_text(text) if text else "" | |
except Exception as e: | |
if debug: | |
print(f"β EasyOCR failed: {e}") | |
return "" | |
else: | |
if debug: | |
print(f"β οΈ Unsupported OCR engine: {engine} or not available.") | |
return "" | |
def count_elements(boxes, arrows, debug=False): | |
"""Return count of boxes and arrows detected.""" | |
box_count = len(boxes) | |
arrow_count = len(arrows) | |
if debug: | |
print(f"π¦ Boxes: {box_count} | β‘οΈ Arrows: {arrow_count}") | |
return {"box_count": box_count, "arrow_count": arrow_count} | |
def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False): | |
"""Validate flowchart structure consistency based on expected counts.""" | |
actual_boxes = len(flowchart_json.get("steps", [])) | |
actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None | |
if debug: | |
print(f"π JSON boxes: {actual_boxes}, edges: {actual_arrows}") | |
return { | |
"boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes), | |
"arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows) | |
} |