|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import io |
|
import copy |
|
import fitz |
|
from fitz import Document as FitzDocument, Page as FitzPage, Matrix as FitzMatrix |
|
import numpy as np |
|
import cv2 |
|
import requests |
|
from pathlib import Path |
|
from enum import auto, Enum |
|
from dataclasses import dataclass, field |
|
from typing import Iterable, Generator, Sequence, Callable, TypeAlias, List, Dict, Any, Optional |
|
from typing import Literal |
|
from collections import defaultdict |
|
from math import pi, ceil, sin, cos, sqrt, atan2 |
|
from PIL.Image import Image, frombytes, new as new_image, Resampling as PILResampling, Transform as PILTransform, fromarray as pil_fromarray |
|
from PIL.ImageOps import expand as pil_expand |
|
from PIL import ImageDraw |
|
from PIL.ImageFont import load_default, FreeTypeFont |
|
from shapely.geometry import Polygon |
|
import pyclipper |
|
from unicodedata import category |
|
from alphabet_detector import AlphabetDetector |
|
from munch import Munch |
|
from transformers import LayoutLMv3ForTokenClassification |
|
import onnxruntime |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.errors import HfHubHTTPError |
|
import time |
|
|
|
|
|
try: |
|
from doclayout_yolo import YOLOv10 |
|
except ImportError: |
|
print("Warning: Could not import YOLOv10 from doclayout_yolo. Layout detection will fail.") |
|
YOLOv10 = None |
|
try: |
|
from pix2tex.cli import LatexOCR |
|
except ImportError: |
|
print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.") |
|
LatexOCR = None |
|
try: |
|
pass |
|
except ImportError: |
|
print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.") |
|
|
|
import torch |
|
|
|
if not hasattr(torch, "get_default_device"): |
|
torch.get_default_device = lambda: torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def mdr_download_model(url: str, file_path: Path): |
|
"""Downloads a model file from a URL to a local path.""" |
|
try: |
|
response = requests.get(url, stream=True, timeout=120) |
|
response.raise_for_status() |
|
file_path.parent.mkdir(parents=True, exist_ok=True) |
|
with open(file_path, "wb") as file: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
file.write(chunk) |
|
print(f"Successfully downloaded {file_path.name}") |
|
except requests.exceptions.RequestException as e: |
|
print(f"ERROR: Failed to download {url}: {e}") |
|
if file_path.exists(): os.remove(file_path) |
|
raise FileNotFoundError(f"Failed to download model from {url}") from e |
|
except Exception as e: |
|
print(f"ERROR: Failed writing file {file_path}: {e}") |
|
if file_path.exists(): os.remove(file_path) |
|
raise e |
|
|
|
|
|
def mdr_ensure_directory(path: str) -> str: |
|
"""Ensures a directory exists, creating it if necessary.""" |
|
path = os.path.abspath(path) |
|
os.makedirs(path, exist_ok=True) |
|
return path |
|
|
|
|
|
def mdr_is_whitespace(text: str) -> bool: |
|
"""Checks if a string contains only whitespace.""" |
|
return bool(re.match(r"^\s*$", text)) if text else True |
|
|
|
|
|
def mdr_expand_image(image: Image, percent: float) -> Image: |
|
"""Expands an image with a white border.""" |
|
if percent <= 0: return image.copy() |
|
w, h = image.size |
|
bw, bh = ceil(w * percent), ceil(h * percent) |
|
fill: tuple[int, ...] | int |
|
if image.mode == "RGBA": |
|
fill = (255, 255, 255, 255) |
|
elif image.mode in ("LA", "L"): |
|
fill = 255 |
|
else: |
|
fill = (255, 255, 255) |
|
return pil_expand(image=image, border=(bw, bh), fill=fill) |
|
|
|
|
|
|
|
MDRPoint: TypeAlias = tuple[float, float] |
|
|
|
|
|
@dataclass |
|
class MDRRectangle: |
|
"""Represents a geometric rectangle defined by four corner points.""" |
|
lt: MDRPoint; |
|
rt: MDRPoint; |
|
lb: MDRPoint; |
|
rb: MDRPoint |
|
|
|
def __iter__(self) -> Generator[MDRPoint, None, None]: |
|
yield self.lt; yield self.lb; yield self.rb; yield self.rt |
|
|
|
@property |
|
def is_valid(self) -> bool: |
|
try: |
|
return Polygon(self).is_valid |
|
except: |
|
return False |
|
|
|
@property |
|
def segments(self) -> Generator[tuple[MDRPoint, MDRPoint], None, None]: |
|
yield (self.lt, self.lb); yield (self.lb, self.rb); yield (self.rb, self.rt); yield (self.rt, self.lt) |
|
|
|
@property |
|
def area(self) -> float: |
|
try: |
|
return Polygon(self).area |
|
except: |
|
return 0.0 |
|
|
|
@property |
|
def size(self) -> tuple[float, float]: |
|
widths, heights = [], [] |
|
for i, (p1, p2) in enumerate(self.segments): |
|
dx, dy = p1[0] - p2[0], p1[1] - p2[1] |
|
dist = sqrt(dx * dx + dy * dy) |
|
if i % 2 == 0: |
|
heights.append(dist) |
|
else: |
|
widths.append(dist) |
|
avg_w = sum(widths) / len(widths) if widths else 0.0 |
|
avg_h = sum(heights) / len(heights) if heights else 0.0 |
|
return avg_w, avg_h |
|
|
|
@property |
|
def wrapper(self) -> tuple[float, float, float, float]: |
|
x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") |
|
for x, y in self: |
|
x1, y1, x2, y2 = min(x1, x), min(y1, y), max(x2, x), max(y2, y) |
|
return x1, y1, x2, y2 |
|
|
|
|
|
def mdr_intersection_area(rect1: MDRRectangle, rect2: MDRRectangle) -> float: |
|
"""Calculates intersection area between two MDRRectangles.""" |
|
try: |
|
p1 = Polygon(rect1) |
|
p2 = Polygon(rect2) |
|
if not p1.is_valid or not p2.is_valid: |
|
return 0.0 |
|
return p1.intersection(p2).area |
|
except: |
|
return 0.0 |
|
|
|
|
|
|
|
@dataclass |
|
class MDROcrFragment: |
|
"""Represents a fragment of text identified by OCR.""" |
|
order: int; |
|
text: str; |
|
rank: float; |
|
rect: MDRRectangle |
|
|
|
|
|
class MDRLayoutClass(Enum): |
|
"""Enumeration of different layout types identified.""" |
|
TITLE = 0; |
|
PLAIN_TEXT = 1; |
|
ABANDON = 2; |
|
FIGURE = 3; |
|
FIGURE_CAPTION = 4; |
|
TABLE = 5; |
|
TABLE_CAPTION = 6; |
|
TABLE_FOOTNOTE = 7; |
|
ISOLATE_FORMULA = 8; |
|
FORMULA_CAPTION = 9 |
|
|
|
|
|
class MDRTableLayoutParsedFormat(Enum): |
|
"""Enumeration for formats of parsed table content.""" |
|
LATEX = auto(); |
|
MARKDOWN = auto(); |
|
HTML = auto() |
|
|
|
|
|
@dataclass(eq=False) |
|
class MDRBaseLayoutElement: |
|
"""Base class for layout elements found on a page.""" |
|
rect: MDRRectangle; |
|
fragments: list[MDROcrFragment] |
|
def __eq__(self, other): |
|
return self is other |
|
|
|
def __hash__(self): |
|
return id(self) |
|
|
|
|
|
@dataclass |
|
class MDRPlainLayoutElement(MDRBaseLayoutElement): |
|
"""Layout element for plain text, titles, captions, figures, etc.""" |
|
|
|
cls: MDRLayoutClass |
|
|
|
|
|
@dataclass |
|
class MDRTableLayoutElement(MDRBaseLayoutElement): |
|
"""Layout element specifically for tables.""" |
|
parsed: tuple[str, MDRTableLayoutParsedFormat] | None |
|
|
|
cls: MDRLayoutClass = MDRLayoutClass.TABLE |
|
|
|
|
|
@dataclass |
|
class MDRFormulaLayoutElement(MDRBaseLayoutElement): |
|
"""Layout element specifically for formulas.""" |
|
latex: str | None |
|
|
|
cls: MDRLayoutClass = MDRLayoutClass.ISOLATE_FORMULA |
|
|
|
|
|
MDRLayoutElement = MDRPlainLayoutElement | MDRTableLayoutElement | MDRFormulaLayoutElement |
|
|
|
|
|
@dataclass |
|
class MDRExtractionResult: |
|
"""Holds the complete result of extracting from a single page image.""" |
|
rotation: float; |
|
layouts: list[MDRLayoutElement]; |
|
extracted_image: Image; |
|
adjusted_image: Image | None |
|
|
|
|
|
|
|
|
|
MDRProgressReportCallback: TypeAlias = Callable[[int, int], None] |
|
|
|
|
|
class MDROcrLevel(Enum): Once = auto(); OncePerLayout = auto() |
|
|
|
|
|
class MDRExtractedTableFormat(Enum): LATEX = auto(); MARKDOWN = auto(); HTML = auto(); DISABLE = auto() |
|
|
|
|
|
class MDRTextKind(Enum): TITLE = 0; PLAIN_TEXT = 1; ABANDON = 2 |
|
|
|
|
|
@dataclass |
|
class MDRTextSpan: |
|
"""Represents a span of text content within a block.""" |
|
content: str; |
|
rank: float; |
|
rect: MDRRectangle |
|
|
|
|
|
@dataclass |
|
class MDRBasicBlock: |
|
"""Base class for structured blocks extracted from the document.""" |
|
rect: MDRRectangle |
|
texts: list[MDRTextSpan] |
|
font_size: float |
|
|
|
|
|
@dataclass |
|
class MDRTextBlock(MDRBasicBlock): |
|
"""A structured block containing text content.""" |
|
kind: MDRTextKind |
|
has_paragraph_indentation: bool = False |
|
last_line_touch_end: bool = False |
|
|
|
|
|
class MDRTableFormat(Enum): |
|
LATEX = auto() |
|
MARKDOWN = auto() |
|
HTML = auto() |
|
UNRECOGNIZABLE = auto() |
|
|
|
|
|
@dataclass |
|
class MDRTableBlock(MDRBasicBlock): |
|
"""A structured block representing a table.""" |
|
content: str |
|
format: MDRTableFormat |
|
image: Image |
|
|
|
|
|
@dataclass |
|
class MDRFormulaBlock(MDRBasicBlock): |
|
"""A structured block representing a formula.""" |
|
content: str | None |
|
image: Image |
|
|
|
|
|
@dataclass |
|
class MDRFigureBlock(MDRBasicBlock): |
|
"""A structured block representing a figure/image.""" |
|
image: Image |
|
|
|
|
|
MDRAssetBlock = MDRTableBlock | MDRFormulaBlock | MDRFigureBlock |
|
|
|
MDRStructuredBlock = MDRTextBlock | MDRAssetBlock |
|
|
|
|
|
|
|
def mdr_similarity_ratio(v1: float, v2: float) -> float: |
|
"""Calculates the ratio of the smaller value to the larger value (0-1).""" |
|
if v1 == 0 and v2 == 0: |
|
return 1.0 |
|
if v1 < 0 or v2 < 0: |
|
return 0.0 |
|
v1, v2 = (v2, v1) if v1 > v2 else (v1, v2) |
|
return 1.0 if v2 == 0 else v1 / v2 |
|
|
|
|
|
def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[float, float]: |
|
"""Calculates width/height of the intersection bounding box.""" |
|
try: |
|
p1 = Polygon(r1) |
|
p2 = Polygon(r2) |
|
if not p1.is_valid or not p2.is_valid: |
|
return 0.0, 0.0 |
|
inter = p1.intersection(p2) |
|
if inter.is_empty: |
|
return 0.0, 0.0 |
|
minx, miny, maxx, maxy = inter.bounds |
|
return maxx - minx, maxy - miny |
|
except: |
|
return 0.0, 0.0 |
|
|
|
|
|
_MDR_CJKA_PATTERN = re.compile(r"[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3\u0600-\u06ff]") |
|
|
|
|
|
def mdr_contains_cjka(text: str): |
|
"""Checks if text contains Chinese, Japanese, Korean, or Arabic chars.""" |
|
return bool(_MDR_CJKA_PATTERN.search(text)) if text else False |
|
|
|
|
|
|
|
class _MDR_TokenPhase(Enum): |
|
Init = 0 |
|
Letter = 1 |
|
Character = 2 |
|
Number = 3 |
|
Space = 4 |
|
|
|
|
|
_mdr_alphabet_detector = AlphabetDetector() |
|
|
|
|
|
def _mdr_is_letter(char: str): |
|
if not category(char).startswith("L"): |
|
return False |
|
try: |
|
return _mdr_alphabet_detector.is_latin(char) or _mdr_alphabet_detector.is_cyrillic( |
|
char) or _mdr_alphabet_detector.is_greek(char) or _mdr_alphabet_detector.is_hebrew(char) |
|
except: |
|
return False |
|
|
|
|
|
def mdr_split_into_words(text: str): |
|
"""Splits text into words, numbers, and individual non-alphanumeric chars.""" |
|
if not text: return |
|
sp = re.compile(r"\s") |
|
np = re.compile(r"\d") |
|
nsp = re.compile(r"[\.,']") |
|
buf = io.StringIO() |
|
phase = _MDR_TokenPhase.Init |
|
for char in text: |
|
is_l = _mdr_is_letter(char) |
|
is_d = np.match(char) |
|
is_s = sp.match(char) |
|
is_ns = nsp.match(char) |
|
if is_l: |
|
if phase in (_MDR_TokenPhase.Number, _MDR_TokenPhase.Character): |
|
w = buf.getvalue() |
|
yield w if w else None |
|
buf = io.StringIO() |
|
buf.write(char) |
|
phase = _MDR_TokenPhase.Letter |
|
elif is_d: |
|
if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Character): |
|
w = buf.getvalue() |
|
yield w if w else None |
|
buf = io.StringIO() |
|
buf.write(char) |
|
phase = _MDR_TokenPhase.Number |
|
elif phase == _MDR_TokenPhase.Number and is_ns: |
|
buf.write(char) |
|
else: |
|
if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): |
|
w = buf.getvalue() |
|
yield w if w else None |
|
buf = io.StringIO() |
|
if is_s: |
|
phase = _MDR_TokenPhase.Space |
|
else: |
|
yield char |
|
phase = _MDR_TokenPhase.Character |
|
if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): |
|
w = buf.getvalue() |
|
yield w if w else None |
|
|
|
|
|
def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]: |
|
"""Calculates word-based similarity between two texts.""" |
|
w1 = list(mdr_split_into_words(t1)) |
|
w2 = list(mdr_split_into_words(t2)) |
|
l1 = len(w1) |
|
l2 = len(w2) |
|
if l1 == 0 and l2 == 0: |
|
return 1.0, 0 |
|
if l1 == 0 or l2 == 0: |
|
return 0.0, max(l1, l2) |
|
if l1 > l2: |
|
w1, w2, l1, l2 = w2, w1, l2, l1 |
|
taken = [False] * l2 |
|
matches = 0 |
|
for word1 in w1: |
|
for i, word2 in enumerate(w2): |
|
if not taken[i] and word1 == word2: |
|
taken[i] = True |
|
matches += 1 |
|
break |
|
mismatches = l2 - matches |
|
return 1.0 - (mismatches / l2), l2 |
|
|
|
|
|
|
|
class MDRRotationAdjuster: |
|
"""Adjusts point coordinates based on image rotation.""" |
|
|
|
def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, |
|
to_origin_coordinate: bool): |
|
fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size) |
|
self._rot = rotation if to_origin_coordinate else -rotation |
|
self._c_off = (fs[0] / 2.0, fs[1] / 2.0) |
|
self._n_off = (ts[0] / 2.0, ts[1] / 2.0) |
|
|
|
def adjust(self, point: MDRPoint) -> MDRPoint: |
|
x = point[0] - self._c_off[0] |
|
y = point[1] - self._c_off[1] |
|
if x != 0 or y != 0: |
|
cos_r = cos(self._rot) |
|
sin_r = sin(self._rot) |
|
x, y = x * cos_r - y * sin_r, x * sin_r + y * cos_r |
|
return x + self._n_off[0], y + self._n_off[1] |
|
|
|
|
|
def mdr_normalize_vertical_rotation(rot: float) -> float: |
|
while rot >= pi: |
|
rot -= pi |
|
while rot < 0: |
|
rot += pi |
|
return rot |
|
|
|
|
|
def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[float]] | None: |
|
h_angs, v_angs = [], [] |
|
for i, (p1, p2) in enumerate(rect.segments): |
|
dx = p2[0] - p1[0] |
|
dy = p2[1] - p1[1] |
|
if abs(dx) < 1e-6 and abs(dy) < 1e-6: |
|
continue |
|
ang = atan2(dy, dx) |
|
if ang < 0: |
|
ang += pi |
|
if ang < pi * 0.25 or ang >= pi * 0.75: |
|
h_angs.append(ang - pi if ang >= pi * 0.75 else ang) |
|
else: |
|
v_angs.append(ang) |
|
if not h_angs or not v_angs: |
|
return None |
|
return h_angs, v_angs |
|
|
|
|
|
def _mdr_normalize_horizontal_angles(rots: list[float]) -> list[float]: return rots |
|
|
|
|
|
def _mdr_find_median(data: list[float]) -> float: |
|
if not data: |
|
return 0.0 |
|
s_data = sorted(data) |
|
n = len(s_data) |
|
return s_data[n // 2] if n % 2 == 1 else (s_data[n // 2 - 1] + s_data[n // 2]) / 2.0 |
|
|
|
|
|
def _mdr_find_mean(data: list[float]) -> float: return sum(data) / len(data) if data else 0.0 |
|
|
|
|
|
def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float: |
|
all_h, all_v = [], [] |
|
for f in frags: |
|
res = _mdr_get_rectangle_angles(f.rect) |
|
if res: |
|
h, v = res |
|
all_h.extend(h) |
|
all_v.extend(v) |
|
if not all_h or not all_v: |
|
return 0.0 |
|
all_h = _mdr_normalize_horizontal_angles(all_h) |
|
all_v = [mdr_normalize_vertical_rotation(a) for a in all_v] |
|
med_h = _mdr_find_median(all_h) |
|
med_v = _mdr_find_median(all_v) |
|
rot_est = ((pi / 2 - med_v) - med_h) / 2.0 |
|
while rot_est >= pi / 2: |
|
rot_est -= pi |
|
while rot_est < -pi / 2: |
|
rot_est += pi |
|
return rot_est |
|
|
|
|
|
def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]: |
|
res = _mdr_get_rectangle_angles(rect); |
|
if res is None: return 0.0, pi / 2.0; |
|
h_rots, v_rots = res; |
|
h_rots = _mdr_normalize_horizontal_angles(h_rots); |
|
v_rots = [mdr_normalize_vertical_rotation(a) for a in v_rots] |
|
return _mdr_find_mean(h_rots), _mdr_find_mean(v_rots) |
|
|
|
|
|
|
|
class _MDR_PredictBase: |
|
"""Base class for ONNX model prediction components.""" |
|
|
|
def get_onnx_session(self, model_path: str, use_gpu: bool): |
|
try: |
|
sess_opts = onnxruntime.SessionOptions() |
|
sess_opts.log_severity_level = 3 |
|
providers = ['CUDAExecutionProvider', |
|
'CPUExecutionProvider'] if use_gpu and 'CUDAExecutionProvider' in onnxruntime.get_available_providers() else [ |
|
'CPUExecutionProvider'] |
|
session = onnxruntime.InferenceSession(model_path, sess_options=sess_opts, providers=providers) |
|
print(f" ONNX session loaded: {Path(model_path).name} ({session.get_providers()})") |
|
return session |
|
except Exception as e: |
|
print(f" ERROR loading ONNX session {Path(model_path).name}: {e}") |
|
if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers(): |
|
print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.") |
|
raise e |
|
|
|
def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: |
|
return [n.name for n in sess.get_outputs()] |
|
|
|
def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: |
|
return [n.name for n in sess.get_inputs()] |
|
|
|
def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: |
|
return {name: img_np for name in names} |
|
|
|
|
|
|
|
class _MDR_NormalizeImage: |
|
|
|
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): |
|
self.scale = np.float32( |
|
eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0)) |
|
mean = mean if mean is not None else [0.485, 0.456, 0.406] |
|
std = std if std is not None else [0.229, 0.224, 0.225] |
|
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) |
|
self.mean = np.array(mean).reshape(shape).astype('float32') |
|
self.std = np.array(std).reshape(shape).astype('float32') |
|
|
|
def __call__(self, data): |
|
img = data['image'] |
|
img = np.array(img) if isinstance(img, Image) else img |
|
data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std |
|
return data |
|
|
|
|
|
class _MDR_DetResizeForTest: |
|
|
|
def __init__(self, **kwargs): |
|
self.resize_type = 0 |
|
self.keep_ratio = False |
|
if 'image_shape' in kwargs: |
|
self.image_shape = kwargs['image_shape'] |
|
self.resize_type = 1 |
|
self.keep_ratio = kwargs.get('keep_ratio', False) |
|
elif 'limit_side_len' in kwargs: |
|
self.limit_side_len = kwargs['limit_side_len'] |
|
self.limit_type = kwargs.get('limit_type', 'min') |
|
elif 'resize_long' in kwargs: |
|
self.resize_type = 2 |
|
self.resize_long = kwargs.get('resize_long', 960) |
|
else: |
|
self.limit_side_len = 736 |
|
self.limit_type = 'min' |
|
|
|
def __call__(self, data): |
|
img = data['image'] |
|
src_h, src_w, _ = img.shape |
|
if src_h + src_w < 64: |
|
img = self._pad(img) |
|
if self.resize_type == 0: |
|
img, ratios = self._resize0(img) |
|
elif self.resize_type == 2: |
|
img, ratios = self._resize2(img) |
|
else: |
|
img, ratios = self._resize1(img) |
|
if img is None: |
|
return None |
|
data['image'] = img |
|
data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]) |
|
return data |
|
|
|
def _pad(self, im, v=0): |
|
h, w, c = im.shape |
|
p = np.zeros((max(32, h), max(32, w), c), np.uint8) + v |
|
p[:h, :w, :] = im |
|
return p |
|
|
|
def _resize1(self, img): |
|
rh, rw = self.image_shape |
|
oh, ow = img.shape[:2] |
|
if self.keep_ratio: |
|
|
|
rw = ow * rh / oh |
|
|
|
N = ceil(rw / 32) |
|
rw = N * 32 |
|
|
|
r_h = float(rh) / oh |
|
r_w = float(rw) / ow |
|
|
|
img = cv2.resize(img, (int(rw), int(rh))) |
|
return img, [r_h, r_w] |
|
|
|
def _resize0(self, img): |
|
lsl = self.limit_side_len |
|
h, w, _ = img.shape |
|
r = 1.0 |
|
if self.limit_type == 'max': |
|
r = float(lsl) / max(h, w) if max(h, w) > lsl else 1.0 |
|
elif self.limit_type == 'min': |
|
r = float(lsl) / min(h, w) if min(h, w) < lsl else 1.0 |
|
elif self.limit_type == 'resize_long': |
|
r = float(lsl) / max(h, w) |
|
else: |
|
raise Exception('Unsupported limit_type') |
|
rh = int(h * r) |
|
rw = int(w * r) |
|
rh = max(int(round(rh / 32) * 32), 32) |
|
rw = max(int(round(rw / 32) * 32), 32) |
|
if int(rw) <= 0 or int(rh) <= 0: |
|
return None, (None, None) |
|
img = cv2.resize(img, (int(rw), int(rh))) |
|
r_h = rh / float(h) |
|
r_w = rw / float(w) |
|
return img, [r_h, r_w] |
|
|
|
def _resize2(self, img): |
|
h, w, _ = img.shape |
|
rl = self.resize_long |
|
r = float(rl) / max(h, w) |
|
rh = int(h * r) |
|
rw = int(w * r) |
|
ms = 128 |
|
rh = (rh + ms - 1) // ms * ms |
|
rw = (rw + ms - 1) // ms * ms |
|
img = cv2.resize(img, (int(rw), int(rh))) |
|
r_h = rh / float(h) |
|
r_w = rw / float(w) |
|
return img, [r_h, r_w] |
|
|
|
|
|
class _MDR_ToCHWImage: |
|
|
|
def __call__(self, data): |
|
img = data['image'] |
|
img = np.array(img) if isinstance(img, Image) else img |
|
data['image'] = img.transpose((2, 0, 1)) |
|
return data |
|
|
|
|
|
class _MDR_KeepKeys: |
|
|
|
def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys |
|
|
|
def __call__(self, data): return [data[key] for key in self.keep_keys] |
|
|
|
|
|
def mdr_ocr_transform( |
|
data: Any, |
|
ops: Optional[List[Callable[[Any], Optional[Any]]]] = None |
|
) -> Optional[Any]: |
|
""" |
|
Applies a sequence of transformation operations to the input data. |
|
This function iterates through a list of operations (callables) and |
|
applies each one sequentially to the data. If any operation |
|
returns None, the processing stops immediately, and None is returned. |
|
Args: |
|
data: The initial data to be transformed. Can be of any type |
|
compatible with the operations. |
|
ops: An optional list of callable operations. Each operation |
|
should accept the current state of the data and return |
|
the transformed data or None to signal an early exit. |
|
If None or an empty list is provided, the original data |
|
is returned unchanged. |
|
Returns: |
|
The transformed data after applying all operations successfully, |
|
or None if any operation in the sequence returned None. |
|
""" |
|
|
|
|
|
if ops is None: |
|
operations_to_apply = [] |
|
else: |
|
operations_to_apply = ops |
|
|
|
current_data = data |
|
|
|
|
|
for op in operations_to_apply: |
|
current_data = op(current_data) |
|
|
|
|
|
|
|
if current_data is None: |
|
return None |
|
|
|
|
|
return current_data |
|
|
|
|
|
def mdr_ocr_create_operators(op_param_list, global_config=None): |
|
ops = [] |
|
for operator in op_param_list: |
|
assert isinstance(operator, dict) and len(operator) == 1, "Op config error"; |
|
op_name = list(operator)[0] |
|
param = {} if operator[op_name] is None else operator[op_name]; |
|
if global_config: param.update(global_config) |
|
op_class_name = f"_MDR_{op_name}" |
|
if op_class_name in globals() and isinstance(globals()[op_class_name], type): |
|
ops.append(globals()[op_class_name](**param)) |
|
else: |
|
raise ValueError(f"Operator class '{op_class_name}' not found.") |
|
return ops |
|
|
|
|
|
class _MDR_DBPostProcess: |
|
|
|
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, |
|
score_mode="fast", box_type='quad', **kwargs): |
|
self.thresh = thresh |
|
self.box_thresh = box_thresh |
|
self.max_cand = max_candidates |
|
self.unclip_r = unclip_ratio |
|
self.min_sz = 3 |
|
self.score_m = score_mode |
|
self.box_t = box_type |
|
assert score_mode in ["slow", "fast"] |
|
self.dila_k = np.array([[1, 1], [1, 1]], dtype=np.uint8) if use_dilation else None |
|
|
|
def _polygons_from_bitmap(self, pred, bmp, dw, dh): |
|
h, w = bmp.shape |
|
boxes, scores = [], [] |
|
contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
|
for contour in contours[:self.max_cand]: |
|
eps = 0.002 * cv2.arcLength(contour, True) |
|
approx = cv2.approxPolyDP(contour, eps, True) |
|
pts = approx.reshape((-1, 2)) |
|
if pts.shape[0] < 4: |
|
continue |
|
score = self._box_score_fast(pred, pts.reshape(-1, 2)) |
|
if self.box_thresh > score: |
|
continue |
|
try: |
|
box = self._unclip(pts, self.unclip_r) |
|
except: |
|
continue |
|
if len(box) > 1: |
|
continue |
|
box = box.reshape(-1, 2) |
|
_, sside = self._get_mini_boxes(box.reshape((-1, 1, 2))) |
|
if sside < self.min_sz + 2: |
|
continue |
|
box = np.array(box) |
|
box[:, 0] = np.clip(np.round(box[:, 0] / w * dw), 0, dw) |
|
box[:, 1] = np.clip(np.round(box[:, 1] / h * dh), 0, dh) |
|
boxes.append(box.tolist()) |
|
scores.append(score) |
|
return boxes, scores |
|
|
|
|
|
def _boxes_from_bitmap(self, pred, bmp, dw, dh): |
|
h, w = bmp.shape |
|
|
|
print( |
|
f" DEBUG OCR: _boxes_from_bitmap: Processing bitmap of shape {h}x{w} for original dimensions {dw:.1f}x{dh:.1f}.") |
|
contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
|
num_contours_found = len(contours) |
|
print(f" DEBUG OCR: _boxes_from_bitmap: Found {num_contours_found} raw contours.") |
|
|
|
num_contours_to_process = min(num_contours_found, self.max_cand) |
|
if num_contours_found > self.max_cand: |
|
print( |
|
f" DEBUG OCR: _boxes_from_bitmap: Processing limited to {self.max_cand} contours (max_candidates).") |
|
|
|
boxes, scores = [], [] |
|
kept_boxes_count = 0 |
|
for i in range(num_contours_to_process): |
|
contour = contours[i] |
|
pts_mini_box, sside = self._get_mini_boxes(contour) |
|
if sside < self.min_sz: |
|
|
|
continue |
|
|
|
pts_arr = np.array(pts_mini_box) |
|
current_score = self._box_score_fast(pred, pts_arr.reshape(-1, |
|
2)) if self.score_m == "fast" else self._box_score_slow( |
|
pred, contour) |
|
|
|
if self.box_thresh > current_score: |
|
|
|
continue |
|
|
|
try: |
|
box_unclipped = self._unclip(pts_arr, self.unclip_r).reshape(-1, 1, 2) |
|
except Exception as e_unclip: |
|
|
|
continue |
|
|
|
box_final, sside_final = self._get_mini_boxes(box_unclipped) |
|
if sside_final < self.min_sz + 2: |
|
|
|
continue |
|
|
|
box_final_arr = np.array(box_final) |
|
box_final_arr[:, 0] = np.clip(np.round(box_final_arr[:, 0] / w * dw), 0, dw) |
|
box_final_arr[:, 1] = np.clip(np.round(box_final_arr[:, 1] / h * dh), 0, dh) |
|
|
|
boxes.append(box_final_arr.astype("int32")) |
|
scores.append(current_score) |
|
kept_boxes_count += 1 |
|
print( |
|
f" DEBUG OCR: _boxes_from_bitmap: Kept {kept_boxes_count} boxes after all filtering (size, score, unclip). Configured box_thresh: {self.box_thresh}, min_sz: {self.min_sz}.") |
|
return np.array(boxes, dtype="int32"), scores |
|
|
|
def _unclip(self, box, ratio): |
|
poly = Polygon(box) |
|
dist = poly.area * ratio / poly.length |
|
offset = pyclipper.PyclipperOffset() |
|
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) |
|
expanded = offset.Execute(dist) |
|
if not expanded: |
|
raise ValueError("Unclip failed") |
|
return np.array(expanded[0]) |
|
|
|
def _get_mini_boxes(self, contour): |
|
bb = cv2.minAreaRect(contour) |
|
pts = sorted(list(cv2.boxPoints(bb)), key=lambda x: x[0]) |
|
i1, i4 = (0, 1) if pts[1][1] > pts[0][1] else (1, 0) |
|
i2, i3 = (2, 3) if pts[3][1] > pts[2][1] else (3, 2) |
|
box = [pts[i1], pts[i2], pts[i3], pts[i4]] |
|
return box, min(bb[1]) |
|
|
|
def _box_score_fast(self, bmp, box): |
|
h, w = bmp.shape[:2] |
|
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) |
|
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) |
|
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) |
|
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) |
|
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) |
|
box[:, 0] -= xmin |
|
box[:, 1] -= ymin |
|
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) |
|
return cv2.mean(bmp[ymin: ymax + 1, xmin: xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 |
|
|
|
def _box_score_slow(self, bmp, contour): |
|
h, w = bmp.shape[:2] |
|
contour = np.reshape(contour.copy(), (-1, 2)) |
|
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) |
|
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) |
|
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) |
|
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) |
|
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) |
|
contour[:, 0] -= xmin |
|
contour[:, 1] -= ymin |
|
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) |
|
return cv2.mean(bmp[ymin: ymax + 1, xmin: xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 |
|
|
|
def __call__(self, outs_dict, shape_list): |
|
pred = outs_dict['maps'][:, 0, :, :] |
|
seg = pred > self.thresh |
|
|
|
print( |
|
f" DEBUG OCR: _MDR_DBPostProcess: pred map shape: {pred.shape}, seg map shape: {seg.shape}, configured thresh: {self.thresh}") |
|
print( |
|
f" DEBUG OCR: _MDR_DBPostProcess: Number of pixels in seg map above threshold (sum of all batches): {np.sum(seg)}") |
|
|
|
boxes_batch = [] |
|
for batch_idx in range(pred.shape[0]): |
|
|
|
sh_orig, sw_orig, rh_ratio, rw_ratio = shape_list[batch_idx] |
|
|
|
|
|
|
|
dw_orig, dh_orig = sw_orig, sh_orig |
|
|
|
current_pred_map = pred[batch_idx] |
|
current_seg_map = seg[batch_idx] |
|
|
|
mask = cv2.dilate(np.array(current_seg_map).astype(np.uint8), |
|
self.dila_k) if self.dila_k is not None else current_seg_map |
|
print( |
|
f" DEBUG OCR: _MDR_DBPostProcess (batch {batch_idx}): Input shape to postproc (orig) {dh_orig:.1f}x{dw_orig:.1f}. Sum of mask pixels: {np.sum(mask)}") |
|
|
|
if self.box_t == 'poly': |
|
boxes, scores = self._polygons_from_bitmap(current_pred_map, mask, dh_orig, dw_orig) |
|
elif self.box_t == 'quad': |
|
boxes, scores = self._boxes_from_bitmap(current_pred_map, mask, dh_orig, dw_orig) |
|
else: |
|
raise ValueError("box_type must be 'quad' or 'poly'") |
|
print( |
|
f" DEBUG OCR: _MDR_DBPostProcess (batch {batch_idx}): Found {len(boxes)} boxes from bitmap processing.") |
|
boxes_batch.append({'points': boxes}) |
|
return boxes_batch |
|
|
|
|
|
class _MDR_TextDetector(_MDR_PredictBase): |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, |
|
{'NormalizeImage': {'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', |
|
'order': 'hwc'}}, {'ToCHWImage': None}, |
|
{'KeepKeys': {'keep_keys': ['image', 'shape']}}] |
|
self.pre_op = mdr_ocr_create_operators(pre_ops) |
|
post_params = {'thresh': args.det_db_thresh, 'box_thresh': args.det_db_box_thresh, 'max_candidates': 1000, |
|
'unclip_ratio': args.det_db_unclip_ratio, 'use_dilation': args.use_dilation, |
|
'score_mode': args.det_db_score_mode, 'box_type': args.det_box_type} |
|
self.post_op = _MDR_DBPostProcess(**post_params) |
|
self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu) |
|
self.input_name = self.get_input_name(self.sess) |
|
self.output_name = self.get_output_name(self.sess) |
|
|
|
def _order_pts(self, pts): |
|
r = np.zeros((4, 2), dtype="float32") |
|
s = pts.sum(axis=1) |
|
r[0] = pts[np.argmin(s)] |
|
r[2] = pts[np.argmax(s)] |
|
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) |
|
d = np.diff(np.array(tmp), axis=1) |
|
r[1] = tmp[np.argmin(d)] |
|
r[3] = tmp[np.argmax(d)] |
|
return r |
|
|
|
def _clip_pts(self, pts, h, w): |
|
pts[:, 0] = np.clip(pts[:, 0], 0, w - 1) |
|
pts[:, 1] = np.clip(pts[:, 1], 0, h - 1) |
|
return pts |
|
|
|
def _filter_quad(self, boxes, shape): |
|
h, w = shape[0:2] |
|
new_boxes = [] |
|
for box in boxes: |
|
box = np.array(box) if isinstance(box, list) else box |
|
box = self._order_pts(box) |
|
box = self._clip_pts(box, h, w) |
|
rw = int(np.linalg.norm(box[0] - box[1])) |
|
rh = int(np.linalg.norm(box[0] - box[3])) |
|
if rw <= 3 or rh <= 3: |
|
continue |
|
new_boxes.append(box) |
|
return np.array(new_boxes) |
|
|
|
def _filter_poly(self, boxes, shape): |
|
h, w = shape[0:2] |
|
new_boxes = [] |
|
for box in boxes: |
|
box = np.array(box) if isinstance(box, list) else box |
|
box = self._clip_pts(box, h, w) |
|
if Polygon(box).area < 10: |
|
continue |
|
new_boxes.append(box) |
|
return np.array(new_boxes) |
|
|
|
|
|
|
|
def __call__(self, img): |
|
ori_im = img.copy() |
|
data = {"image": img} |
|
print(f" DEBUG OCR: _MDR_TextDetector: Original image shape: {ori_im.shape}") |
|
|
|
|
|
try: |
|
data = mdr_ocr_transform(data, self.pre_op) |
|
except Exception as e_preproc: |
|
print(f" DEBUG OCR: _MDR_TextDetector: Error during preprocessing (mdr_ocr_transform): {e_preproc}") |
|
import traceback |
|
traceback.print_exc() |
|
return np.array([]) |
|
|
|
if data is None: |
|
print( |
|
" DEBUG OCR: _MDR_TextDetector: Preprocessing (mdr_ocr_transform) returned None. No text will be detected.") |
|
return np.array([]) |
|
|
|
processed_img, shape_list = data |
|
if processed_img is None: |
|
print(" DEBUG OCR: _MDR_TextDetector: Processed image after transform is None. No text will be detected.") |
|
return np.array([]) |
|
print( |
|
f" DEBUG OCR: _MDR_TextDetector: Processed image shape for ONNX: {processed_img.shape}, shape_list: {shape_list}") |
|
|
|
img_for_onnx = np.expand_dims(processed_img, axis=0) |
|
shape_list_for_onnx = np.expand_dims(shape_list, axis=0) |
|
img_for_onnx = img_for_onnx.copy() |
|
|
|
inputs = self.get_input_feed(self.input_name, img_for_onnx) |
|
print(f" DEBUG OCR: _MDR_TextDetector: Running ONNX inference for text detection...") |
|
try: |
|
outputs = self.sess.run(self.output_name, input_feed=inputs) |
|
except Exception as e_infer: |
|
print(f" DEBUG OCR: _MDR_TextDetector: ONNX inference for detection failed: {e_infer}") |
|
import traceback |
|
traceback.print_exc() |
|
return np.array([]) |
|
print(f" DEBUG OCR: _MDR_TextDetector: ONNX inference done. Output map shape: {outputs[0].shape}") |
|
|
|
preds = {"maps": outputs[0]} |
|
try: |
|
post_res = self.post_op(preds, shape_list_for_onnx) |
|
except Exception as e_postproc: |
|
print(f" DEBUG OCR: _MDR_TextDetector: Error during DBPostProcess: {e_postproc}") |
|
import traceback |
|
traceback.print_exc() |
|
return np.array([]) |
|
|
|
|
|
|
|
if not post_res or not isinstance(post_res, list) or len(post_res) == 0 or \ |
|
not isinstance(post_res[0], dict) or 'points' not in post_res[0]: |
|
print(" DEBUG OCR: _MDR_TextDetector: DBPostProcess returned invalid or empty structure for points.") |
|
return np.array([]) |
|
|
|
boxes_from_post = post_res[0]['points'] |
|
|
|
|
|
|
|
no_boxes_found = False |
|
if isinstance(boxes_from_post, np.ndarray): |
|
if boxes_from_post.size == 0: |
|
no_boxes_found = True |
|
elif isinstance(boxes_from_post, list): |
|
if not boxes_from_post: |
|
no_boxes_found = True |
|
elif boxes_from_post is None: |
|
no_boxes_found = True |
|
else: |
|
|
|
print( |
|
f" DEBUG OCR: _MDR_TextDetector: 'points' from DBPostProcess is of unexpected type: {type(boxes_from_post)}") |
|
return np.array([]) |
|
|
|
if no_boxes_found: |
|
print(" DEBUG OCR: _MDR_TextDetector: DBPostProcess returned no actual point data.") |
|
return np.array([]) |
|
|
|
|
|
print( |
|
f" DEBUG OCR: _MDR_TextDetector: Boxes from DBPostProcess before final filtering: {len(boxes_from_post)}") |
|
|
|
|
|
|
|
if not isinstance(boxes_from_post, (list, np.ndarray)) or \ |
|
(isinstance(boxes_from_post, np.ndarray) and boxes_from_post.size == 0) or \ |
|
(isinstance(boxes_from_post, list) and not boxes_from_post): |
|
print(" DEBUG OCR: _MDR_TextDetector: No boxes from DBPostProcess to filter (secondary check).") |
|
return np.array([]) |
|
|
|
if self.args.det_box_type == 'poly': |
|
final_boxes = self._filter_poly(boxes_from_post, ori_im.shape) |
|
else: |
|
final_boxes = self._filter_quad(boxes_from_post, ori_im.shape) |
|
print(f" DEBUG OCR: _MDR_TextDetector: Boxes after final poly/quad filtering: {len(final_boxes)}") |
|
return final_boxes |
|
|
|
|
|
class _MDR_ClsPostProcess: |
|
|
|
def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0: '0', 1: '180'} |
|
|
|
def __call__(self, preds, label=None, *args, **kwargs): |
|
preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; |
|
idxs = preds.argmax(axis=1) |
|
return [(self.labels[idx], float(preds[i, idx])) for i, idx in enumerate(idxs)] |
|
|
|
|
|
class _MDR_TextClassifier(_MDR_PredictBase): |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, |
|
str) else args.cls_image_shape |
|
self.batch_num = args.cls_batch_num |
|
self.thresh = args.cls_thresh |
|
self.post_op = _MDR_ClsPostProcess(label_list=args.label_list) |
|
self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu) |
|
self.input_name = self.get_input_name(self.sess) |
|
self.output_name = self.get_output_name(self.sess) |
|
|
|
def _resize_norm(self, img): |
|
imgC, imgH, imgW = self.shape |
|
h, w = img.shape[:2] |
|
r = w / float(h) if h > 0 else 0 |
|
rw = int(ceil(imgH * r)) |
|
rw = min(rw, imgW) |
|
resized = cv2.resize(img, (rw, imgH)) |
|
resized = resized.astype("float32") |
|
if imgC == 1: |
|
resized = resized / 255.0 |
|
resized = resized[np.newaxis, :] |
|
else: |
|
resized = resized.transpose((2, 0, 1)) / 255.0 |
|
resized -= 0.5 |
|
resized /= 0.5 |
|
padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
padding[:, :, 0:rw] = resized |
|
return padding |
|
|
|
def __call__(self, img_list): |
|
if not img_list: |
|
return img_list, [] |
|
img_list_cp = copy.deepcopy(img_list) |
|
num = len(img_list_cp) |
|
ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list_cp] |
|
indices = np.argsort(np.array(ratios)) |
|
results = [["", 0.0]] * num |
|
batch_n = self.batch_num |
|
for start in range(0, num, batch_n): |
|
end = min(num, start + batch_n) |
|
batch = [] |
|
for i in range(start, end): |
|
batch.append(self._resize_norm(img_list_cp[indices[i]])[np.newaxis, :]) |
|
if not batch: |
|
continue |
|
batch = np.concatenate(batch, axis=0).copy() |
|
inputs = self.get_input_feed(self.input_name, batch) |
|
outputs = self.sess.run(self.output_name, input_feed=inputs) |
|
cls_out = self.post_op(outputs[0]) |
|
for i in range(len(cls_out)): |
|
orig_idx = indices[start + i] |
|
label, score = cls_out[i] |
|
results[orig_idx] = [label, score] |
|
if "180" in label and score > self.thresh: |
|
img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180) |
|
return img_list, results |
|
|
|
|
|
class _MDR_BaseRecLabelDecode: |
|
|
|
def __init__(self, char_path=None, use_space=False): |
|
self.beg, self.end, self.rev = "sos", "eos", False |
|
self.chars = [] |
|
if char_path is None: |
|
self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz") |
|
else: |
|
try: |
|
with open(char_path, "rb") as f: |
|
self.chars = [l.decode("utf-8").strip("\n\r") for l in f] |
|
if use_space: |
|
self.chars.append(" ") |
|
if any("\u0600" <= c <= "\u06FF" for c in self.chars): |
|
self.rev = True |
|
except FileNotFoundError: |
|
print(f"Warn: Dict not found {char_path}") |
|
self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz") |
|
if use_space: |
|
self.chars.append(" ") |
|
d_char = self.add_special_char(list(self.chars)) |
|
self.dict = {c: i for i, c in enumerate(d_char)} |
|
self.character = d_char |
|
|
|
def add_special_char(self, chars): |
|
return chars |
|
|
|
def get_ignored_tokens(self): |
|
return [] |
|
|
|
def _reverse(self, pred): |
|
res = [] |
|
cur = "" |
|
for c in pred: |
|
if not re.search("[a-zA-Z0-9 :*./%+-]", c): |
|
if cur != "": |
|
res.extend([cur, c]) |
|
else: |
|
res.extend([c]) |
|
cur = "" |
|
else: |
|
cur += c |
|
if cur != "": |
|
res.append(cur) |
|
return "".join(res[::-1]) |
|
|
|
def decode(self, idxs, probs=None, remove_dup=False): |
|
res = [] |
|
ignored = self.get_ignored_tokens() |
|
bs = len(idxs) |
|
for b_idx in range(bs): |
|
sel = np.ones(len(idxs[b_idx]), dtype=bool) |
|
if remove_dup: |
|
sel[1:] = idxs[b_idx][1:] != idxs[b_idx][:-1] |
|
for ig_tok in ignored: |
|
sel &= idxs[b_idx] != ig_tok |
|
char_l = [ |
|
self.character[tid] |
|
for tid in idxs[b_idx][sel] |
|
if 0 <= tid < len(self.character) |
|
] |
|
conf_l = probs[b_idx][sel] if probs is not None else [1] * len(char_l) |
|
if len(conf_l) == 0: |
|
conf_l = [0] |
|
txt = "".join(char_l) |
|
if self.rev: |
|
txt = self._reverse(txt) |
|
res.append((txt, float(np.mean(conf_l)))) |
|
return res |
|
|
|
|
|
class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode): |
|
def __init__(self, char_path=None, use_space=False, **kwargs): super().__init__(char_path, use_space) |
|
|
|
def add_special_char(self, chars): return ["blank"] + chars |
|
|
|
def get_ignored_tokens(self): return [0] |
|
|
|
def __call__(self, preds, label=None, *args, **kwargs): |
|
preds = preds[-1] if isinstance(preds, (tuple, list)) else preds; |
|
preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds |
|
idxs = preds.argmax(axis=2); |
|
probs = preds.max(axis=2); |
|
txt = self.decode(idxs, probs, remove_dup=True); |
|
return txt |
|
|
|
|
|
class _MDR_TextRecognizer(_MDR_PredictBase): |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
shape_str = getattr(args, 'rec_image_shape', "3,48,320") |
|
self.shape = tuple(map(int, shape_str.split(','))) |
|
self.batch_num = getattr(args, 'rec_batch_num', 6) |
|
self.algo = getattr(args, 'rec_algorithm', 'SVTR_LCNet') |
|
self.post_op = _MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, |
|
use_space=getattr(args, 'use_space_char', True)) |
|
self.sess = self.get_onnx_session(args.rec_model_dir, args.use_gpu) |
|
self.input_name = self.get_input_name(self.sess) |
|
self.output_name = self.get_output_name(self.sess) |
|
|
|
|
|
def _resize_norm(self, img, max_r): |
|
imgC, imgH, imgW = self.shape |
|
h_orig, w_orig = img.shape[:2] |
|
|
|
print( |
|
f" DEBUG RECOGNIZER: _resize_norm input crop shape: ({h_orig}, {w_orig}), target shape: {self.shape}, max_r_batch: {max_r:.2f}") |
|
|
|
|
|
MIN_DIM_FOR_RESIZE = 2 |
|
if h_orig < MIN_DIM_FOR_RESIZE or w_orig < MIN_DIM_FOR_RESIZE: |
|
print( |
|
f" DEBUG RECOGNIZER: _resize_norm received degenerate crop ({h_orig}x{w_orig}) with dimension < {MIN_DIM_FOR_RESIZE}. Returning zeros before resize attempt.") |
|
return np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if h_orig == 0 or w_orig == 0: |
|
print( |
|
f" DEBUG RECOGNIZER: _resize_norm received zero-dimension crop ({h_orig}x{w_orig}) (secondary check). Returning zeros.") |
|
return np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
|
|
r_current = w_orig / float(h_orig) |
|
tw = min(imgW, int(ceil(imgH * r_current))) |
|
tw = max(1, tw) |
|
|
|
|
|
print(f" DEBUG RECOGNIZER: _resize_norm calculated target width (tw): {tw} for target height (imgH): {imgH}") |
|
|
|
try: |
|
|
|
if tw <= 0 or imgH <= 0: |
|
print( |
|
f" DEBUG RECOGNIZER: _resize_norm calculated invalid target resize dimensions (tw: {tw}, imgH: {imgH}). Returning zeros.") |
|
return np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
resized = cv2.resize(img, (tw, imgH)) |
|
except cv2.error as e_resize: |
|
print( |
|
f" DEBUG RECOGNIZER: _resize_norm cv2.resize failed: {e_resize}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") |
|
return np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
except Exception as e_resize_general: |
|
print( |
|
f" DEBUG RECOGNIZER: _resize_norm general error during resize: {e_resize_general}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") |
|
import traceback |
|
traceback.print_exc() |
|
return np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
|
|
|
|
resized = resized.astype("float32") |
|
if imgC == 1 and len(resized.shape) == 3: |
|
if resized.shape[2] == 3: |
|
resized = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY) |
|
|
|
if len(resized.shape) == 2: |
|
resized = resized[:, :, np.newaxis] |
|
|
|
|
|
if imgC == 3 and resized.shape[2] == 1: |
|
resized = cv2.cvtColor(resized, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
if resized.shape[2] != imgC: |
|
print( |
|
f" DEBUG RECOGNIZER: Channel mismatch after processing. Expected {imgC}, got {resized.shape[2]}. Crop shape ({h_orig},{w_orig}). Returning zeros.") |
|
return np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
|
|
resized = resized.transpose((2, 0, 1)) / 255.0 |
|
resized -= 0.5 |
|
resized /= 0.5 |
|
|
|
padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) |
|
|
|
actual_padded_width = min(tw, imgW) |
|
padding[:, :, 0:actual_padded_width] = resized[:, :, 0:actual_padded_width] |
|
|
|
print(f" DEBUG RECOGNIZER: _resize_norm output padded shape: {padding.shape}") |
|
|
|
min_px, max_px, mean_px = np.min(padding), np.max(padding), np.mean(padding) |
|
print(f" DEBUG RECOGNIZER: Normalized Crop Properties (before ONNX): " |
|
f"dtype: {padding.dtype}, " |
|
f"MinPx: {min_px:.4f}, " |
|
f"MaxPx: {max_px:.4f}, " |
|
f"MeanPx: {mean_px:.4f}") |
|
if np.all(padding == 0): |
|
print(" DEBUG RECOGNIZER: WARNING - Normalized image is all zeros!") |
|
elif np.abs(max_px - min_px) < 1e-6: |
|
print(f" DEBUG RECOGNIZER: WARNING - Normalized image is a constant value: {mean_px:.4f}") |
|
return padding |
|
|
|
def __call__(self, img_list): |
|
if not img_list: |
|
return [] |
|
num = len(img_list) |
|
ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list] |
|
indices = np.argsort(np.array(ratios)) |
|
results = [["", 0.0]] * num |
|
batch_n = self.batch_num |
|
for start in range(0, num, batch_n): |
|
end = min(num, start + batch_n) |
|
batch = [] |
|
max_r_batch = 0 |
|
for i in range(start, end): |
|
h, w = img_list[indices[i]].shape[0:2] |
|
if h > 0: |
|
max_r_batch = max(max_r_batch, w / float(h)) |
|
for i in range(start, end): |
|
batch.append(self._resize_norm(img_list[indices[i]], max_r_batch)[np.newaxis, :]) |
|
if not batch: |
|
continue |
|
batch = np.concatenate(batch, axis=0).copy() |
|
inputs = self.get_input_feed(self.input_name, batch) |
|
outputs = self.sess.run(self.output_name, input_feed=inputs) |
|
rec_out = self.post_op(outputs[0]) |
|
for i in range(len(rec_out)): |
|
results[indices[start + i]] = rec_out[i] |
|
return results |
|
|
|
|
|
|
|
class _MDR_TextSystem: |
|
|
|
def __init__(self, args): |
|
class ArgsObject: |
|
def __init__(self, **entries): self.__dict__.update(entries) |
|
|
|
if isinstance(args, dict): args = ArgsObject(**args) |
|
self.args = args |
|
self.detector = _MDR_TextDetector(args) |
|
self.recognizer = _MDR_TextRecognizer(args) |
|
self.use_cls = getattr(args, 'use_angle_cls', True) |
|
self.drop_score = getattr(args, 'drop_score', 0.5) |
|
self.classifier = _MDR_TextClassifier(args) if self.use_cls else None |
|
self.crop_idx = 0 |
|
self.save_crop = getattr(args, 'save_crop_res', False) |
|
self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res") |
|
|
|
def __call__(self, img: np.ndarray) -> tuple[list[np.ndarray], list[tuple[str, float]]]: |
|
ori_im = img.copy() |
|
|
|
dt_boxes: np.ndarray = self.detector(img) |
|
print( |
|
f" DEBUG TextSystem: Detector found {len(dt_boxes) if dt_boxes is not None and dt_boxes.size > 0 else 0} initial boxes.") |
|
if dt_boxes is None or dt_boxes.size == 0: |
|
return [], [] |
|
|
|
dt_boxes_sorted: list[np.ndarray] = self._sort_boxes(dt_boxes) |
|
print(f" DEBUG TextSystem: Sorted {len(dt_boxes_sorted)} boxes.") |
|
if not dt_boxes_sorted: |
|
return [], [] |
|
|
|
|
|
|
|
valid_boxes_for_cropping: list[np.ndarray] = [] |
|
img_crop_list: list[np.ndarray] = [] |
|
for i, box_pts in enumerate(dt_boxes_sorted): |
|
crop_im = mdr_get_rotated_crop(ori_im, box_pts) |
|
if crop_im is not None and crop_im.shape[0] > 1 and crop_im.shape[1] > 1: |
|
valid_boxes_for_cropping.append(box_pts) |
|
img_crop_list.append(crop_im) |
|
else: |
|
print( |
|
f" DEBUG TextSystem: Crop for box {i} (pts: {box_pts}) was None or too small. Skipping this box.") |
|
|
|
dt_boxes_sorted = valid_boxes_for_cropping |
|
|
|
|
|
|
|
print(f" DEBUG TextSystem: Created {len(img_crop_list)} valid crops for further processing.") |
|
|
|
if not img_crop_list: |
|
print(" DEBUG TextSystem: No valid crops generated. Returning empty.") |
|
return [], [] |
|
|
|
if self.use_cls and self.classifier is not None: |
|
print(f" DEBUG TextSystem: Applying text classification for {len(img_crop_list)} crops.") |
|
img_crop_list, cls_results = self.classifier( |
|
img_crop_list) |
|
print(f" DEBUG TextSystem: Classification complete. {len(cls_results if cls_results else [])} results.") |
|
|
|
rec_results: list[tuple[str, float]] = [] |
|
print(f" DEBUG TextSystem: Recognizing text for {len(img_crop_list)} crops.") |
|
rec_results = self.recognizer(img_crop_list) |
|
print(f" DEBUG TextSystem: Recognizer returned {len(rec_results)} results.") |
|
|
|
|
|
expected_count = len(dt_boxes_sorted) |
|
|
|
|
|
actual_rec_count = len(rec_results) |
|
num_to_process = 0 |
|
|
|
if actual_rec_count == expected_count: |
|
num_to_process = actual_rec_count |
|
else: |
|
print(f" DEBUG TextSystem: WARNING - Mismatch in lengths after recognition! " |
|
f"Expected (from boxes/crops): {expected_count}, " |
|
f"Recognizer returned: {actual_rec_count} results. ") |
|
num_to_process = min(actual_rec_count, expected_count) |
|
if num_to_process < expected_count: |
|
print( |
|
f" DEBUG TextSystem: Will process {num_to_process} items due to mismatch. Some data might be lost if recognizer dropped results or if there was an issue in earlier stages not caught.") |
|
elif num_to_process < actual_rec_count: |
|
print( |
|
f" DEBUG TextSystem: Will process {num_to_process} items. Recognizer returned more results ({actual_rec_count}) than expected crops ({expected_count}). Extra recognition results will be ignored.") |
|
|
|
if num_to_process == 0: |
|
if expected_count > 0: |
|
print( |
|
" DEBUG TextSystem: No recognition results to process (num_to_process is 0) despite having input boxes/crops. Returning empty.") |
|
else: |
|
print( |
|
" DEBUG TextSystem: No items to process (no initial boxes or num_to_process is 0). Returning empty.") |
|
return [], [] |
|
|
|
|
|
print( |
|
f" DEBUG TextSystem: Filtering {num_to_process} recognition results with drop_score: {self.drop_score}") |
|
final_boxes_to_return: list[np.ndarray] = [] |
|
final_recs_to_return: list[tuple[str, float]] = [] |
|
final_crops_for_saving: list[np.ndarray] = [] |
|
|
|
|
|
for i in range(num_to_process): |
|
|
|
|
|
text, confidence = rec_results[i] |
|
|
|
print(f" DEBUG TextSystem: Rec item {i} - Text: '{text}', Confidence: {confidence:.4f}") |
|
|
|
if confidence >= self.drop_score: |
|
if text and not mdr_is_whitespace(text): |
|
final_boxes_to_return.append(dt_boxes_sorted[i]) |
|
final_recs_to_return.append(rec_results[i]) |
|
if self.save_crop: |
|
|
|
|
|
final_crops_for_saving.append(img_crop_list[i]) |
|
else: |
|
print(f" DEBUG TextSystem: Item {i} REJECTED (empty/whitespace text).") |
|
else: |
|
print( |
|
f" DEBUG TextSystem: Item {i} REJECTED (confidence {confidence:.4f} < drop_score {self.drop_score}).") |
|
|
|
|
|
print(f" DEBUG TextSystem: Kept {len(final_boxes_to_return)} boxes after recognition and filtering.") |
|
|
|
if self.save_crop and final_crops_for_saving: |
|
print(f" DEBUG TextSystem: Saving {len(final_crops_for_saving)} filtered crops.") |
|
self._save_crops(final_crops_for_saving, final_recs_to_return) |
|
|
|
return final_boxes_to_return, final_recs_to_return |
|
|
|
def _sort_boxes(self, boxes): |
|
if boxes is None or len(boxes) == 0: return [] |
|
|
|
def key(box): |
|
min_y = min(p[1] for p in box); min_x = min(p[0] for p in box); return (min_y, min_x) |
|
|
|
try: |
|
return list(sorted(boxes, key=key)) |
|
except: |
|
return list(boxes) |
|
|
|
def _save_crops(self, crops, recs): |
|
mdr_ensure_directory(self.crop_dir) |
|
num = len(crops) |
|
for i in range(num): |
|
txt, score = recs[i] |
|
safe = re.sub(r'\W+', '_', txt)[:20] |
|
fname = f"crop_{self.crop_idx + i}_{safe}_{score:.2f}.jpg" |
|
cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i]) |
|
self.crop_idx += num |
|
|
|
|
|
|
|
def mdr_get_rotated_crop(img, points): |
|
"""Crops and perspective-transforms a quadrilateral region.""" |
|
pts = np.array(points, dtype="float32") |
|
assert len(pts) == 4 |
|
w = int(max(np.linalg.norm(pts[0] - pts[1]), np.linalg.norm(pts[2] - pts[3]))) |
|
h = int(max(np.linalg.norm(pts[0] - pts[3]), np.linalg.norm(pts[1] - pts[2]))) |
|
std = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) |
|
M = cv2.getPerspectiveTransform(pts, std) |
|
dst = cv2.warpPerspective(img, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=(128, 128, 128), flags=cv2.INTER_CUBIC) |
|
dh, dw = dst.shape[0:2] |
|
if dh > 0 and dw > 0 and dh * 1.0 / dw >= 1.5: |
|
dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE) |
|
return dst |
|
|
|
|
|
def mdr_get_min_area_crop(img, points): |
|
"""Crops the minimum area rectangle containing the points.""" |
|
bb = cv2.minAreaRect(np.array(points).astype(np.int32)) |
|
box_pts = cv2.boxPoints(bb) |
|
return mdr_get_rotated_crop(img, box_pts) |
|
|
|
|
|
|
|
_MDR_INCLUDES_MIN_RATE = 0.99 |
|
|
|
|
|
class _MDR_OverlapMatrixContext: |
|
|
|
def __init__(self, layouts: list[MDRLayoutElement]): |
|
length = len(layouts); |
|
self.polys: list[Polygon | None] = [] |
|
for l in layouts: |
|
try: |
|
p = Polygon(l.rect); self.polys.append(p if p.is_valid else None) |
|
except: |
|
self.polys.append(None) |
|
self.matrix = [[0.0] * length for _ in range(length)]; |
|
self.removed = set() |
|
for i in range(length): |
|
p1 = self.polys[i]; |
|
if p1 is None: continue; self.matrix[i][i] = 1.0 |
|
for j in range(i + 1, length): |
|
p2 = self.polys[j]; |
|
if p2 is None: continue |
|
r_ij = self._rate(p1, p2); |
|
r_ji = self._rate(p2, p1); |
|
self.matrix[i][j] = r_ij; |
|
self.matrix[j][i] = r_ji |
|
|
|
def _rate(self, p1: Polygon, p2: Polygon) -> float: |
|
try: |
|
inter = p1.intersection(p2) |
|
except: |
|
return 0.0 |
|
if inter.is_empty or inter.area < 1e-6: |
|
return 0.0 |
|
_, _, ix1, iy1 = inter.bounds |
|
iw = ix1 - inter.bounds[0] |
|
ih = iy1 - inter.bounds[1] |
|
_, _, px1, py1 = p2.bounds |
|
pw = px1 - p2.bounds[0] |
|
ph = py1 - p2.bounds[1] |
|
if pw < 1e-6 or ph < 1e-6: |
|
return 0.0 |
|
wr = min(iw / pw, 1.0) |
|
hr = min(ih / ph, 1.0) |
|
return (wr + hr) / 2.0 |
|
|
|
def others(self, idx: int): |
|
for i, r in enumerate(self.matrix[idx]): |
|
if i != idx and i not in self.removed: yield r |
|
|
|
def includes(self, idx: int): |
|
for i, r in enumerate(self.matrix[idx]): |
|
if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE: |
|
if self.matrix[i][idx] < _MDR_INCLUDES_MIN_RATE: yield i |
|
|
|
|
|
def mdr_remove_overlap_layouts(layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]: |
|
if not layouts: |
|
return [] |
|
ctx = _MDR_OverlapMatrixContext(layouts) |
|
prev_removed = -1 |
|
while len(ctx.removed) != prev_removed: |
|
prev_removed = len(ctx.removed) |
|
current_removed = set() |
|
for i in range(len(layouts)): |
|
if i in ctx.removed or i in current_removed: |
|
continue |
|
li = layouts[i] |
|
pi = ctx.polys[i] |
|
if pi is None: |
|
current_removed.add(i) |
|
continue |
|
contained = False |
|
for j in range(len(layouts)): |
|
if i == j or j in ctx.removed or j in current_removed: |
|
continue |
|
if ctx.matrix[j][i] >= _MDR_INCLUDES_MIN_RATE and ctx.matrix[i][j] < _MDR_INCLUDES_MIN_RATE: |
|
contained = True |
|
break |
|
if contained: |
|
current_removed.add(i) |
|
continue |
|
contained_by_i = list(ctx.includes(i)) |
|
if contained_by_i: |
|
for j in contained_by_i: |
|
if j not in ctx.removed and j not in current_removed: |
|
li.fragments.extend(layouts[j].fragments) |
|
current_removed.add(j) |
|
li.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) |
|
ctx.removed.update(current_removed) |
|
return [l for i, l in enumerate(layouts) if i not in ctx.removed] |
|
|
|
|
|
def _mdr_split_fragments_into_lines(frags: list[MDROcrFragment]) -> Generator[list[MDROcrFragment], None, None]: |
|
if not frags: |
|
return |
|
frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) |
|
group, y_sum, h_sum = [], 0.0, 0.0 |
|
for f in frags: |
|
_, y1, _, y2 = f.rect.wrapper |
|
h = y2 - y1 |
|
med_y = (y1 + y2) / 2.0 |
|
if h <= 0: |
|
continue |
|
if not group: |
|
group.append(f) |
|
y_sum, h_sum = med_y, h |
|
else: |
|
g_len = len(group) |
|
avg_med_y = y_sum / g_len |
|
avg_h = h_sum / g_len |
|
max_dev = avg_h * 0.40 |
|
if abs(med_y - avg_med_y) > max_dev: |
|
yield group |
|
group, y_sum, h_sum = [f], med_y, h |
|
else: |
|
group.append(f) |
|
y_sum += med_y |
|
h_sum += h |
|
if group: |
|
yield group |
|
|
|
|
|
def mdr_merge_fragments_into_lines(orig_frags: list[MDROcrFragment]) -> list[MDROcrFragment]: |
|
merged = [] |
|
for group in _mdr_split_fragments_into_lines(orig_frags): |
|
if not group: |
|
continue |
|
if len(group) == 1: |
|
merged.append(group[0]) |
|
continue |
|
group.sort(key=lambda f: f.rect.lt[0]) |
|
min_order = min(f.order for f in group if hasattr(f, 'order')) if group else 0 |
|
texts, rank_w, txt_len = [], 0.0, 0 |
|
x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") |
|
for f in group: |
|
fx1, fy1, fx2, fy2 = f.rect.wrapper |
|
x1, y1, x2, y2 = min(x1, fx1), min(y1, fy1), max(x2, fx2), max(y2, fy2) |
|
t = f.text |
|
l = len(t) |
|
if l > 0: |
|
texts.append(t) |
|
rank_w += f.rank * l |
|
txt_len += l |
|
if txt_len == 0: |
|
continue |
|
m_txt = " ".join(texts) |
|
m_rank = rank_w / txt_len if txt_len > 0 else 0.0 |
|
m_rect = MDRRectangle(lt=(x1, y1), rt=(x2, y1), lb=(x1, y2), rb=(x2, y2)) |
|
merged.append(MDROcrFragment(order=min_order, text=m_txt, rank=m_rank, rect=m_rect)) |
|
merged.sort(key=lambda f: (f.order, f.rect.lt[1], f.rect.lt[0])) |
|
for i, f in enumerate(merged): |
|
f.order = i |
|
return merged |
|
|
|
|
|
|
|
_MDR_CORRECTION_MIN_OVERLAP = 0.5 |
|
|
|
|
|
def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image, layout: MDRLayoutElement): |
|
|
|
if not layout.fragments: |
|
|
|
|
|
|
|
return |
|
|
|
|
|
try: |
|
x1, y1, x2, y2 = layout.rect.wrapper |
|
margin = 5 |
|
|
|
crop_x1 = max(0, round(x1) - margin) |
|
crop_y1 = max(0, round(y1) - margin) |
|
crop_x2 = min(source_img.width, round(x2) + margin) |
|
crop_y2 = min(source_img.height, round(y2) + margin) |
|
|
|
if crop_x1 >= crop_x2 or crop_y1 >= crop_y2: |
|
print( |
|
f"Correct: Crop box for layout {type(layout.cls).__name__} is invalid/empty ({crop_x1},{crop_y1},{crop_x2},{crop_y2}). Skipping OCR correction.") |
|
return |
|
|
|
cropped = source_img.crop((crop_x1, crop_y1, crop_x2, crop_y2)) |
|
off_x, off_y = crop_x1, crop_y1 |
|
except Exception as e: |
|
print(f"Correct: Crop error for layout {type(layout.cls).__name__}: {e}") |
|
return |
|
|
|
|
|
if cropped.width < 5 or cropped.height < 5: |
|
print( |
|
f"Correct: Cropped image for layout {type(layout.cls).__name__} is too small ({cropped.width}x{cropped.height}). Skipping OCR correction.") |
|
return |
|
|
|
try: |
|
|
|
cropped_np = np.array(cropped.convert("RGB"))[:, :, ::-1] |
|
new_frags_local = list(ocr_engine.find_text_fragments(cropped_np)) |
|
except Exception as e: |
|
print(f"Correct: OCR error during correction for layout {type(layout.cls).__name__}: {e}") |
|
|
|
|
|
return |
|
|
|
new_frags_global = [] |
|
|
|
for f in new_frags_local: |
|
r = f.rect |
|
lt, rt, lb, rb = r.lt, r.rt, r.lb, r.rb |
|
f.rect = MDRRectangle(lt=(lt[0] + off_x, lt[1] + off_y), rt=(rt[0] + off_x, rt[1] + off_y), |
|
lb=(lb[0] + off_x, lb[1] + off_y), rb=(rb[0] + off_x, rb[1] + off_y)) |
|
new_frags_global.append(f) |
|
|
|
orig_frags = layout.fragments |
|
matched, unmatched_orig = [], [] |
|
used_new = set() |
|
|
|
|
|
|
|
|
|
|
|
for i, orig_f in enumerate(orig_frags): |
|
best_j, best_rate = -1, -1.0 |
|
try: |
|
poly_o = Polygon(orig_f.rect) |
|
except: |
|
continue |
|
if not poly_o.is_valid: |
|
continue |
|
for j, new_f in enumerate(new_frags_global): |
|
if j in used_new: |
|
continue |
|
try: |
|
poly_n = Polygon(new_f.rect) |
|
except: |
|
continue |
|
if not poly_n.is_valid: |
|
continue |
|
try: |
|
inter = poly_o.intersection(poly_n) |
|
union = poly_o.union(poly_n) |
|
except: |
|
continue |
|
rate = inter.area / union.area if union.area > 1e-6 else 0.0 |
|
if rate > _MDR_CORRECTION_MIN_OVERLAP and rate > best_rate: |
|
best_rate = rate |
|
best_j = j |
|
if best_j != -1: |
|
matched.append((orig_f, new_frags_global[best_j])) |
|
used_new.add(best_j) |
|
else: |
|
unmatched_orig.append(orig_f) |
|
|
|
unmatched_new = [f for j, f in enumerate(new_frags_global) if j not in used_new] |
|
|
|
final = [n if n.rank >= o.rank else o for o, n in matched] |
|
final.extend(unmatched_orig) |
|
final.extend(unmatched_new) |
|
|
|
layout.fragments = final |
|
if layout.fragments: |
|
layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) |
|
|
|
|
|
|
|
_MDR_OCR_MODELS = {"det": ("ppocr_onnx", "model", "det_model", "en_PP-OCRv3_det_infer.onnx"), |
|
"cls": ("ppocr_onnx", "model", "cls_model", "ch_ppocr_mobile_v2.0_cls_infer.onnx"), |
|
"rec": ("ppocr_onnx", "model", "rec_model", "en_PP-OCRv3_rec_infer.onnx"), |
|
"keys": ("ppocr_onnx", "ppocr", "utils", "dict", "en_dict.txt")} |
|
|
|
_MDR_OCR_URL_BASE = "https://raw.githubusercontent.com/Kazuhito00/PaddleOCR-ONNX-Sample/main/" |
|
|
|
|
|
@dataclass |
|
class _MDR_ONNXParams: |
|
|
|
use_gpu: bool |
|
det_model_dir: str |
|
cls_model_dir: str |
|
rec_model_dir: str |
|
rec_char_dict_path: str |
|
|
|
|
|
use_angle_cls: bool = True |
|
rec_image_shape: str = "3,48,256" |
|
cls_image_shape: str = "3,48,192" |
|
cls_batch_num: int = 6 |
|
cls_thresh: float = 0.9 |
|
label_list: List[str] = field(default_factory=lambda: ['0', '180']) |
|
|
|
|
|
det_algorithm: str = "DB" |
|
det_limit_side_len: int = 1280 |
|
det_limit_type: str = 'min' |
|
det_db_thresh: float = 0.3 |
|
det_db_box_thresh: float = 0.6 |
|
det_db_unclip_ratio: float = 1.5 |
|
use_dilation: bool = False |
|
det_db_score_mode: str = 'fast' |
|
det_box_type: str = 'quad' |
|
|
|
|
|
rec_batch_num: int = 6 |
|
drop_score: float = 0.5 |
|
rec_algorithm: str = "SVTR_LCNet" |
|
use_space_char: bool = True |
|
|
|
|
|
save_crop_res: bool = False |
|
crop_res_save_dir: str = "./output/mdr_crop_res" |
|
show_log: bool = False |
|
use_onnx: bool = True |
|
|
|
|
|
class MDROcrEngine: |
|
"""Handles OCR detection and recognition using ONNX models.""" |
|
|
|
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str): |
|
self._device = device; |
|
self._model_dir = mdr_ensure_directory(model_dir_path) |
|
self._text_system: _MDR_TextSystem | None = None; |
|
self._onnx_params: _MDR_ONNXParams | None = None |
|
self._ensure_models(); |
|
self._get_system() |
|
|
|
def _ensure_models(self): |
|
for key, parts in _MDR_OCR_MODELS.items(): |
|
fp = Path(self._model_dir) / Path(*parts) |
|
if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join( |
|
parts); mdr_download_model(url, fp) |
|
|
|
def _get_system(self) -> _MDR_TextSystem | None: |
|
if self._text_system is None: |
|
paths = {k: str(Path(self._model_dir) / Path(*p)) for k, p in _MDR_OCR_MODELS.items()} |
|
|
|
self._onnx_params = _MDR_ONNXParams( |
|
use_gpu=(self._device == "cuda"), |
|
det_model_dir=paths["det"], |
|
cls_model_dir=paths["cls"], |
|
rec_model_dir=paths["rec"], |
|
rec_char_dict_path=paths["keys"], |
|
|
|
det_db_thresh=0.3, |
|
det_db_box_thresh=0.5, |
|
drop_score=0.1, |
|
use_angle_cls=False, |
|
) |
|
try: |
|
self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.") |
|
except Exception as e: |
|
print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None |
|
return self._text_system |
|
|
|
|
|
def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]: |
|
"""Finds and recognizes text fragments in a NumPy image (BGR).""" |
|
system = self._get_system() |
|
if system is None: |
|
print(" DEBUG OCR Engine: MDR OCR System unavailable. No fragments will be found.") |
|
return |
|
|
|
img_for_system = self._preprocess(image_np) |
|
print(f" DEBUG OCR Engine: Image preprocessed for TextSystem. Shape: {img_for_system.shape}") |
|
|
|
try: |
|
boxes, recs = system(img_for_system) |
|
except Exception as e: |
|
print(f" DEBUG OCR Engine: Error during TextSystem prediction: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return |
|
|
|
if not boxes or not recs: |
|
print( |
|
f" DEBUG OCR Engine: TextSystem returned no boxes ({len(boxes) if boxes is not None else 'None'}) or no recs ({len(recs) if recs is not None else 'None'}). No fragments generated.") |
|
return |
|
|
|
if len(boxes) != len(recs): |
|
print( |
|
f" DEBUG OCR Engine: Mismatch between boxes ({len(boxes)}) and recs ({len(recs)}) from TextSystem. This is problematic. No fragments generated.") |
|
return |
|
|
|
print( |
|
f" DEBUG OCR Engine: TextSystem returned {len(boxes)} boxes and {len(recs)} recognition results. Converting to MDROcrFragment.") |
|
fragments_generated_count = 0 |
|
for i, (box_pts, rec_tuple) in enumerate(zip(boxes, recs)): |
|
if not isinstance(rec_tuple, (list, tuple)) or len(rec_tuple) != 2: |
|
print(f" DEBUG OCR Engine: Rec item {i} is not a valid (text, score) tuple: {rec_tuple}. Skipping.") |
|
continue |
|
|
|
txt, conf = rec_tuple |
|
if not txt or mdr_is_whitespace(txt): |
|
|
|
continue |
|
|
|
try: |
|
pts = [(float(p[0]), float(p[1])) for p in box_pts] |
|
if len(pts) == 4: |
|
r = MDRRectangle(lt=pts[0], rt=pts[1], lb=pts[2], rb=pts[3]) |
|
if r.is_valid and r.area > 1: |
|
yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r) |
|
fragments_generated_count += 1 |
|
|
|
|
|
|
|
|
|
except Exception as e_frag: |
|
print(f" DEBUG OCR Engine: Error creating MDROcrFragment for item {i}: {e_frag}") |
|
continue |
|
|
|
print(f" DEBUG OCR Engine: Generated {fragments_generated_count} MDROcrFragment objects.") |
|
|
|
def _preprocess(self, img: np.ndarray) -> np.ndarray: |
|
if len(img.shape) == 3 and img.shape[2] == 4: |
|
a = img[:, :, 3] / 255.0 |
|
bg = (255, 255, 255) |
|
new = np.zeros_like(img[:, :, :3]) |
|
[setattr(new[:, :, i], 'flags.writeable', True) for i in range(3)] |
|
[np.copyto(new[:, :, i], (bg[i] * (1 - a) + img[:, :, i] * a)) for i in range(3)] |
|
img = new.astype(np.uint8) |
|
elif len(img.shape) == 2: |
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
|
elif not (len(img.shape) == 3 and img.shape[2] == 3): |
|
raise ValueError("Unsupported image format") |
|
return img |
|
|
|
|
|
|
|
_MDR_MAX_LEN = 510; |
|
_MDR_CLS_ID = 0; |
|
_MDR_SEP_ID = 2; |
|
_MDR_PAD_ID = 1 |
|
|
|
|
|
def mdr_boxes_to_reader_inputs(boxes: List[List[int]], max_len=_MDR_MAX_LEN) -> Dict[str, torch.Tensor]: |
|
t_boxes = boxes[:max_len]; |
|
i_boxes = [[0, 0, 0, 0]] + t_boxes + [[0, 0, 0, 0]] |
|
i_ids = [_MDR_CLS_ID] + [_MDR_PAD_ID] * len(t_boxes) + [_MDR_SEP_ID] |
|
a_mask = [1] * len(i_ids); |
|
pad_len = (max_len + 2) - len(i_ids) |
|
if pad_len > 0: i_boxes.extend([[0, 0, 0, 0]] * pad_len); i_ids.extend([_MDR_PAD_ID] * pad_len); a_mask.extend( |
|
[0] * pad_len) |
|
return {"bbox": torch.tensor([i_boxes]), "input_ids": torch.tensor([i_ids]), |
|
"attention_mask": torch.tensor([a_mask])} |
|
|
|
|
|
def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification) -> Dict[ |
|
str, torch.Tensor]: |
|
return {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
|
|
|
|
def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]: |
|
print(f"mdr_parse_reader_logits: Called with logits shape: {logits.shape}, length: {length}") |
|
if length == 0: |
|
print("mdr_parse_reader_logits: length is 0, returning empty list.") |
|
return [] |
|
|
|
print(f"mdr_parse_reader_logits: Attempting to slice logits with [1 : {length + 1}, :{length}]") |
|
try: |
|
rel_logits = logits[1: length + 1, :length] |
|
print(f"mdr_parse_reader_logits: rel_logits shape: {rel_logits.shape}") |
|
except IndexError as e: |
|
print(f"mdr_parse_reader_logits: IndexError during rel_logits slicing! Error: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
return list(range(length)) |
|
|
|
orders = rel_logits.argmax(dim=1).tolist() |
|
print(f"mdr_parse_reader_logits: Initial orders calculated. Count: {len(orders)}") |
|
|
|
|
|
loop_count = 0 |
|
|
|
|
|
|
|
|
|
max_loops = max(50, length * length) |
|
|
|
while True: |
|
loop_count += 1 |
|
if loop_count > max_loops: |
|
print( |
|
f"mdr_parse_reader_logits: Exceeded max_loops ({max_loops}), breaking while loop to prevent infinite loop.") |
|
break |
|
|
|
|
|
conflicts = defaultdict(list) |
|
[conflicts[order].append(idx) for idx, order in enumerate(orders)] |
|
|
|
|
|
conflicting_orders_map = {o: idxs for o, idxs in conflicts.items() if len(idxs) > 1} |
|
|
|
if not conflicting_orders_map: |
|
|
|
break |
|
|
|
|
|
if loop_count == 1 or loop_count % 10 == 0: |
|
print( |
|
f"mdr_parse_reader_logits: While loop iteration: {loop_count}. Found {len(conflicting_orders_map)} conflicting orders.") |
|
|
|
for order_val, c_idxs in conflicting_orders_map.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not c_idxs: continue |
|
|
|
best_c_idx_for_this_order = -1 |
|
max_confidence_for_this_order = -float('inf') |
|
|
|
for current_c_idx in c_idxs: |
|
confidence = rel_logits[current_c_idx, order_val].item() |
|
if confidence > max_confidence_for_this_order: |
|
max_confidence_for_this_order = confidence |
|
best_c_idx_for_this_order = current_c_idx |
|
|
|
|
|
|
|
|
|
|
|
for current_c_idx in c_idxs: |
|
if current_c_idx != best_c_idx_for_this_order: |
|
|
|
|
|
|
|
|
|
|
|
temp_logits_row = rel_logits[current_c_idx, :].clone() |
|
temp_logits_row[order_val] = -float('inf') |
|
orders[current_c_idx] = temp_logits_row.argmax().item() |
|
|
|
print( |
|
f"mdr_parse_reader_logits: While loop finished after {loop_count} iterations. Returning {len(orders)} orders.") |
|
return orders |
|
|
|
|
|
|
|
@dataclass |
|
class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; order: int; value: tuple[ |
|
float, float, float, float] |
|
|
|
|
|
class MDRLayoutReader: |
|
"""Determines reading order of layout elements using LayoutLMv3.""" |
|
|
|
def __init__(self, model_path: str): |
|
self._model_path = model_path |
|
self._model: LayoutLMv3ForTokenClassification | None = None |
|
|
|
if torch.cuda.is_available(): |
|
self._device = "cuda" |
|
print("MDRLayoutReader: CUDA is available. Setting device to cuda.") |
|
else: |
|
self._device = "cpu" |
|
print("MDRLayoutReader: CUDA not available. Setting device to cpu.") |
|
|
|
|
|
def _get_model(self) -> LayoutLMv3ForTokenClassification | None: |
|
if self._model is None: |
|
|
|
|
|
layoutreader_cache_dir = Path(self._model_path) |
|
mdr_ensure_directory(str(layoutreader_cache_dir)) |
|
|
|
name = "lakshya-rawat/document-qa-model" |
|
|
|
print(f"MDRLayoutReader: Attempting to load LayoutLMv3 model '{name}'. Cache dir: {layoutreader_cache_dir}") |
|
try: |
|
self._model = LayoutLMv3ForTokenClassification.from_pretrained( |
|
name, |
|
cache_dir=str(layoutreader_cache_dir), |
|
local_files_only=False, |
|
num_labels=_MDR_MAX_LEN + 1 |
|
) |
|
self._model.to(torch.device(self._device)) |
|
self._model.eval() |
|
print(f"MDR LayoutReader model '{name}' loaded successfully on device: {self._model.device}.") |
|
except Exception as e: |
|
print(f"ERROR loading MDR LayoutReader model '{name}': {e}") |
|
import traceback |
|
traceback.print_exc() |
|
self._model = None |
|
return self._model |
|
|
|
|
|
|
|
def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]: |
|
w, h = size |
|
if w <= 0 or h <= 0: |
|
print("MDRLayoutReader: Invalid image size (w or h <= 0), returning layouts as is.") |
|
return layouts |
|
if not layouts: |
|
print("MDRLayoutReader: No layouts to process, returning empty list.") |
|
return [] |
|
|
|
model = self._get_model() |
|
if model is None: |
|
print("MDRLayoutReader: Model not available, returning layouts sorted geometrically.") |
|
layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) |
|
return layouts |
|
|
|
print("MDRLayoutReader: Preparing bboxes...") |
|
|
|
bbox_list = self._prepare_bboxes(layouts, w, h) |
|
|
|
if bbox_list is None or len(bbox_list) == 0: |
|
print("MDRLayoutReader: No bboxes prepared from layouts, returning layouts as is (sorted geometrically).") |
|
layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) |
|
return layouts |
|
print(f"MDRLayoutReader: Prepared {len(bbox_list)} bboxes.") |
|
|
|
|
|
scaled_bboxes: list[list[int]] = [] |
|
if w > 0 and h > 0: |
|
for bbox_item in bbox_list: |
|
x0, y0, x1, y1 = bbox_item.value |
|
x0_c = max(0.0, min(x0, float(w))) |
|
y0_c = max(0.0, min(y0, float(h))) |
|
x1_c = max(0.0, min(x1, float(w))) |
|
y1_c = max(0.0, min(y1, float(h))) |
|
|
|
scaled_x0 = max(0, min(1000, int(1000 * x0_c / w))) |
|
scaled_y0 = max(0, min(1000, int(1000 * y0_c / h))) |
|
scaled_x1 = max(scaled_x0, min(1000, int(1000 * x1_c / w))) |
|
scaled_y1 = max(scaled_y0, min(1000, int(1000 * y1_c / h))) |
|
scaled_bboxes.append([scaled_x0, scaled_y0, scaled_x1, scaled_y1]) |
|
else: |
|
|
|
print( |
|
"MDRLayoutReader: Warning - Invalid image dimensions (w or h is zero) for scaling bboxes. Cannot determine reading order.") |
|
layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) |
|
return layouts |
|
|
|
|
|
if not scaled_bboxes: |
|
print( |
|
"MDRLayoutReader: No scaled bboxes available after scaling step. Returning geometrically sorted layouts.") |
|
layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) |
|
return layouts |
|
|
|
|
|
|
|
bypass_model_inference = False |
|
if len(scaled_bboxes) > 0: |
|
num_s_bboxes = len(scaled_bboxes) |
|
|
|
first_s_bbox_str = str(scaled_bboxes[0]) |
|
all_identical = all(str(s_b) == first_s_bbox_str for s_b in scaled_bboxes) |
|
|
|
if all_identical: |
|
|
|
s_x0, s_y0, s_x1, s_y1 = scaled_bboxes[0] |
|
if (s_x1 - s_x0 == 0) or (s_y1 - s_y0 == 0): |
|
bypass_model_inference = True |
|
print("MDRLayoutReader: All scaled bboxes are identical and degenerate. Bypassing LayoutLMv3.") |
|
|
|
if not bypass_model_inference and num_s_bboxes > 1: |
|
degenerate_count = 0 |
|
for s_b in scaled_bboxes: |
|
if (s_b[2] - s_b[0] == 0) or (s_b[3] - s_b[1] == 0): |
|
degenerate_count += 1 |
|
|
|
if degenerate_count / num_s_bboxes > 0.9: |
|
bypass_model_inference = True |
|
print( |
|
f"MDRLayoutReader: High percentage ({degenerate_count / num_s_bboxes * 100:.1f}%) of scaled bboxes are degenerate. Bypassing LayoutLMv3.") |
|
|
|
if bypass_model_inference: |
|
print("MDRLayoutReader: Applying fallback sequential order due to problematic scaled_bboxes.") |
|
|
|
for i in range(len(bbox_list)): |
|
bbox_list[i].order = i |
|
|
|
result_layouts = self._apply_order(layouts, bbox_list) |
|
return result_layouts |
|
|
|
|
|
orders: list[int] = [] |
|
try: |
|
with torch.no_grad(): |
|
print("MDRLayoutReader: Creating reader inputs...") |
|
inputs = mdr_boxes_to_reader_inputs(scaled_bboxes) |
|
print("MDRLayoutReader: Preparing inputs for model device...") |
|
inputs = mdr_prepare_reader_inputs(inputs, model) |
|
print("MDRLayoutReader: Running model inference...") |
|
logits = model(**inputs).logits.cpu().squeeze(0) |
|
print("MDRLayoutReader: Model inference complete. Parsing logits...") |
|
orders = mdr_parse_reader_logits(logits, len(bbox_list)) |
|
print(f"MDRLayoutReader: Logits parsed. Orders count: {len(orders)}") |
|
|
|
if len(orders) == len(bbox_list): |
|
for i, order_val in enumerate(orders): |
|
bbox_list[i].order = order_val |
|
else: |
|
print( |
|
f"MDRLayoutReader: Warning - Mismatch between orders ({len(orders)}) and bbox_list ({len(bbox_list)}). Using sequential order.") |
|
for i in range(len(bbox_list)): |
|
bbox_list[i].order = i |
|
except Exception as e: |
|
print(f"MDR LayoutReader prediction error: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
for i in range(len(bbox_list)): |
|
bbox_list[i].order = i |
|
print("MDRLayoutReader: Applying fallback sequential order due to error...") |
|
result_layouts = self._apply_order(layouts, bbox_list) |
|
return result_layouts |
|
|
|
print("MDRLayoutReader: Applying order...") |
|
result_layouts = self._apply_order(layouts, bbox_list) |
|
print("MDRLayoutReader: Order applied. Returning layouts.") |
|
return result_layouts |
|
|
|
def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None: |
|
line_h = self._estimate_line_h(layouts) |
|
bbox_list = [] |
|
for i, l in enumerate(layouts): |
|
if l.cls == MDRLayoutClass.PLAIN_TEXT and l.fragments: |
|
[bbox_list.append(_MDR_ReaderBBox(i, j, False, -1, f.rect.wrapper)) for j, f in enumerate(l.fragments)] |
|
else: |
|
bbox_list.extend(self._gen_virtual(l, i, line_h, w, h)) |
|
if len(bbox_list) > _MDR_MAX_LEN: |
|
print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})") |
|
return None |
|
bbox_list.sort(key=lambda b: (b.value[1], b.value[0])) |
|
return bbox_list |
|
|
|
|
|
|
|
def _apply_order(self, original_layouts_list: list[MDRLayoutElement], |
|
ordered_bbox_list_with_final_orders: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]: |
|
|
|
|
|
layout_map = defaultdict(list) |
|
for bbox_item in ordered_bbox_list_with_final_orders: |
|
layout_map[bbox_item.layout_index].append(bbox_item) |
|
|
|
|
|
|
|
|
|
|
|
|
|
layout_median_orders = [] |
|
for original_layout_idx, bboxes_for_this_layout in layout_map.items(): |
|
if bboxes_for_this_layout: |
|
|
|
median_order_for_layout = self._median([b.order for b in bboxes_for_this_layout]) |
|
layout_median_orders.append((original_layout_idx, median_order_for_layout)) |
|
|
|
layout_median_orders.sort(key=lambda x: x[1]) |
|
|
|
|
|
|
|
|
|
final_sorted_layouts = [original_layouts_list[idx] for idx, _ in layout_median_orders] |
|
|
|
|
|
|
|
nfo = 0 |
|
for layout_obj in final_sorted_layouts: |
|
if not layout_obj.fragments: |
|
continue |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
original_idx_of_current_layout = original_layouts_list.index(layout_obj) |
|
except ValueError: |
|
|
|
print( |
|
f" ERROR: Could not find layout in original list during fragment sort. Skipping fragment sort for this layout.") |
|
|
|
for i_frag, frag_in_layout in enumerate(layout_obj.fragments): |
|
frag_in_layout.order = nfo + i_frag |
|
nfo += len(layout_obj.fragments) |
|
continue |
|
|
|
|
|
reader_bboxes_for_this_layout = [ |
|
b for b in layout_map[original_idx_of_current_layout] if not b.virtual |
|
] |
|
|
|
if reader_bboxes_for_this_layout: |
|
|
|
frag_idx_to_new_order_map = { |
|
b.fragment_index: b.order for b in reader_bboxes_for_this_layout |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fragments_with_original_indices = list(enumerate(layout_obj.fragments)) |
|
|
|
fragments_with_original_indices.sort( |
|
key=lambda item: frag_idx_to_new_order_map.get(item[0], float('inf')) |
|
) |
|
|
|
|
|
layout_obj.fragments = [item[1] for item in |
|
fragments_with_original_indices] |
|
|
|
else: |
|
|
|
print( |
|
f" LayoutReader ApplyOrder: No reader_bboxes for layout (orig_idx {original_idx_of_current_layout}). Sorting frags geometrically.") |
|
layout_obj.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) |
|
|
|
|
|
for frag in layout_obj.fragments: |
|
frag.order = nfo |
|
nfo += 1 |
|
|
|
return final_sorted_layouts |
|
|
|
def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float: |
|
heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1] > 0] |
|
return self._median(heights) if heights else 15.0 |
|
|
|
def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[ |
|
_MDR_ReaderBBox, None, None]: |
|
x0, y0, x1, y1 = l.rect.wrapper |
|
lh = y1 - y0 |
|
lw = x1 - x0 |
|
if lh <= 0 or lw <= 0 or line_h <= 0: |
|
yield _MDR_ReaderBBox(l_idx, -1, True, -1, (x0, y0, x1, y1)) |
|
return |
|
lines = 1 |
|
if lh > line_h * 1.5: |
|
if lh <= ph * 0.25 or lw >= pw * 0.5: |
|
lines = 3 |
|
elif lw > pw * 0.25: |
|
lines = 3 if lw > pw * 0.4 else 2 |
|
elif lw <= pw * 0.25: |
|
lines = max(1, int(lh / (line_h * 1.5))) if lh / lw > 1.5 else 2 |
|
else: |
|
lines = max(1, int(round(lh / line_h))) |
|
lines = max(1, lines) |
|
act_line_h = lh / lines |
|
cur_y = y0 |
|
for i in range(lines): |
|
ly0 = max(0, min(ph, cur_y)) |
|
ly1 = max(0, min(ph, cur_y + act_line_h)) |
|
lx0 = max(0, min(pw, x0)) |
|
lx1 = max(0, min(pw, x1)) |
|
if ly1 > ly0 and lx1 > lx0: |
|
yield _MDR_ReaderBBox(l_idx, -1, True, -1, (lx0, ly0, lx1, ly1)) |
|
cur_y += act_line_h |
|
|
|
def _median(self, nums: list[float | int]) -> float: |
|
if not nums: |
|
return 0.0 |
|
s_nums = sorted(nums) |
|
n = len(s_nums) |
|
return float(s_nums[n // 2]) if n % 2 == 1 else float((s_nums[n // 2 - 1] + s_nums[n // 2]) / 2.0) |
|
|
|
|
|
|
|
class MDRLatexExtractor: |
|
"""Extracts LaTeX from formula images using pix2tex.""" |
|
|
|
def __init__(self, model_path: str): |
|
self._model_path = model_path; |
|
self._model: LatexOCR | None = None |
|
self._device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def extract(self, image: Image) -> str | None: |
|
if LatexOCR is None: return None; |
|
image = mdr_expand_image(image, 0.1) |
|
model = self._get_model() |
|
if model is None: return None; |
|
try: |
|
with torch.no_grad(): |
|
img_rgb = image.convert('RGB') if image.mode != 'RGB' else image; latex = model( |
|
img_rgb); return latex if latex else None |
|
except Exception as e: |
|
print(f"MDR LaTeX error: {e}"); return None |
|
|
|
def _get_model(self) -> LatexOCR | None: |
|
if self._model is None and LatexOCR is not None: |
|
mdr_ensure_directory(self._model_path) |
|
wp = Path(self._model_path) / "weights.pth" |
|
rp = Path(self._model_path) / "image_resizer.pth" |
|
cp = Path(self._model_path) / "config.yaml" |
|
if not wp.exists() or not rp.exists(): |
|
print("Downloading MDR LaTeX models...") |
|
self._download() |
|
if not cp.exists(): |
|
print(f"Warn: MDR LaTeX config not found {self._model_path}") |
|
try: |
|
args = Munch({"config": str(cp), "checkpoint": str(wp), "device": self._device, |
|
"no_cuda": self._device == "cuda", "no_resize": False, "temperature": 0.0}) |
|
self._model = LatexOCR(args) |
|
print(f"MDR LaTeX loaded on {self._device}.") |
|
except Exception as e: |
|
print(f"ERROR initializing MDR LatexOCR: {e}") |
|
self._model = None |
|
return self._model |
|
|
|
def _download(self): |
|
tag = "v0.0.1" |
|
base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/" |
|
files = {"weights.pth": base + "weights.pth", "image_resizer.pth": base + "image_resizer.pth"} |
|
mdr_ensure_directory(self._model_path) |
|
[mdr_download_model(url, Path(self._model_path) / name) for name, url in files.items() if |
|
not (Path(self._model_path) / name).exists()] |
|
|
|
|
|
|
|
MDRTableOutputFormat = Literal["latex", "markdown", "html"] |
|
|
|
|
|
class MDRTableParser: |
|
"""Parses table structure/content from images using StructTable model.""" |
|
|
|
def __init__(self, device: Literal["cpu", "cuda"], model_path: str): |
|
self._model: Any | None = None; |
|
self._model_path = mdr_ensure_directory(model_path) |
|
self._device = device if torch.cuda.is_available() and device == "cuda" else "cpu" |
|
self._disabled = self._device == "cuda" |
|
if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.") |
|
|
|
def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None: |
|
if self._disabled: return None; |
|
fmt: MDRTableOutputFormat | None = None |
|
if format == MDRTableLayoutParsedFormat.LATEX: |
|
fmt = "latex" |
|
elif format == MDRTableLayoutParsedFormat.MARKDOWN: |
|
fmt = "markdown" |
|
elif format == MDRTableLayoutParsedFormat.HTML: |
|
fmt = "html" |
|
else: |
|
return None |
|
image = mdr_expand_image(image, 0.05) |
|
model = self._get_model() |
|
if model is None: return None; |
|
try: |
|
img_rgb = image.convert('RGB') if image.mode != 'RGB' else image |
|
with torch.no_grad(): |
|
results = model([img_rgb], output_format=fmt) |
|
return results[0] if results else None |
|
except Exception as e: |
|
print(f"MDR Table parsing error: {e}"); return None |
|
|
|
def _get_model(self): |
|
if self._model is None and not self._disabled: |
|
try: |
|
from struct_eqtable import build_model |
|
name = "U4R/StructTable-InternVL2-1B"; |
|
local = any(Path(self._model_path).iterdir()) |
|
print(f"Loading MDR StructTable model '{name}'...") |
|
model = build_model(model_ckpt=name, max_new_tokens=1024, max_time=30, lmdeploy=False, flash_attn=True, |
|
batch_size=1, cache_dir=self._model_path, local_files_only=local) |
|
self._model = model.to(self._device); |
|
print(f"MDR StructTable loaded on {self._device}.") |
|
except ImportError: |
|
print("ERROR: struct_eqtable not found."); self._disabled = True; self._model = None |
|
except Exception as e: |
|
print(f"ERROR loading MDR StructTable: {e}"); self._model = None |
|
return self._model |
|
|
|
|
|
|
|
_MDR_TINY_ROTATION = 0.005 |
|
|
|
|
|
@dataclass |
|
class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationAdjuster; fragment_origin_rectangles: \ |
|
list[MDRRectangle] |
|
|
|
|
|
class MDRImageOptimizer: |
|
"""Handles image rotation detection and coordinate adjustments.""" |
|
|
|
def __init__(self, raw_image: Image, adjust_points: bool): |
|
self._raw = raw_image; |
|
self._image = raw_image; |
|
self._adjust_points = adjust_points |
|
self._fragments: list[MDROcrFragment] = []; |
|
self._rotation: float = 0.0; |
|
self._rot_ctx: _MDR_RotationContext | None = None |
|
|
|
@property |
|
def image(self) -> Image: |
|
return self._image |
|
|
|
@property |
|
def adjusted_image(self) -> Image | None: |
|
return self._image if self._rot_ctx is not None else None |
|
|
|
@property |
|
def rotation(self) -> float: |
|
return self._rotation |
|
|
|
@property |
|
def image_np(self) -> np.ndarray: |
|
img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) |
|
|
|
def receive_fragments(self, fragments: list[MDROcrFragment]): |
|
self._fragments = fragments |
|
if not fragments: |
|
return |
|
self._rotation = mdr_calculate_image_rotation(fragments) |
|
if abs(self._rotation) < _MDR_TINY_ROTATION: |
|
self._rotation = 0.0 |
|
return |
|
orig_sz = self._raw.size |
|
try: |
|
self._image = self._raw.rotate(-np.degrees(self._rotation), resample=PILResampling.BICUBIC, |
|
fillcolor=(255, 255, 255), expand=True) |
|
except Exception as e: |
|
print(f"Optimizer rotation error: {e}") |
|
self._rotation = 0.0 |
|
self._image = self._raw |
|
return |
|
new_sz = self._image.size |
|
self._rot_ctx = _MDR_RotationContext( |
|
fragment_origin_rectangles=[f.rect for f in fragments], |
|
to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False), |
|
to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True)) |
|
adj = self._rot_ctx.to_new |
|
[setattr(f, 'rect', |
|
MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f |
|
in fragments if (r := f.rect)] |
|
|
|
def finalize_layout_coords(self, layouts: list[MDRLayoutElement]): |
|
if self._rot_ctx is None: |
|
return |
|
|
|
if not self._adjust_points: |
|
|
|
if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): |
|
for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles): |
|
f.rect = orig_r |
|
|
|
adj = self._rot_ctx.to_origin |
|
for l in layouts: |
|
if (r := l.rect): |
|
l.rect = MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), |
|
rb=adj.adjust(r.rb)) |
|
|
|
|
|
|
|
|
|
|
|
def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image: |
|
"""Clips a potentially rotated rectangle from an image.""" |
|
try: |
|
h_rot, _ = mdr_calculate_rectangle_rotation(rect) |
|
avg_w, avg_h = rect.size |
|
if avg_w <= 0 or avg_h <= 0: |
|
return new_image("RGB", (1, 1), (255, 255, 255)) |
|
tx, ty = rect.lt |
|
trans_orig = np.array([[1, 0, -tx], [0, 1, -ty], [0, 0, 1]]) |
|
cos_r = cos(-h_rot) |
|
sin_r = sin(-h_rot) |
|
rot = np.array([[cos_r, -sin_r, 0], [sin_r, cos_r, 0], [0, 0, 1]]) |
|
pad_dx = wrap_w / 2.0 |
|
pad_dy = wrap_h / 2.0 |
|
trans_pad = np.array([[1, 0, pad_dx], [0, 1, pad_dy], [0, 0, 1]]) |
|
matrix = trans_pad @ rot @ trans_orig |
|
try: |
|
inv_matrix = np.linalg.inv(matrix) |
|
except np.linalg.LinAlgError: |
|
x0, y0, x1, y1 = rect.wrapper |
|
return image.crop((round(x0), round(y0), round(x1), round(y1))) |
|
p_mat = ( |
|
inv_matrix[0, 0], inv_matrix[0, 1], inv_matrix[0, 2], inv_matrix[1, 0], inv_matrix[1, 1], inv_matrix[1, 2]) |
|
out_w = ceil(avg_w + wrap_w) |
|
out_h = ceil(avg_h + wrap_h) |
|
return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, |
|
fillcolor=(255, 255, 255)) |
|
except Exception as e: |
|
print(f"MDR Clipping error: {e}") |
|
return new_image("RGB", (10, 10), (255, 255, 255)) |
|
|
|
|
|
def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, |
|
wrap_h: float = 0.0) -> Image: |
|
"""Clips a layout region from the MDRExtractionResult image.""" |
|
img = res.adjusted_image if res.adjusted_image else res.extracted_image |
|
return mdr_clip_from_image(img, layout.rect, wrap_w, wrap_h) |
|
|
|
|
|
|
|
_MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); |
|
_MDR_LAYOUT_COLORS = {MDRLayoutClass.TITLE: (0x0A, 0x12, 0x2C, 255), MDRLayoutClass.PLAIN_TEXT: (0x3C, 0x67, 0x90, 255), |
|
MDRLayoutClass.ABANDON: (0xC0, 0xBB, 0xA9, 180), MDRLayoutClass.FIGURE: (0x5B, 0x91, 0x3C, 255), |
|
MDRLayoutClass.FIGURE_CAPTION: (0x77, 0xB3, 0x54, 255), |
|
MDRLayoutClass.TABLE: (0x44, 0x17, 0x52, 255), |
|
MDRLayoutClass.TABLE_CAPTION: (0x81, 0x75, 0xA0, 255), |
|
MDRLayoutClass.TABLE_FOOTNOTE: (0xEF, 0xB6, 0xC9, 255), |
|
MDRLayoutClass.ISOLATE_FORMULA: (0xFA, 0x38, 0x27, 255), |
|
MDRLayoutClass.FORMULA_CAPTION: (0xFF, 0x9D, 0x24, 255)}; |
|
_MDR_DEFAULT_COLOR = (0x80, 0x80, 0x80, 255); |
|
_MDR_RGBA = tuple[int, int, int, int] |
|
|
|
|
|
def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None: |
|
"""Draws layout and fragment boxes onto an image for debugging.""" |
|
if not layouts: return; |
|
try: |
|
l_font = load_default(size=25) |
|
f_font = load_default(size=15) |
|
draw = ImageDraw.Draw(image, mode="RGBA") |
|
except Exception as e: |
|
print(f"MDR Plot init error: {e}") |
|
return |
|
|
|
def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA): |
|
try: |
|
x, y = pos |
|
txt = str(num) |
|
txt_pos = (round(x) + 3, round(y) + 1) |
|
bbox = draw.textbbox(txt_pos, txt, font=font) |
|
bg_rect = (bbox[0] - 2, bbox[1] - 1, bbox[2] + 2, bbox[3] + 1) |
|
bg_color = (color[0], color[1], color[2], 180) |
|
draw.rectangle(bg_rect, fill=bg_color) |
|
draw.text(txt_pos, txt, font=font, fill=(255, 255, 255, 255)) |
|
except Exception as e: |
|
print(f"MDR Draw num error: {e}") |
|
|
|
for i, l in enumerate(layouts): |
|
try: |
|
l_color = _MDR_LAYOUT_COLORS.get(l.cls, _MDR_DEFAULT_COLOR) |
|
draw.polygon([p for p in l.rect], outline=l_color, width=3) |
|
_draw_num(l.rect.lt, i + 1, l_font, l_color) |
|
except Exception as e: |
|
print(f"MDR Layout draw error: {e}") |
|
for l in layouts: |
|
for f in l.fragments: |
|
try: |
|
draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1) |
|
except Exception as e: |
|
print(f"MDR Fragment draw error: {e}") |
|
|
|
|
|
|
|
class MDRExtractionEngine: |
|
"""Core engine for extracting structured information from a document image.""" |
|
|
|
def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"] = "cpu", ocr_for_each_layouts: bool = True, |
|
extract_formula: bool = True, extract_table_format: MDRTableLayoutParsedFormat | None = None): |
|
self._model_dir = model_dir_path |
|
self._device = device if torch.cuda.is_available() else "cpu" |
|
self._ocr_each = ocr_for_each_layouts; |
|
self._ext_formula = extract_formula; |
|
self._ext_table = extract_table_format |
|
self._yolo: YOLOv10 | None = None |
|
|
|
self._ocr_engine = MDROcrEngine(device=self._device, model_dir_path=os.path.join(self._model_dir, "onnx_ocr")) |
|
self._table_parser = MDRTableParser(device=self._device, |
|
model_path=os.path.join(self._model_dir, "struct_eqtable")) |
|
self._latex_extractor = MDRLatexExtractor(model_path=os.path.join(self._model_dir, "latex")) |
|
self._layout_reader = MDRLayoutReader(model_path=os.path.join(self._model_dir, "layoutreader")) |
|
print(f"MDR Extraction Engine initialized on device: {self._device}") |
|
|
|
|
|
def _get_yolo_model(self) -> Any | None: |
|
"""Loads the YOLOv10b-DocLayNet layout detection model using ultralytics.YOLO.""" |
|
if self._yolo is None: |
|
|
|
|
|
repo_id = "hantian/yolo-doclaynet" |
|
filename = "yolov10b-doclaynet.pt" |
|
|
|
yolo_cache_dir = Path(self._model_dir) / "yolo_hf_cache_doclaynet" |
|
mdr_ensure_directory(str(yolo_cache_dir)) |
|
|
|
print(f"Attempting to load YOLO model '{filename}' from repo '{repo_id}' using ultralytics.YOLO...") |
|
print(f"Hugging Face Hub cache directory for YOLO: {yolo_cache_dir}") |
|
|
|
try: |
|
yolo_model_filepath = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
cache_dir=yolo_cache_dir, |
|
local_files_only=False, |
|
force_download=False, |
|
) |
|
print(f"YOLO model file path: {yolo_model_filepath}") |
|
|
|
from ultralytics import YOLO as UltralyticsYOLO |
|
self._yolo = UltralyticsYOLO(yolo_model_filepath) |
|
print("MDR YOLOv10b-DocLayNet model loaded successfully using ultralytics.YOLO.") |
|
|
|
except ImportError: |
|
print("ERROR: ultralytics library not found. Cannot load YOLOv10b-DocLayNet.") |
|
print("Please ensure it's installed: pip install ultralytics (matching version if possible)") |
|
self._yolo = None |
|
except HfHubHTTPError as e: |
|
print(f"ERROR: Failed to download YOLO model '{filename}' via Hugging Face Hub: {e}") |
|
self._yolo = None |
|
except Exception as e: |
|
print(f"ERROR: Failed to load YOLO model '{yolo_model_filepath}' with ultralytics.YOLO: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
self._yolo = None |
|
|
|
return self._yolo |
|
|
|
|
|
def analyze_image(self, image: Image, adjust_points: bool = False) -> MDRExtractionResult: |
|
"""Analyzes a single page image to extract layout and content.""" |
|
print(" Engine: Analyzing image...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processed_pil_image = image |
|
print(" Engine: CLAHE preprocessing applied to input image.") |
|
optimizer = MDRImageOptimizer(processed_pil_image, adjust_points) |
|
print(" Engine: Initial OCR...") |
|
frags = list(self._ocr_engine.find_text_fragments(optimizer.image_np)) |
|
print(f" Engine: {len(frags)} fragments found.") |
|
optimizer.receive_fragments(frags) |
|
frags = optimizer._fragments |
|
print(" Engine: Layout detection...") |
|
yolo = self._get_yolo_model() |
|
raw_layouts = [] |
|
if yolo: |
|
try: |
|
raw_layouts = list(self._run_yolo_detection(optimizer.image, yolo)) |
|
print(f" Engine: {len(raw_layouts)} raw layouts found.") |
|
except Exception: |
|
import traceback, sys |
|
traceback.print_exc(file=sys.stderr) |
|
print(" Engine: Matching fragments...") |
|
layouts = self._match_fragments_to_layouts(frags, raw_layouts) |
|
if not layouts and frags: |
|
|
|
page_rect = MDRRectangle( |
|
lt=(0, 0), rt=(optimizer.image.width, 0), |
|
lb=(0, optimizer.image.height), rb=(optimizer.image.width, optimizer.image.height) |
|
) |
|
dummy = MDRPlainLayoutElement( |
|
cls=MDRLayoutClass.PLAIN_TEXT, rect=page_rect, fragments=frags.copy() |
|
) |
|
layouts.append(dummy) |
|
print(" Engine: Removing overlaps...") |
|
layouts = mdr_remove_overlap_layouts(layouts) |
|
print(f" Engine: {len(layouts)} layouts after overlap removal.") |
|
if self._ocr_each and layouts: |
|
print(" Engine: OCR correction...") |
|
self._run_ocr_correction(optimizer.image, layouts) |
|
print(" Engine: Determining reading order...") |
|
layouts = self._layout_reader.determine_reading_order(layouts, optimizer.image.size) |
|
layouts = [l for l in layouts if self._should_keep_layout(l)] |
|
print(f" Engine: {len(layouts)} layouts after filtering.") |
|
if self._ext_table or self._ext_formula: |
|
print(" Engine: Parsing tables/formulas...") |
|
self._parse_special_layouts(layouts, optimizer) |
|
print(" Engine: Merging fragments...") |
|
[setattr(l, 'fragments', mdr_merge_fragments_into_lines(l.fragments)) for l in layouts] |
|
print(" Engine: Finalizing coords...") |
|
optimizer.finalize_layout_coords(layouts) |
|
print(" Engine: Analysis complete.") |
|
return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, |
|
adjusted_image=optimizer.adjusted_image) |
|
|
|
|
|
def _run_yolo_detection(self, img: Image, yolo: Any): |
|
img_rgb = img.convert("RGB") |
|
|
|
res_list = yolo.predict(source=img_rgb, imgsz=1024, conf=0.25, |
|
device=self._device, verbose=False) |
|
|
|
if not res_list or not hasattr(res_list[0], 'boxes') or res_list[0].boxes is None: |
|
print(" Engine: YOLO detection (ultralytics) returned no results or no boxes.") |
|
return |
|
|
|
results = res_list[0] |
|
|
|
model_class_names = {} |
|
if hasattr(results, 'names') and isinstance(results.names, dict): |
|
model_class_names = results.names |
|
print(f" Engine: YOLO model class names from ultralytics: {model_class_names}") |
|
else: |
|
|
|
|
|
print( |
|
" Engine: CRITICAL WARNING - Could not get class names from YOLO model. Layout mapping will likely be incorrect.") |
|
|
|
_doclaynet_names_fallback = ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', |
|
'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'] |
|
model_class_names = {i: name for i, name in enumerate(_doclaynet_names_fallback)} |
|
print(f" Engine: Using FALLBACK class names (VERIFY!): {model_class_names}") |
|
|
|
plain_mdr_classes: set[MDRLayoutClass] = { |
|
MDRLayoutClass.TITLE, MDRLayoutClass.PLAIN_TEXT, |
|
MDRLayoutClass.FIGURE_CAPTION, MDRLayoutClass.TABLE_CAPTION, |
|
MDRLayoutClass.TABLE_FOOTNOTE, MDRLayoutClass.FORMULA_CAPTION, |
|
} |
|
|
|
if results.boxes.cls is None or results.boxes.xyxy is None: |
|
print(" Engine: YOLO results.boxes.cls or .xyxy is None.") |
|
return |
|
|
|
print(f" Engine: Processing {len(results.boxes.cls)} detected YOLO boxes...") |
|
for i in range(len(results.boxes.cls)): |
|
yolo_cls_id = int(results.boxes.cls[i].item()) |
|
xyxy_tensor = results.boxes.xyxy[i] |
|
|
|
yolo_cls_name = model_class_names.get(yolo_cls_id, f"UnknownID-{yolo_cls_id}") |
|
|
|
mdr_cls = None |
|
|
|
|
|
if yolo_cls_name == 'Text': |
|
mdr_cls = MDRLayoutClass.PLAIN_TEXT |
|
elif yolo_cls_name == 'Title': |
|
mdr_cls = MDRLayoutClass.TITLE |
|
elif yolo_cls_name == 'Section-header': |
|
mdr_cls = MDRLayoutClass.TITLE |
|
elif yolo_cls_name == 'List-item': |
|
mdr_cls = MDRLayoutClass.PLAIN_TEXT |
|
elif yolo_cls_name == 'Table': |
|
mdr_cls = MDRLayoutClass.TABLE |
|
elif yolo_cls_name == 'Picture': |
|
mdr_cls = MDRLayoutClass.FIGURE |
|
elif yolo_cls_name == 'Formula': |
|
mdr_cls = MDRLayoutClass.ISOLATE_FORMULA |
|
elif yolo_cls_name == 'Caption': |
|
mdr_cls = MDRLayoutClass.FIGURE_CAPTION |
|
elif yolo_cls_name == 'Footnote': |
|
mdr_cls = MDRLayoutClass.TABLE_FOOTNOTE |
|
elif yolo_cls_name in ['Page-header', 'Page-footer']: |
|
mdr_cls = MDRLayoutClass.ABANDON |
|
|
|
if mdr_cls is None: |
|
|
|
continue |
|
|
|
|
|
|
|
x1, y1, x2, y2 = map(float, xyxy_tensor) |
|
rect = MDRRectangle(lt=(x1, y1), rt=(x2, y1), lb=(x1, y2), rb=(x2, y2)) |
|
if rect.area < 10: continue |
|
|
|
if mdr_cls == MDRLayoutClass.TABLE: |
|
yield MDRTableLayoutElement(rect=rect, fragments=[], parsed=None, cls=mdr_cls) |
|
elif mdr_cls == MDRLayoutClass.ISOLATE_FORMULA: |
|
yield MDRFormulaLayoutElement(rect=rect, fragments=[], latex=None, cls=mdr_cls) |
|
elif mdr_cls == MDRLayoutClass.FIGURE: |
|
yield MDRPlainLayoutElement(cls=mdr_cls, rect=rect, fragments=[]) |
|
elif mdr_cls in plain_mdr_classes: |
|
yield MDRPlainLayoutElement(cls=mdr_cls, rect=rect, fragments=[]) |
|
elif mdr_cls == MDRLayoutClass.ABANDON: |
|
yield MDRPlainLayoutElement(cls=mdr_cls, rect=rect, fragments=[]) |
|
def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[ |
|
MDRLayoutElement]: |
|
if not frags or not layouts: |
|
return layouts |
|
layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts] |
|
for frag in frags: |
|
try: |
|
frag_poly = Polygon(frag.rect) |
|
frag_area = frag_poly.area |
|
except: |
|
continue |
|
if not frag_poly.is_valid or frag_area < 1e-6: |
|
continue |
|
candidates = [] |
|
for idx, l_poly in enumerate(layout_polys): |
|
if l_poly is None: |
|
continue |
|
try: |
|
inter_area = frag_poly.intersection(l_poly).area |
|
except: |
|
continue |
|
overlap = inter_area / frag_area if frag_area > 0 else 0 |
|
if overlap > 0.85: |
|
candidates.append((idx, l_poly.area, overlap)) |
|
if candidates: |
|
candidates.sort(key=lambda x: (x[1], -x[2])) |
|
best_idx = candidates[0][0] |
|
layouts[best_idx].fragments.append(frag) |
|
for l in layouts: |
|
l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) |
|
return layouts |
|
|
|
def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]): |
|
for i, l in enumerate(layouts): |
|
if l.cls == MDRLayoutClass.FIGURE: continue |
|
try: |
|
mdr_correct_layout_fragments(self._ocr_engine, img, l) |
|
except Exception as e: |
|
print(f" Engine: OCR correction error layout {i}: {e}") |
|
|
|
def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer): |
|
img_to_clip = optimizer.image |
|
for l in layouts: |
|
if isinstance(l, MDRFormulaLayoutElement) and self._ext_formula: |
|
try: |
|
f_img = mdr_clip_from_image(img_to_clip, l.rect) |
|
l.latex = self._latex_extractor.extract(f_img) if f_img.width > 1 and f_img.height > 1 else None |
|
except Exception as e: |
|
print(f" Engine: LaTeX extract error: {e}") |
|
elif isinstance(l, MDRTableLayoutElement) and self._ext_table is not None: |
|
try: |
|
t_img = mdr_clip_from_image(img_to_clip, l.rect) |
|
parsed = self._table_parser.parse_table_image(t_img, |
|
self._ext_table) if t_img.width > 1 and t_img.height > 1 else None |
|
except Exception as e: |
|
print(f" Engine: Table parse error: {e}") |
|
parsed = None |
|
if parsed: |
|
l.parsed = (parsed, self._ext_table) |
|
|
|
def _should_keep_layout(self, l: MDRLayoutElement) -> bool: |
|
if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True |
|
return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA] |
|
|
|
|
|
|
|
class _MDR_LinkedShape: |
|
"""Internal helper for managing layout linking across pages.""" |
|
|
|
def __init__(self, layout: MDRLayoutElement): self.layout = layout; self.pre: list[MDRLayoutElement | None] = [None, |
|
None]; self.nex: \ |
|
list[MDRLayoutElement | None] = [None, None] |
|
|
|
@property |
|
def distance2(self) -> float: x, y = self.layout.rect.lt; return x * x + y * y |
|
|
|
|
|
class MDRPageSection: |
|
"""Represents a page's layouts for framework detection via linking.""" |
|
|
|
def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]): |
|
self._page_index = page_index; |
|
self._shapes = [_MDR_LinkedShape(l) for l in layouts]; |
|
self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0])) |
|
|
|
@property |
|
def page_index(self) -> int: |
|
return self._page_index |
|
|
|
def find_framework_elements(self) -> list[MDRLayoutElement]: |
|
"""Identifies framework layouts based on links to other pages.""" |
|
return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)] |
|
|
|
def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None: |
|
"""Links matching shapes between this section and the next.""" |
|
if offset not in (1, 2): |
|
return |
|
matches_matrix = [[sn for sn in next_section._shapes if self._shapes_match(ss, sn)] for ss in self._shapes] |
|
origin_pair = self._find_origin_pair(matches_matrix, next_section._shapes) |
|
if origin_pair is None: |
|
return |
|
orig_s, orig_n = origin_pair |
|
orig_s_pt = orig_s.layout.rect.lt |
|
orig_n_pt = orig_n.layout.rect.lt |
|
for i, s1 in enumerate(self._shapes): |
|
potentials = matches_matrix[i] |
|
if not potentials: |
|
continue |
|
r1_rel = self._relative_rect(orig_s_pt, s1.layout.rect) |
|
best_s2 = None |
|
max_ovr = -1.0 |
|
for s2 in potentials: |
|
r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect) |
|
ovr = self._symmetric_iou(r1_rel, r2_rel) |
|
if ovr > max_ovr: |
|
max_ovr = ovr |
|
best_s2 = s2 |
|
if max_ovr >= 0.80 and best_s2 is not None: |
|
s1.nex[offset - 1] = best_s2.layout |
|
best_s2.pre[offset - 1] = s1.layout |
|
|
|
def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool: |
|
l1 = s1.layout |
|
l2 = s2.layout |
|
sz1 = l1.rect.size |
|
sz2 = l2.rect.size |
|
thresh = 0.90 |
|
if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: |
|
return False |
|
f1 = l1.fragments |
|
f2 = l2.fragments |
|
c1 = len(f1) |
|
c2 = len(f2) |
|
if c1 == 0 and c2 == 0: |
|
return True |
|
if c1 == 0 or c2 == 0: |
|
return False |
|
matches = 0 |
|
used_f2 = [False] * c2 |
|
for frag1 in f1: |
|
best_j = -1 |
|
max_sim = -1.0 |
|
for j, frag2 in enumerate(f2): |
|
if not used_f2[j]: |
|
sim = self._fragment_sim(l1, l2, frag1, frag2) |
|
if sim > max_sim: |
|
max_sim = sim |
|
best_j = j |
|
if max_sim > 0.75: |
|
matches += 1 |
|
if best_j != -1: |
|
used_f2[best_j] = True |
|
max_c = max(c1, c2) |
|
rate_frags = matches / max_c |
|
return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95)) |
|
|
|
def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, |
|
f2: MDROcrFragment) -> float: |
|
r1_rel = self._relative_rect(l1.rect.lt, f1.rect) |
|
r2_rel = self._relative_rect(l2.rect.lt, f2.rect) |
|
geom_sim = self._symmetric_iou(r1_rel, r2_rel) |
|
text_sim, _ = mdr_check_text_similarity(f1.text, f2.text) |
|
return (geom_sim + text_sim) / 2.0 |
|
|
|
def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> \ |
|
tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None: |
|
best_pair = None |
|
min_dist2 = float('inf') |
|
for i, s1 in enumerate(self._shapes): |
|
match_list = matches_matrix[i] |
|
if not match_list: |
|
continue |
|
for s2 in match_list: |
|
dist2 = s1.distance2 + s2.distance2 |
|
if dist2 < min_dist2: |
|
min_dist2 = dist2 |
|
best_pair = (s1, s2) |
|
return best_pair |
|
|
|
def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool: |
|
if not thresholds: return False; idx = min(count, len(thresholds) - 1); return rate >= thresholds[idx] |
|
|
|
def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle: |
|
ox, oy = origin |
|
r = rect |
|
return MDRRectangle(lt=(r.lt[0] - ox, r.lt[1] - oy), rt=(r.rt[0] - ox, r.rt[1] - oy), |
|
lb=(r.lb[0] - ox, r.lb[1] - oy), rb=(r.rb[0] - ox, r.rb[1] - oy)) |
|
|
|
def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float: |
|
try: |
|
p1 = Polygon(r1) |
|
p2 = Polygon(r2) |
|
except: |
|
return 0.0 |
|
if not p1.is_valid or not p2.is_valid: |
|
return 0.0 |
|
try: |
|
inter = p1.intersection(p2) |
|
union = p1.union(p2) |
|
except: |
|
return 0.0 |
|
if inter.is_empty or inter.area < 1e-6: |
|
return 0.0 |
|
union_area = union.area |
|
return inter.area / union_area if union_area > 1e-6 else 1.0 |
|
|
|
|
|
|
|
_MDR_CONTEXT_PAGES = 2 |
|
|
|
|
|
@dataclass |
|
class MDRProcessingParams: |
|
"""Parameters for processing a document.""" |
|
pdf: str | FitzDocument; |
|
page_indexes: Iterable[int] | None; |
|
report_progress: MDRProgressReportCallback | None |
|
|
|
|
|
class MDRDocumentIterator: |
|
"""Iterates through document pages, handles context, and calls the extraction engine.""" |
|
|
|
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, |
|
extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, |
|
debug_dir_path: str | None): |
|
self._debug_dir = debug_dir_path |
|
self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, |
|
ocr_for_each_layouts=(ocr_level == MDROcrLevel.OncePerLayout), |
|
extract_formula=extract_formula, extract_table_format=extract_table_format) |
|
|
|
def iterate_sections(self, params: MDRProcessingParams) -> Generator[ |
|
tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]: |
|
"""Yields page index, extraction result, and content layouts for each requested page.""" |
|
|
|
for res, sec in self._process_and_link_sections(params): |
|
|
|
framework_element_ids = {id(fw_el) for fw_el in sec.find_framework_elements()} |
|
|
|
content = [l for l in res.layouts if id(l) not in framework_element_ids] |
|
yield sec.page_index, res, content |
|
|
|
def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[ |
|
tuple[MDRExtractionResult, MDRPageSection], None, None]: |
|
queue: list[tuple[MDRExtractionResult, MDRPageSection]] = [] |
|
for page_idx, res in self._run_extraction_on_pages(params): |
|
cur_sec = MDRPageSection(page_idx, res.layouts) |
|
for i, (_, prev_sec) in enumerate(queue): |
|
offset = len(queue) - i |
|
if offset <= _MDR_CONTEXT_PAGES: |
|
prev_sec.link_to_next(cur_sec, offset) |
|
queue.append((res, cur_sec)) |
|
if len(queue) > _MDR_CONTEXT_PAGES: |
|
yield queue.pop(0) |
|
for res, sec in queue: |
|
yield res, sec |
|
|
|
def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[ |
|
tuple[int, MDRExtractionResult], None, None]: |
|
if self._debug_dir: mdr_ensure_directory(self._debug_dir) |
|
doc, should_close = None, False |
|
if isinstance(params.pdf, str): |
|
try: |
|
doc = fitz.open(params.pdf); should_close = True |
|
except Exception as e: |
|
print(f"ERROR: PDF open failed: {e}"); return |
|
elif isinstance(params.pdf, FitzDocument): |
|
doc = params.pdf |
|
else: |
|
print(f"ERROR: Invalid PDF type: {type(params.pdf)}"); return |
|
scan_idxs, enable_idxs = self._get_page_ranges(doc, params.page_indexes) |
|
enable_set = set(enable_idxs); |
|
total_scan = len(scan_idxs) |
|
try: |
|
for i, page_idx in enumerate(scan_idxs): |
|
print(f" Iterator: Processing page {page_idx + 1}/{doc.page_count} (Scan {i + 1}/{total_scan})...") |
|
try: |
|
page = doc.load_page(page_idx) |
|
img = self._render_page_image(page, 300) |
|
res = self._engine.analyze_image(image=img, adjust_points=False) |
|
if self._debug_dir: |
|
self._save_debug_plot(img, page_idx, res, self._debug_dir) |
|
if page_idx in enable_set: |
|
yield page_idx, res |
|
if params.report_progress: |
|
params.report_progress(i + 1, total_scan) |
|
except Exception as e: |
|
print(f" Iterator: Page {page_idx + 1} processing error: {e}") |
|
finally: |
|
if should_close and doc: doc.close() |
|
|
|
def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int] | None) -> tuple[Sequence[int], Sequence[int]]: |
|
count = doc.page_count |
|
if idxs is None: |
|
all_p = list(range(count)) |
|
return all_p, all_p |
|
enable = set() |
|
scan = set() |
|
for i in idxs: |
|
if 0 <= i < count: |
|
enable.add(i) |
|
[scan.add(j) for j in range(max(0, i - _MDR_CONTEXT_PAGES), min(count, i + _MDR_CONTEXT_PAGES + 1))] |
|
return sorted(list(scan)), sorted(list(enable)) |
|
|
|
def _render_page_image(self, page: FitzPage, dpi: int) -> Image: |
|
mat = FitzMatrix(dpi / 72.0, dpi / 72.0) |
|
pix = page.get_pixmap(matrix=mat, alpha=False) |
|
return frombytes("RGB", (pix.width, pix.height), pix.samples) |
|
|
|
def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str): |
|
try: |
|
plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy() |
|
mdr_plot_layout(plot_img, res.layouts) |
|
plot_img.save(os.path.join(path, f"mdr_plot_page_{idx + 1}.png")) |
|
except Exception as e: |
|
print(f" Iterator: Plot generation error page {idx + 1}: {e}") |
|
|
|
|
|
|
|
class MagicPDFProcessor: |
|
""" |
|
Main class for processing PDF documents to extract structured data blocks |
|
using the MagicDataReadiness pipeline. |
|
""" |
|
|
|
def __init__(self, device: Literal["cpu", "cuda"] = "cuda", model_dir_path: str = "./mdr_models", |
|
ocr_level: MDROcrLevel = MDROcrLevel.Once, extract_formula: bool = True, |
|
extract_table_format: MDRExtractedTableFormat | None = None, debug_dir_path: str | None = None): |
|
""" |
|
Initializes the MagicPDFProcessor. |
|
Args: |
|
device: Computation device ('cpu' or 'cuda'). Defaults to 'cuda'. Fallbacks to 'cpu' if CUDA not available. |
|
model_dir_path: Path to directory for storing/caching downloaded models. Defaults to './mdr_models'. |
|
ocr_level: Level of OCR application (Once per page or Once per layout). Defaults to Once per page. |
|
extract_formula: Whether to attempt LaTeX extraction from formula images. Defaults to True. |
|
extract_table_format: Desired format for extracted table content (LATEX, MARKDOWN, HTML, DISABLE, or None). |
|
Defaults to LATEX if CUDA is available, otherwise DISABLE. |
|
debug_dir_path: Optional path to save debug plots and intermediate files. Defaults to None (disabled). |
|
""" |
|
actual_dev = device if torch.cuda.is_available() else "cpu"; |
|
print(f"MagicPDFProcessor using device: {actual_dev}.") |
|
if extract_table_format is None: extract_table_format = MDRExtractedTableFormat.LATEX if actual_dev == "cuda" else MDRExtractedTableFormat.DISABLE |
|
table_fmt_internal: MDRTableLayoutParsedFormat | None = None |
|
if extract_table_format == MDRExtractedTableFormat.LATEX: |
|
table_fmt_internal = MDRTableLayoutParsedFormat.LATEX |
|
elif extract_table_format == MDRExtractedTableFormat.MARKDOWN: |
|
table_fmt_internal = MDRTableLayoutParsedFormat.MARKDOWN |
|
elif extract_table_format == MDRExtractedTableFormat.HTML: |
|
table_fmt_internal = MDRTableLayoutParsedFormat.HTML |
|
self._iterator = MDRDocumentIterator(device=actual_dev, model_dir_path=model_dir_path, ocr_level=ocr_level, |
|
extract_formula=extract_formula, extract_table_format=table_fmt_internal, |
|
debug_dir_path=debug_dir_path) |
|
print("MagicPDFProcessor initialized.") |
|
|
|
def process_document(self, pdf_input: str | FitzDocument, |
|
report_progress: MDRProgressReportCallback | None = None) -> Generator[ |
|
MDRStructuredBlock, None, None]: |
|
""" |
|
Processes the entire PDF document and yields all extracted structured blocks. |
|
Args: |
|
pdf_input: Path to the PDF file or a loaded fitz.Document object. |
|
report_progress: Optional callback function for progress updates (receives completed_scan_pages, total_scan_pages). |
|
Yields: |
|
MDRStructuredBlock: An extracted block (MDRTextBlock, MDRTableBlock, etc.). |
|
""" |
|
print(f"Processing document: {pdf_input if isinstance(pdf_input, str) else 'FitzDocument object'}") |
|
for _, blocks, _ in self.process_document_pages(pdf_input=pdf_input, report_progress=report_progress, |
|
page_indexes=None): |
|
yield from blocks |
|
print("Document processing complete.") |
|
|
|
def process_document_pages(self, pdf_input: str | FitzDocument, page_indexes: Iterable[int] | None = None, |
|
report_progress: MDRProgressReportCallback | None = None) -> Generator[ |
|
tuple[int, list[MDRStructuredBlock], Image], None, None]: |
|
""" |
|
Processes specific pages (or all if page_indexes is None) of the PDF document. |
|
Yields results page by page, including the page index, extracted blocks, and the original page image. |
|
Args: |
|
pdf_input: Path to the PDF file or a loaded fitz.Document object. |
|
page_indexes: An iterable of 0-based page indices to process. If None, processes all pages. |
|
report_progress: Optional callback function for progress updates. |
|
Yields: |
|
tuple[int, list[MDRStructuredBlock], Image]: |
|
- page_index (0-based) |
|
- list of extracted MDRStructuredBlock objects for that page |
|
- PIL Image object of the original rendered page |
|
""" |
|
params = MDRProcessingParams(pdf=pdf_input, page_indexes=page_indexes, report_progress=report_progress) |
|
page_count = 0 |
|
for page_idx, extraction_result, content_layouts in self._iterator.iterate_sections(params): |
|
page_count += 1 |
|
print(f"Processor: Converting layouts to blocks for page {page_idx + 1}...") |
|
blocks = self._create_structured_blocks(extraction_result, content_layouts) |
|
print(f"Processor: Analyzing paragraph structure for page {page_idx + 1}...") |
|
self._analyze_paragraph_structure(blocks) |
|
print(f"Processor: Yielding results for page {page_idx + 1}.") |
|
yield page_idx, blocks, extraction_result.extracted_image |
|
print(f"Processor: Finished processing {page_count} pages.") |
|
|
|
def _create_structured_blocks(self, result: MDRExtractionResult, layouts: list[MDRLayoutElement]) -> list[ |
|
MDRStructuredBlock]: |
|
"""Converts MDRLayoutElement objects into MDRStructuredBlock objects.""" |
|
temp_store: list[tuple[MDRLayoutElement, MDRStructuredBlock]] = [] |
|
for layout in layouts: |
|
if isinstance(layout, MDRPlainLayoutElement): |
|
self._add_plain_block(temp_store, layout, result) |
|
elif isinstance(layout, MDRTableLayoutElement): |
|
temp_store.append((layout, self._create_table_block(layout, result))) |
|
elif isinstance(layout, MDRFormulaLayoutElement): |
|
temp_store.append((layout, self._create_formula_block(layout, result))) |
|
self._assign_relative_font_sizes(temp_store) |
|
return [block for _, block in temp_store] |
|
|
|
def _analyze_paragraph_structure(self, blocks: list[MDRStructuredBlock]): |
|
""" |
|
Calculates indentation and line-end heuristics for MDRTextBlocks |
|
based on page-level text boundaries and average line height. |
|
""" |
|
|
|
MIN_VALID_HEIGHT = 1e-6 |
|
|
|
INDENTATION_THRESHOLD_FACTOR = 1.0 |
|
|
|
LINE_END_THRESHOLD_FACTOR = 1.0 |
|
|
|
|
|
page_avg_line_height, page_min_x, page_max_x = self._calculate_text_range( |
|
(b for b in blocks if isinstance(b, MDRTextBlock) and b.kind != MDRTextKind.ABANDON) |
|
) |
|
|
|
|
|
if page_avg_line_height <= MIN_VALID_HEIGHT: |
|
return |
|
|
|
|
|
for block in blocks: |
|
|
|
if not isinstance(block, MDRTextBlock) or block.kind == MDRTextKind.ABANDON or not block.texts: |
|
continue |
|
|
|
|
|
avg_line_height = page_avg_line_height |
|
page_text_start_x = page_min_x |
|
page_text_end_x = page_max_x |
|
|
|
|
|
first_text_span = block.texts[0] |
|
last_text_span = block.texts[-1] |
|
|
|
try: |
|
|
|
|
|
first_line_start_x = (first_text_span.rect.lt[0] + first_text_span.rect.lb[0]) / 2.0 |
|
|
|
indentation_delta = first_line_start_x - page_text_start_x |
|
|
|
block.has_paragraph_indentation = indentation_delta > (avg_line_height * INDENTATION_THRESHOLD_FACTOR) |
|
|
|
|
|
|
|
last_line_end_x = (last_text_span.rect.rt[0] + last_text_span.rect.rb[0]) / 2.0 |
|
|
|
line_end_delta = page_text_end_x - last_line_end_x |
|
|
|
block.last_line_touch_end = line_end_delta < (avg_line_height * LINE_END_THRESHOLD_FACTOR) |
|
|
|
except Exception as e: |
|
|
|
print(f"Warn: Error calculating paragraph structure for block: {e}") |
|
|
|
block.has_paragraph_indentation = False |
|
block.last_line_touch_end = False |
|
|
|
def _calculate_text_range(self, blocks_iter: Iterable[MDRStructuredBlock]) -> tuple[float, float, float]: |
|
"""Calculates average line height and min/max x-coordinates for text.""" |
|
h_sum = 0.0 |
|
count = 0 |
|
x1 = float('inf') |
|
x2 = float('-inf') |
|
for b in blocks_iter: |
|
if not isinstance(b, MDRTextBlock) or b.kind == MDRTextKind.ABANDON: |
|
continue |
|
for t in b.texts: |
|
_, h = t.rect.size |
|
if h > 1e-6: |
|
h_sum += h |
|
count += 1 |
|
tx1, _, tx2, _ = t.rect.wrapper |
|
x1 = min(x1, tx1) |
|
x2 = max(x2, tx2) |
|
if count == 0: |
|
return 0.0, 0.0, 0.0 |
|
mean_h = h_sum / count |
|
x1 = 0.0 if x1 == float('inf') else x1 |
|
x2 = 0.0 if x2 == float('-inf') else x2 |
|
return mean_h, x1, x2 |
|
|
|
def _add_plain_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], layout: MDRPlainLayoutElement, |
|
result: MDRExtractionResult): |
|
"""Creates MDRStructuredBlocks for plain layout types.""" |
|
cls = layout.cls |
|
texts = self._convert_fragments_to_spans(layout.fragments) |
|
if cls == MDRLayoutClass.TITLE: |
|
store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.TITLE))) |
|
elif cls == MDRLayoutClass.PLAIN_TEXT: |
|
store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.PLAIN_TEXT))) |
|
elif cls == MDRLayoutClass.ABANDON: |
|
store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.ABANDON))) |
|
elif cls == MDRLayoutClass.FIGURE: |
|
store.append((layout, MDRFigureBlock(layout.rect, [], 0.0, mdr_clip_layout(result, layout)))) |
|
elif cls == MDRLayoutClass.FIGURE_CAPTION: |
|
block = self._find_previous_block(store, MDRFigureBlock) |
|
if block: block.texts.extend(texts) |
|
elif cls == MDRLayoutClass.TABLE_CAPTION or cls == MDRLayoutClass.TABLE_FOOTNOTE: |
|
block = self._find_previous_block(store, MDRTableBlock) |
|
if block: block.texts.extend(texts) |
|
elif cls == MDRLayoutClass.FORMULA_CAPTION: |
|
block = self._find_previous_block(store, MDRFormulaBlock) |
|
if block: block.texts.extend(texts) |
|
|
|
def _find_previous_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], |
|
block_type: type) -> MDRStructuredBlock | None: |
|
"""Finds the most recent block of a specific type.""" |
|
for i in range(len(store) - 1, -1, -1): |
|
_, block = store[i] |
|
if isinstance(block, block_type): |
|
return block |
|
return None |
|
|
|
def _create_table_block(self, layout: MDRTableLayoutElement, result: MDRExtractionResult) -> MDRTableBlock: |
|
"""Converts MDRTableLayoutElement to MDRTableBlock.""" |
|
fmt = MDRTableFormat.UNRECOGNIZABLE |
|
content = "" |
|
if layout.parsed: |
|
p_content, p_fmt = layout.parsed |
|
can_use = not (p_fmt == MDRTableLayoutParsedFormat.LATEX and mdr_contains_cjka( |
|
"".join(f.text for f in layout.fragments))) |
|
if can_use: |
|
content = p_content |
|
if p_fmt == MDRTableLayoutParsedFormat.LATEX: |
|
fmt = MDRTableFormat.LATEX |
|
elif p_fmt == MDRTableLayoutParsedFormat.MARKDOWN: |
|
fmt = MDRTableFormat.MARKDOWN |
|
elif p_fmt == MDRTableLayoutParsedFormat.HTML: |
|
fmt = MDRTableFormat.HTML |
|
return MDRTableBlock(layout.rect, [], 0.0, fmt, content, mdr_clip_layout(result, layout)) |
|
|
|
def _create_formula_block(self, layout: MDRFormulaLayoutElement, result: MDRExtractionResult) -> MDRFormulaBlock: |
|
"""Converts MDRFormulaLayoutElement to MDRFormulaBlock.""" |
|
content = layout.latex if layout.latex and not mdr_contains_cjka( |
|
"".join(f.text for f in layout.fragments)) else None |
|
return MDRFormulaBlock(layout.rect, [], 0.0, content, mdr_clip_layout(result, layout)) |
|
|
|
def _assign_relative_font_sizes(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]]): |
|
"""Calculates and assigns relative font size (0-1) to blocks.""" |
|
sizes = [] |
|
for l, _ in store: |
|
heights = [f.rect.size[1] for f in l.fragments if f.rect.size[1] > 1e-6] |
|
avg_h = sum(heights) / len(heights) if heights else 0.0 |
|
sizes.append(avg_h) |
|
valid = [s for s in sizes if s > 1e-6] |
|
min_s, max_s = (min(valid), max(valid)) if valid else (0.0, 0.0) |
|
rng = max_s - min_s |
|
if rng < 1e-6: |
|
[setattr(b, 'font_size', 0.0) for _, b in store] |
|
else: |
|
[setattr(b, 'font_size', (s - min_s) / rng if s > 1e-6 else 0.0) for s, (_, b) in zip(sizes, store)] |
|
|
|
def _convert_fragments_to_spans(self, frags: list[MDROcrFragment]) -> list[MDRTextSpan]: |
|
"""Converts MDROcrFragment list to MDRTextSpan list.""" |
|
return [MDRTextSpan(f.text, f.rank, f.rect) for f in frags] |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
print("=" * 60) |
|
print(" MagicDataReadiness PDF Processor - Example Usage") |
|
print("=" * 60) |
|
|
|
|
|
|
|
|
|
MDR_MODEL_DIRECTORY = "./mdr_pipeline_models" |
|
|
|
|
|
|
|
|
|
MDR_INPUT_PDF = "example_input.pdf" |
|
if not Path(MDR_INPUT_PDF).exists(): |
|
try: |
|
print(f"Creating dummy PDF: {MDR_INPUT_PDF}") |
|
doc = fitz.new_document() |
|
page = doc.new_page() |
|
page.insert_text((72, 72), "This is a dummy PDF for testing.") |
|
doc.save(MDR_INPUT_PDF) |
|
doc.close() |
|
except Exception as e: |
|
print(f"Warning: Could not create dummy PDF: {e}") |
|
|
|
|
|
MDR_DEBUG_DIRECTORY = "./mdr_debug_output" |
|
|
|
|
|
MDR_DEVICE = "cuda" |
|
|
|
|
|
MDR_TABLE_FORMAT = MDRExtractedTableFormat.MARKDOWN |
|
|
|
|
|
MDR_PAGES = None |
|
|
|
|
|
print(f"Model Directory: {os.path.abspath(MDR_MODEL_DIRECTORY)}") |
|
print(f"Input PDF: {os.path.abspath(MDR_INPUT_PDF)}") |
|
print(f"Debug Output: {os.path.abspath(MDR_DEBUG_DIRECTORY) if MDR_DEBUG_DIRECTORY else 'Disabled'}") |
|
print(f"Target Device: {MDR_DEVICE}") |
|
print(f"Table Format: {MDR_TABLE_FORMAT.name}") |
|
print(f"Pages: {'All' if MDR_PAGES is None else MDR_PAGES}") |
|
print("-" * 60) |
|
|
|
mdr_ensure_directory(MDR_MODEL_DIRECTORY) |
|
if MDR_DEBUG_DIRECTORY: |
|
mdr_ensure_directory(MDR_DEBUG_DIRECTORY) |
|
if not Path(MDR_INPUT_PDF).is_file(): |
|
print(f"ERROR: Input PDF not found at '{MDR_INPUT_PDF}'. Please place a PDF file there or update the path.") |
|
exit(1) |
|
|
|
|
|
|
|
def mdr_progress_update(completed, total): |
|
perc = (completed / total) * 100 if total > 0 else 0 |
|
print(f" [Progress] Scanned {completed}/{total} pages ({perc:.1f}%)") |
|
|
|
|
|
|
|
print("Initializing MagicPDFProcessor...") |
|
init_start = time.time() |
|
try: |
|
mdr_processor = MagicPDFProcessor( |
|
device=MDR_DEVICE, |
|
model_dir_path=MDR_MODEL_DIRECTORY, |
|
debug_dir_path=MDR_DEBUG_DIRECTORY, |
|
extract_table_format=MDR_TABLE_FORMAT |
|
) |
|
print(f"Initialization took {time.time() - init_start:.2f}s") |
|
except Exception as e: |
|
print(f"FATAL ERROR during initialization: {e}") |
|
import traceback |
|
|
|
traceback.print_exc() |
|
exit(1) |
|
|
|
|
|
print("\nStarting document processing...") |
|
proc_start = time.time() |
|
all_blocks_count = 0 |
|
processed_pages_count = 0 |
|
|
|
try: |
|
|
|
block_generator = mdr_processor.process_document_pages( |
|
pdf_input=MDR_INPUT_PDF, |
|
page_indexes=MDR_PAGES, |
|
report_progress=mdr_progress_update |
|
) |
|
|
|
|
|
for page_idx, page_blocks, page_img in block_generator: |
|
processed_pages_count += 1 |
|
print(f"\n--- Page {page_idx + 1} Results ---") |
|
if not page_blocks: |
|
print(" No blocks extracted.") |
|
continue |
|
|
|
print(f" Extracted {len(page_blocks)} blocks:") |
|
for block_idx, block in enumerate(page_blocks): |
|
all_blocks_count += 1 |
|
info = f" - Block {block_idx + 1}: {type(block).__name__}" |
|
if isinstance(block, MDRTextBlock): |
|
preview = block.texts[0].content[:70].replace('\n', ' ') + "..." if block.texts else "[EMPTY]" |
|
info += f" (Kind: {block.kind.name}, FontSz: {block.font_size:.2f}, Indent: {block.has_paragraph_indentation}, EndTouch: {block.last_line_touch_end}) | Text: '{preview}'" |
|
elif isinstance(block, MDRTableBlock): |
|
info += f" (Format: {block.format.name}, HasContent: {bool(block.content)}, FontSz: {block.font_size:.2f})" |
|
|
|
elif isinstance(block, MDRFormulaBlock): |
|
info += f" (HasLatex: {bool(block.content)}, FontSz: {block.font_size:.2f})" |
|
|
|
elif isinstance(block, MDRFigureBlock): |
|
info += f" (FontSz: {block.font_size:.2f})" |
|
print(info) |
|
|
|
proc_time = time.time() - proc_start |
|
print("\n" + "=" * 60) |
|
print(" Processing Summary") |
|
print(f" Total time: {proc_time:.2f} seconds") |
|
print(f" Pages processed: {processed_pages_count}") |
|
print(f" Total blocks extracted: {all_blocks_count}") |
|
print("=" * 60) |
|
|
|
except Exception as e: |
|
print(f"\nFATAL ERROR during processing: {e}") |
|
import traceback |
|
|
|
traceback.print_exc() |
|
exit(1) |