Venkat V
updated with fixes to all modules
152df72
"""
OCR module with support for EasyOCR and Doctr.
Provides the `extract_text` function that accepts a cropped bounding box and image,
and runs OCR based on the selected engine ("easyocr" or "doctr").
"""
import numpy as np
from PIL import Image
import cv2
from textblob import TextBlob
from device_config import get_device
# OCR engine flags
USE_EASYOCR = True
USE_DOCTR = False
# Import EasyOCR if available
try:
import easyocr
reader = easyocr.Reader(['en'], gpu=(get_device() == "cuda"))
print(f"βœ… EasyOCR reader initialized on: {get_device()}")
USE_EASYOCR = True
except ImportError:
print("⚠️ EasyOCR not installed. Falling back if Doctr is available.")
# Import Doctr if available
try:
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
doctr_model = ocr_predictor(pretrained=True)
print("βœ… Doctr model loaded.")
USE_DOCTR = True
except ImportError:
print("⚠️ Doctr not installed.")
def expand_bbox(bbox, image_size, pad=10):
"""Expand a bounding box by padding within image bounds."""
x1, y1, x2, y2 = bbox
x1 = max(0, x1 - pad)
y1 = max(0, y1 - pad)
x2 = min(image_size[0], x2 + pad)
y2 = min(image_size[1], y2 + pad)
return [x1, y1, x2, y2]
def clean_text(text):
"""Use TextBlob to autocorrect basic OCR errors."""
blob = TextBlob(text)
return str(blob.correct())
def extract_text(image, bbox, debug=False, engine="easyocr"):
"""
Run OCR on a cropped region using EasyOCR or Doctr.
Parameters:
image (PIL.Image): Full input image.
bbox (list): [x1, y1, x2, y2] bounding box.
debug (bool): Enable debug output.
engine (str): 'easyocr' or 'doctr'.
Returns:
str: Cleaned OCR output.
"""
# Expand and crop image region
bbox = expand_bbox(bbox, image.size, pad=10)
x1, y1, x2, y2 = bbox
cropped = image.crop((x1, y1, x2, y2))
# Convert to OpenCV grayscale
cv_img = np.array(cropped)
gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
# Enhance contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Apply adaptive threshold for better text separation
thresh = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 4)
# Resize for better OCR resolution
resized = cv2.resize(thresh, (0, 0), fx=2.5, fy=2.5, interpolation=cv2.INTER_LINEAR)
# Convert to RGB (some OCR engines expect 3-channel images)
preprocessed = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB)
if debug:
Image.fromarray(preprocessed).save(f"debug_ocr_crop_{x1}_{y1}.png")
if engine == "doctr" and USE_DOCTR:
try:
doc = DocumentFile.from_images([Image.fromarray(preprocessed)])
result = doctr_model(doc)
out_text = " ".join([b.value for b in result.pages[0].blocks])
if debug:
print(f"πŸ“˜ Doctr OCR: {out_text}")
return clean_text(out_text)
except Exception as e:
if debug:
print(f"❌ Doctr failed: {e}")
return ""
elif engine == "easyocr" and USE_EASYOCR:
try:
results = reader.readtext(preprocessed, paragraph=False, min_size=10)
filtered = []
for r in results:
text = r[1].strip()
conf = r[2]
if conf > 0.5 and len(text) > 2 and any(c.isalnum() for c in text):
filtered.append(r)
# Remove duplicates by bounding box IoU overlap
final = []
seen = set()
for r in filtered:
t = r[1].strip()
if t.lower() not in seen:
seen.add(t.lower())
final.append(r)
final.sort(key=lambda r: (r[0][0][1], r[0][0][0]))
text = " ".join([r[1] for r in final]).strip()
if debug:
for r in final:
print(f"πŸ“± EasyOCR: {r[1]} (conf: {r[2]:.2f})")
return clean_text(text) if text else ""
except Exception as e:
if debug:
print(f"❌ EasyOCR failed: {e}")
return ""
else:
if debug:
print(f"⚠️ Unsupported OCR engine: {engine} or not available.")
return ""
def count_elements(boxes, arrows, debug=False):
"""Return count of boxes and arrows detected."""
box_count = len(boxes)
arrow_count = len(arrows)
if debug:
print(f"πŸ“¦ Boxes: {box_count} | ➑️ Arrows: {arrow_count}")
return {"box_count": box_count, "arrow_count": arrow_count}
def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False):
"""Validate flowchart structure consistency based on expected counts."""
actual_boxes = len(flowchart_json.get("steps", []))
actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None
if debug:
print(f"πŸ” JSON boxes: {actual_boxes}, edges: {actual_arrows}")
return {
"boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes),
"arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows)
}