diff --git "a/mdr_pdf_parser.py" "b/mdr_pdf_parser.py" --- "a/mdr_pdf_parser.py" +++ "b/mdr_pdf_parser.py" @@ -17,18 +17,17 @@ # \=====================================================================/ # - # --- External Library Imports --- import os import re import io import copy -import fitz # PyMuPDF +import fitz # PyMuPDF from fitz import Document as FitzDocument, Page as FitzPage, Matrix as FitzMatrix import numpy as np -import cv2 # OpenCV -import torch # PyTorch -import requests # For downloading models +import cv2 # OpenCV +import torch # PyTorch +import requests # For downloading models from pathlib import Path from enum import auto, Enum from dataclasses import dataclass, field @@ -51,7 +50,7 @@ from enum import auto, Enum # --- HUGGING FACE HUB IMPORT ONLY BECAUSE RUNNING IN SPACES NOT NECESSARY IN PROD --- from huggingface_hub import hf_hub_download from huggingface_hub.errors import HfHubHTTPError -import time # Added for example usage timing +import time # Added for example usage timing # --- External Dependencies --- try: @@ -65,560 +64,660 @@ except ImportError: print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.") LatexOCR = None try: - pass # from struct_eqtable import build_model # Keep commented as per original + pass # from struct_eqtable import build_model # Keep commented as per original except ImportError: print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.") import torch + if not hasattr(torch, "get_default_device"): torch.get_default_device = lambda: torch.device("cuda" if torch.cuda.is_available() else "cpu") + # --- MagicDataReadiness Core Components --- # --- MDR Utilities --- def mdr_download_model(url: str, file_path: Path): - """Downloads a model file from a URL to a local path.""" - try: - response = requests.get(url, stream=True, timeout=120) # Increased timeout - response.raise_for_status() - file_path.parent.mkdir(parents=True, exist_ok=True) - with open(file_path, "wb") as file: - for chunk in response.iter_content(chunk_size=8192): - file.write(chunk) - print(f"Successfully downloaded {file_path.name}") - except requests.exceptions.RequestException as e: - print(f"ERROR: Failed to download {url}: {e}") - if file_path.exists(): os.remove(file_path) - raise FileNotFoundError(f"Failed to download model from {url}") from e - except Exception as e: - print(f"ERROR: Failed writing file {file_path}: {e}") - if file_path.exists(): os.remove(file_path) - raise e + """Downloads a model file from a URL to a local path.""" + try: + response = requests.get(url, stream=True, timeout=120) # Increased timeout + response.raise_for_status() + file_path.parent.mkdir(parents=True, exist_ok=True) + with open(file_path, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + print(f"Successfully downloaded {file_path.name}") + except requests.exceptions.RequestException as e: + print(f"ERROR: Failed to download {url}: {e}") + if file_path.exists(): os.remove(file_path) + raise FileNotFoundError(f"Failed to download model from {url}") from e + except Exception as e: + print(f"ERROR: Failed writing file {file_path}: {e}") + if file_path.exists(): os.remove(file_path) + raise e + def mdr_ensure_directory(path: str) -> str: - """Ensures a directory exists, creating it if necessary.""" - path = os.path.abspath(path) - os.makedirs(path, exist_ok=True) - return path + """Ensures a directory exists, creating it if necessary.""" + path = os.path.abspath(path) + os.makedirs(path, exist_ok=True) + return path + def mdr_is_whitespace(text: str) -> bool: - """Checks if a string contains only whitespace.""" - return bool(re.match(r"^\s*$", text)) if text else True + """Checks if a string contains only whitespace.""" + return bool(re.match(r"^\s*$", text)) if text else True + def mdr_expand_image(image: Image, percent: float) -> Image: - """Expands an image with a white border.""" - if percent <= 0: return image.copy() - w, h = image.size - bw, bh = ceil(w * percent), ceil(h * percent) - fill: tuple[int, ...] | int - if image.mode == "RGBA": fill = (255, 255, 255, 255) - elif image.mode in ("LA", "L"): fill = 255 - else: fill = (255, 255, 255) - return pil_expand(image=image, border=(bw, bh), fill=fill) + """Expands an image with a white border.""" + if percent <= 0: return image.copy() + w, h = image.size + bw, bh = ceil(w * percent), ceil(h * percent) + fill: tuple[int, ...] | int + if image.mode == "RGBA": + fill = (255, 255, 255, 255) + elif image.mode in ("LA", "L"): + fill = 255 + else: + fill = (255, 255, 255) + return pil_expand(image=image, border=(bw, bh), fill=fill) + # --- MDR Geometry --- MDRPoint: TypeAlias = tuple[float, float] + + @dataclass class MDRRectangle: - """Represents a geometric rectangle defined by four corner points.""" - lt: MDRPoint; rt: MDRPoint; lb: MDRPoint; rb: MDRPoint - def __iter__(self) -> Generator[MDRPoint, None, None]: yield self.lt; yield self.lb; yield self.rb; yield self.rt - @property - def is_valid(self) -> bool: - try: return Polygon(self).is_valid - except: return False - @property - def segments(self) -> Generator[tuple[MDRPoint, MDRPoint], None, None]: yield (self.lt, self.lb); yield (self.lb, self.rb); yield (self.rb, self.rt); yield (self.rt, self.lt) - @property - def area(self) -> float: - try: return Polygon(self).area - except: return 0.0 - @property - def size(self) -> tuple[float, float]: - widths, heights = [], [] - for i, (p1, p2) in enumerate(self.segments): - dx, dy = p1[0]-p2[0], p1[1]-p2[1] - dist = sqrt(dx*dx + dy*dy) - if i % 2 == 0: heights.append(dist) - else: widths.append(dist) - avg_w = sum(widths)/len(widths) if widths else 0.0 - avg_h = sum(heights)/len(heights) if heights else 0.0 - return avg_w, avg_h - @property - def wrapper(self) -> tuple[float, float, float, float]: - x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") - for x, y in self: - x1, y1, x2, y2 = min(x1, x), min(y1, y), max(x2, x), max(y2, y) - return x1, y1, x2, y2 + """Represents a geometric rectangle defined by four corner points.""" + lt: MDRPoint; + rt: MDRPoint; + lb: MDRPoint; + rb: MDRPoint + + def __iter__(self) -> Generator[MDRPoint, None, None]: + yield self.lt; yield self.lb; yield self.rb; yield self.rt + + @property + def is_valid(self) -> bool: + try: + return Polygon(self).is_valid + except: + return False + + @property + def segments(self) -> Generator[tuple[MDRPoint, MDRPoint], None, None]: + yield (self.lt, self.lb); yield (self.lb, self.rb); yield (self.rb, self.rt); yield (self.rt, self.lt) + + @property + def area(self) -> float: + try: + return Polygon(self).area + except: + return 0.0 + + @property + def size(self) -> tuple[float, float]: + widths, heights = [], [] + for i, (p1, p2) in enumerate(self.segments): + dx, dy = p1[0] - p2[0], p1[1] - p2[1] + dist = sqrt(dx * dx + dy * dy) + if i % 2 == 0: + heights.append(dist) + else: + widths.append(dist) + avg_w = sum(widths) / len(widths) if widths else 0.0 + avg_h = sum(heights) / len(heights) if heights else 0.0 + return avg_w, avg_h + + @property + def wrapper(self) -> tuple[float, float, float, float]: + x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") + for x, y in self: + x1, y1, x2, y2 = min(x1, x), min(y1, y), max(x2, x), max(y2, y) + return x1, y1, x2, y2 + def mdr_intersection_area(rect1: MDRRectangle, rect2: MDRRectangle) -> float: - """Calculates intersection area between two MDRRectangles.""" - try: - p1 = Polygon(rect1) - p2 = Polygon(rect2) - if not p1.is_valid or not p2.is_valid: + """Calculates intersection area between two MDRRectangles.""" + try: + p1 = Polygon(rect1) + p2 = Polygon(rect2) + if not p1.is_valid or not p2.is_valid: + return 0.0 + return p1.intersection(p2).area + except: return 0.0 - return p1.intersection(p2).area - except: - return 0.0 + # --- MDR Data Structures --- @dataclass class MDROcrFragment: - """Represents a fragment of text identified by OCR.""" - order: int; text: str; rank: float; rect: MDRRectangle + """Represents a fragment of text identified by OCR.""" + order: int; + text: str; + rank: float; + rect: MDRRectangle + class MDRLayoutClass(Enum): - """Enumeration of different layout types identified.""" - TITLE=0; PLAIN_TEXT=1; ABANDON=2; FIGURE=3; FIGURE_CAPTION=4; TABLE=5; TABLE_CAPTION=6; TABLE_FOOTNOTE=7; ISOLATE_FORMULA=8; FORMULA_CAPTION=9 + """Enumeration of different layout types identified.""" + TITLE = 0; + PLAIN_TEXT = 1; + ABANDON = 2; + FIGURE = 3; + FIGURE_CAPTION = 4; + TABLE = 5; + TABLE_CAPTION = 6; + TABLE_FOOTNOTE = 7; + ISOLATE_FORMULA = 8; + FORMULA_CAPTION = 9 + class MDRTableLayoutParsedFormat(Enum): - """Enumeration for formats of parsed table content.""" - LATEX=auto(); MARKDOWN=auto(); HTML=auto() + """Enumeration for formats of parsed table content.""" + LATEX = auto(); + MARKDOWN = auto(); + HTML = auto() + @dataclass class MDRBaseLayoutElement: - """Base class for layout elements found on a page.""" - rect: MDRRectangle; fragments: list[MDROcrFragment] + """Base class for layout elements found on a page.""" + rect: MDRRectangle; + fragments: list[MDROcrFragment] + @dataclass class MDRPlainLayoutElement(MDRBaseLayoutElement): - """Layout element for plain text, titles, captions, figures, etc.""" - # MODIFIED: Replaced Literal[...] with the Enum class name - cls: MDRLayoutClass # The type hint is now the Enum class itself + """Layout element for plain text, titles, captions, figures, etc.""" + # MODIFIED: Replaced Literal[...] with the Enum class name + cls: MDRLayoutClass # The type hint is now the Enum class itself + @dataclass class MDRTableLayoutElement(MDRBaseLayoutElement): - """Layout element specifically for tables.""" - parsed: tuple[str, MDRTableLayoutParsedFormat] | None - # MODIFIED: Replaced Literal[EnumMember] with the Enum class name - cls: MDRLayoutClass = MDRLayoutClass.TABLE # Hint with Enum, assign default member + """Layout element specifically for tables.""" + parsed: tuple[str, MDRTableLayoutParsedFormat] | None + # MODIFIED: Replaced Literal[EnumMember] with the Enum class name + cls: MDRLayoutClass = MDRLayoutClass.TABLE # Hint with Enum, assign default member + @dataclass class MDRFormulaLayoutElement(MDRBaseLayoutElement): - """Layout element specifically for formulas.""" - latex: str | None - # MODIFIED: Replaced Literal[EnumMember] with the Enum class name - cls: MDRLayoutClass = MDRLayoutClass.ISOLATE_FORMULA # Hint with Enum, assign default member + """Layout element specifically for formulas.""" + latex: str | None + # MODIFIED: Replaced Literal[EnumMember] with the Enum class name + cls: MDRLayoutClass = MDRLayoutClass.ISOLATE_FORMULA # Hint with Enum, assign default member + + +MDRLayoutElement = MDRPlainLayoutElement | MDRTableLayoutElement | MDRFormulaLayoutElement # Type alias -MDRLayoutElement = MDRPlainLayoutElement | MDRTableLayoutElement | MDRFormulaLayoutElement # Type alias @dataclass class MDRExtractionResult: - """Holds the complete result of extracting from a single page image.""" - rotation: float; layouts: list[MDRLayoutElement]; extracted_image: Image; adjusted_image: Image | None + """Holds the complete result of extracting from a single page image.""" + rotation: float; + layouts: list[MDRLayoutElement]; + extracted_image: Image; + adjusted_image: Image | None + # --- MDR Data Structures --- MDRProgressReportCallback: TypeAlias = Callable[[int, int], None] -class MDROcrLevel(Enum): Once=auto(); OncePerLayout=auto() -class MDRExtractedTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); DISABLE=auto() +class MDROcrLevel(Enum): Once = auto(); OncePerLayout = auto() + + +class MDRExtractedTableFormat(Enum): LATEX = auto(); MARKDOWN = auto(); HTML = auto(); DISABLE = auto() + + +class MDRTextKind(Enum): TITLE = 0; PLAIN_TEXT = 1; ABANDON = 2 -class MDRTextKind(Enum): TITLE=0; PLAIN_TEXT=1; ABANDON=2 @dataclass class MDRTextSpan: - """Represents a span of text content within a block.""" - content: str; rank: float; rect: MDRRectangle + """Represents a span of text content within a block.""" + content: str; + rank: float; + rect: MDRRectangle + @dataclass class MDRBasicBlock: - """Base class for structured blocks extracted from the document.""" - rect: MDRRectangle - texts: list[MDRTextSpan] - font_size: float # Relative font size (0-1) + """Base class for structured blocks extracted from the document.""" + rect: MDRRectangle + texts: list[MDRTextSpan] + font_size: float # Relative font size (0-1) + @dataclass class MDRTextBlock(MDRBasicBlock): - """A structured block containing text content.""" - kind: MDRTextKind - has_paragraph_indentation: bool = False - last_line_touch_end: bool = False + """A structured block containing text content.""" + kind: MDRTextKind + has_paragraph_indentation: bool = False + last_line_touch_end: bool = False + class MDRTableFormat(Enum): - LATEX=auto() - MARKDOWN=auto() - HTML=auto() - UNRECOGNIZABLE=auto() + LATEX = auto() + MARKDOWN = auto() + HTML = auto() + UNRECOGNIZABLE = auto() + @dataclass class MDRTableBlock(MDRBasicBlock): - """A structured block representing a table.""" - content: str - format: MDRTableFormat - image: Image # Image clip of the table + """A structured block representing a table.""" + content: str + format: MDRTableFormat + image: Image # Image clip of the table + @dataclass class MDRFormulaBlock(MDRBasicBlock): - """A structured block representing a formula.""" - content: str | None - image: Image # Image clip of the formula + """A structured block representing a formula.""" + content: str | None + image: Image # Image clip of the formula + @dataclass class MDRFigureBlock(MDRBasicBlock): - """A structured block representing a figure/image.""" - image: Image # Image clip of the figure + """A structured block representing a figure/image.""" + image: Image # Image clip of the figure + + +MDRAssetBlock = MDRTableBlock | MDRFormulaBlock | MDRFigureBlock # Type alias -MDRAssetBlock = MDRTableBlock | MDRFormulaBlock | MDRFigureBlock # Type alias +MDRStructuredBlock = MDRTextBlock | MDRAssetBlock # Type alias -MDRStructuredBlock = MDRTextBlock | MDRAssetBlock # Type alias # --- MDR Utilities --- def mdr_similarity_ratio(v1: float, v2: float) -> float: - """Calculates the ratio of the smaller value to the larger value (0-1).""" - if v1 == 0 and v2 == 0: - return 1.0 - if v1 < 0 or v2 < 0: - return 0.0 - v1, v2 = (v2, v1) if v1 > v2 else (v1, v2) - return 1.0 if v2 == 0 else v1 / v2 + """Calculates the ratio of the smaller value to the larger value (0-1).""" + if v1 == 0 and v2 == 0: + return 1.0 + if v1 < 0 or v2 < 0: + return 0.0 + v1, v2 = (v2, v1) if v1 > v2 else (v1, v2) + return 1.0 if v2 == 0 else v1 / v2 + def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[float, float]: - """Calculates width/height of the intersection bounding box.""" - try: - p1 = Polygon(r1) - p2 = Polygon(r2) - if not p1.is_valid or not p2.is_valid: - return 0.0, 0.0 - inter = p1.intersection(p2) - if inter.is_empty: + """Calculates width/height of the intersection bounding box.""" + try: + p1 = Polygon(r1) + p2 = Polygon(r2) + if not p1.is_valid or not p2.is_valid: + return 0.0, 0.0 + inter = p1.intersection(p2) + if inter.is_empty: + return 0.0, 0.0 + minx, miny, maxx, maxy = inter.bounds + return maxx - minx, maxy - miny + except: return 0.0, 0.0 - minx, miny, maxx, maxy = inter.bounds - return maxx - minx, maxy - miny - except: - return 0.0, 0.0 + _MDR_CJKA_PATTERN = re.compile(r"[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3\u0600-\u06ff]") + def mdr_contains_cjka(text: str): - """Checks if text contains Chinese, Japanese, Korean, or Arabic chars.""" - return bool(_MDR_CJKA_PATTERN.search(text)) if text else False + """Checks if text contains Chinese, Japanese, Korean, or Arabic chars.""" + return bool(_MDR_CJKA_PATTERN.search(text)) if text else False + # --- MDR Text Processing --- class _MDR_TokenPhase(Enum): - Init=0 - Letter=1 - Character=2 - Number=3 - Space=4 + Init = 0 + Letter = 1 + Character = 2 + Number = 3 + Space = 4 + _mdr_alphabet_detector = AlphabetDetector() + def _mdr_is_letter(char: str): - if not category(char).startswith("L"): - return False - try: - return _mdr_alphabet_detector.is_latin(char) or _mdr_alphabet_detector.is_cyrillic(char) or _mdr_alphabet_detector.is_greek(char) or _mdr_alphabet_detector.is_hebrew(char) - except: return False + if not category(char).startswith("L"): + return False + try: + return _mdr_alphabet_detector.is_latin(char) or _mdr_alphabet_detector.is_cyrillic( + char) or _mdr_alphabet_detector.is_greek(char) or _mdr_alphabet_detector.is_hebrew(char) + except: + return False + def mdr_split_into_words(text: str): - """Splits text into words, numbers, and individual non-alphanumeric chars.""" - if not text: return - sp = re.compile(r"\s") - np = re.compile(r"\d") - nsp = re.compile(r"[\.,']") - buf = io.StringIO() - phase = _MDR_TokenPhase.Init - for char in text: - is_l = _mdr_is_letter(char) - is_d = np.match(char) - is_s = sp.match(char) - is_ns = nsp.match(char) - if is_l: - if phase in (_MDR_TokenPhase.Number, _MDR_TokenPhase.Character): - w = buf.getvalue() - yield w if w else None - buf = io.StringIO() - buf.write(char) - phase = _MDR_TokenPhase.Letter - elif is_d: - if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Character): - w = buf.getvalue() - yield w if w else None - buf = io.StringIO() - buf.write(char) - phase = _MDR_TokenPhase.Number - elif phase == _MDR_TokenPhase.Number and is_ns: - buf.write(char) - else: - if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): + """Splits text into words, numbers, and individual non-alphanumeric chars.""" + if not text: return + sp = re.compile(r"\s") + np = re.compile(r"\d") + nsp = re.compile(r"[\.,']") + buf = io.StringIO() + phase = _MDR_TokenPhase.Init + for char in text: + is_l = _mdr_is_letter(char) + is_d = np.match(char) + is_s = sp.match(char) + is_ns = nsp.match(char) + if is_l: + if phase in (_MDR_TokenPhase.Number, _MDR_TokenPhase.Character): + w = buf.getvalue() + yield w if w else None + buf = io.StringIO() + buf.write(char) + phase = _MDR_TokenPhase.Letter + elif is_d: + if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Character): + w = buf.getvalue() + yield w if w else None + buf = io.StringIO() + buf.write(char) + phase = _MDR_TokenPhase.Number + elif phase == _MDR_TokenPhase.Number and is_ns: + buf.write(char) + else: + if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): + w = buf.getvalue() + yield w if w else None + buf = io.StringIO() + if is_s: + phase = _MDR_TokenPhase.Space + else: + yield char + phase = _MDR_TokenPhase.Character + if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): w = buf.getvalue() yield w if w else None - buf = io.StringIO() - if is_s: - phase = _MDR_TokenPhase.Space - else: - yield char - phase = _MDR_TokenPhase.Character - if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): - w = buf.getvalue() - yield w if w else None + def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]: - """Calculates word-based similarity between two texts.""" - w1 = list(mdr_split_into_words(t1)) - w2 = list(mdr_split_into_words(t2)) - l1 = len(w1) - l2 = len(w2) - if l1 == 0 and l2 == 0: - return 1.0, 0 - if l1 == 0 or l2 == 0: - return 0.0, max(l1, l2) - if l1 > l2: - w1, w2, l1, l2 = w2, w1, l2, l1 - taken = [False] * l2 - matches = 0 - for word1 in w1: - for i, word2 in enumerate(w2): - if not taken[i] and word1 == word2: - taken[i] = True - matches += 1 - break - mismatches = l2 - matches - return 1.0 - (mismatches / l2), l2 + """Calculates word-based similarity between two texts.""" + w1 = list(mdr_split_into_words(t1)) + w2 = list(mdr_split_into_words(t2)) + l1 = len(w1) + l2 = len(w2) + if l1 == 0 and l2 == 0: + return 1.0, 0 + if l1 == 0 or l2 == 0: + return 0.0, max(l1, l2) + if l1 > l2: + w1, w2, l1, l2 = w2, w1, l2, l1 + taken = [False] * l2 + matches = 0 + for word1 in w1: + for i, word2 in enumerate(w2): + if not taken[i] and word1 == word2: + taken[i] = True + matches += 1 + break + mismatches = l2 - matches + return 1.0 - (mismatches / l2), l2 + # --- MDR Geometry Processing --- class MDRRotationAdjuster: - """Adjusts point coordinates based on image rotation.""" - - def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool): - fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size) - self._rot = rotation if to_origin_coordinate else -rotation - self._c_off = (fs[0]/2.0, fs[1]/2.0) - self._n_off = (ts[0]/2.0, ts[1]/2.0) - - def adjust(self, point: MDRPoint) -> MDRPoint: - x = point[0] - self._c_off[0] - y = point[1] - self._c_off[1] - if x != 0 or y != 0: - cos_r = cos(self._rot) - sin_r = sin(self._rot) - x, y = x * cos_r - y * sin_r, x * sin_r + y * cos_r - return x + self._n_off[0], y + self._n_off[1] + """Adjusts point coordinates based on image rotation.""" + + def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, + to_origin_coordinate: bool): + fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size) + self._rot = rotation if to_origin_coordinate else -rotation + self._c_off = (fs[0] / 2.0, fs[1] / 2.0) + self._n_off = (ts[0] / 2.0, ts[1] / 2.0) + + def adjust(self, point: MDRPoint) -> MDRPoint: + x = point[0] - self._c_off[0] + y = point[1] - self._c_off[1] + if x != 0 or y != 0: + cos_r = cos(self._rot) + sin_r = sin(self._rot) + x, y = x * cos_r - y * sin_r, x * sin_r + y * cos_r + return x + self._n_off[0], y + self._n_off[1] + def mdr_normalize_vertical_rotation(rot: float) -> float: - while rot >= pi: - rot -= pi - while rot < 0: - rot += pi - return rot + while rot >= pi: + rot -= pi + while rot < 0: + rot += pi + return rot + def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[float]] | None: - h_angs, v_angs = [], [] - for i, (p1, p2) in enumerate(rect.segments): - dx = p2[0] - p1[0] - dy = p2[1] - p1[1] - if abs(dx) < 1e-6 and abs(dy) < 1e-6: - continue - ang = atan2(dy, dx) - if ang < 0: - ang += pi - if ang < pi * 0.25 or ang >= pi * 0.75: - h_angs.append(ang - pi if ang >= pi * 0.75 else ang) - else: - v_angs.append(ang) - if not h_angs or not v_angs: - return None - return h_angs, v_angs + h_angs, v_angs = [], [] + for i, (p1, p2) in enumerate(rect.segments): + dx = p2[0] - p1[0] + dy = p2[1] - p1[1] + if abs(dx) < 1e-6 and abs(dy) < 1e-6: + continue + ang = atan2(dy, dx) + if ang < 0: + ang += pi + if ang < pi * 0.25 or ang >= pi * 0.75: + h_angs.append(ang - pi if ang >= pi * 0.75 else ang) + else: + v_angs.append(ang) + if not h_angs or not v_angs: + return None + return h_angs, v_angs + def _mdr_normalize_horizontal_angles(rots: list[float]) -> list[float]: return rots + def _mdr_find_median(data: list[float]) -> float: - if not data: - return 0.0 - s_data = sorted(data) - n = len(s_data) - return s_data[n // 2] if n % 2 == 1 else (s_data[n // 2 - 1] + s_data[n // 2]) / 2.0 + if not data: + return 0.0 + s_data = sorted(data) + n = len(s_data) + return s_data[n // 2] if n % 2 == 1 else (s_data[n // 2 - 1] + s_data[n // 2]) / 2.0 + + +def _mdr_find_mean(data: list[float]) -> float: return sum(data) / len(data) if data else 0.0 -def _mdr_find_mean(data: list[float]) -> float: return sum(data)/len(data) if data else 0.0 def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float: - all_h, all_v = [], [] - for f in frags: - res = _mdr_get_rectangle_angles(f.rect) - if res: - h, v = res - all_h.extend(h) - all_v.extend(v) - if not all_h or not all_v: - return 0.0 - all_h = _mdr_normalize_horizontal_angles(all_h) - all_v = [mdr_normalize_vertical_rotation(a) for a in all_v] - med_h = _mdr_find_median(all_h) - med_v = _mdr_find_median(all_v) - rot_est = ((pi / 2 - med_v) - med_h) / 2.0 - while rot_est >= pi / 2: - rot_est -= pi - while rot_est < -pi / 2: - rot_est += pi - return rot_est + all_h, all_v = [], [] + for f in frags: + res = _mdr_get_rectangle_angles(f.rect) + if res: + h, v = res + all_h.extend(h) + all_v.extend(v) + if not all_h or not all_v: + return 0.0 + all_h = _mdr_normalize_horizontal_angles(all_h) + all_v = [mdr_normalize_vertical_rotation(a) for a in all_v] + med_h = _mdr_find_median(all_h) + med_v = _mdr_find_median(all_v) + rot_est = ((pi / 2 - med_v) - med_h) / 2.0 + while rot_est >= pi / 2: + rot_est -= pi + while rot_est < -pi / 2: + rot_est += pi + return rot_est + def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]: - res = _mdr_get_rectangle_angles(rect); - if res is None: return 0.0, pi/2.0; - h_rots, v_rots = res; - h_rots = _mdr_normalize_horizontal_angles(h_rots); v_rots = [mdr_normalize_vertical_rotation(a) for a in v_rots] - return _mdr_find_mean(h_rots), _mdr_find_mean(v_rots) + res = _mdr_get_rectangle_angles(rect); + if res is None: return 0.0, pi / 2.0; + h_rots, v_rots = res; + h_rots = _mdr_normalize_horizontal_angles(h_rots); + v_rots = [mdr_normalize_vertical_rotation(a) for a in v_rots] + return _mdr_find_mean(h_rots), _mdr_find_mean(v_rots) + # --- MDR ONNX OCR Internals --- class _MDR_PredictBase: - """Base class for ONNX model prediction components.""" + """Base class for ONNX model prediction components.""" - def get_onnx_session(self, model_path: str, use_gpu: bool): - try: - sess_opts = onnxruntime.SessionOptions() - sess_opts.log_severity_level = 3 - providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu and 'CUDAExecutionProvider' in onnxruntime.get_available_providers() else ['CPUExecutionProvider'] - session = onnxruntime.InferenceSession(model_path, sess_options=sess_opts, providers=providers) - print(f" ONNX session loaded: {Path(model_path).name} ({session.get_providers()})") - return session - except Exception as e: - print(f" ERROR loading ONNX session {Path(model_path).name}: {e}") - if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers(): - print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.") - raise e + def get_onnx_session(self, model_path: str, use_gpu: bool): + try: + sess_opts = onnxruntime.SessionOptions() + sess_opts.log_severity_level = 3 + providers = ['CUDAExecutionProvider', + 'CPUExecutionProvider'] if use_gpu and 'CUDAExecutionProvider' in onnxruntime.get_available_providers() else [ + 'CPUExecutionProvider'] + session = onnxruntime.InferenceSession(model_path, sess_options=sess_opts, providers=providers) + print(f" ONNX session loaded: {Path(model_path).name} ({session.get_providers()})") + return session + except Exception as e: + print(f" ERROR loading ONNX session {Path(model_path).name}: {e}") + if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers(): + print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.") + raise e + + def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: + return [n.name for n in sess.get_outputs()] - def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: - return [n.name for n in sess.get_outputs()] + def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: + return [n.name for n in sess.get_inputs()] - def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: - return [n.name for n in sess.get_inputs()] + def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: + return {name: img_np for name in names} - def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: - return {name: img_np for name in names} # --- MDR ONNX OCR Internals --- class _MDR_NormalizeImage: - def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): - self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0)) - mean = mean if mean is not None else [0.485, 0.456, 0.406] - std = std if std is not None else [0.229, 0.224, 0.225] - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype('float32') - self.std = np.array(std).reshape(shape).astype('float32') + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + self.scale = np.float32( + eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0)) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + img = np.array(img) if isinstance(img, Image) else img + data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std + return data - def __call__(self, data): - img = data['image'] - img = np.array(img) if isinstance(img, Image) else img - data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std - return data class _MDR_DetResizeForTest: - def __init__(self, **kwargs): - self.resize_type = 0 - self.keep_ratio = False - if 'image_shape' in kwargs: - self.image_shape = kwargs['image_shape'] - self.resize_type = 1 - self.keep_ratio = kwargs.get('keep_ratio', False) - elif 'limit_side_len' in kwargs: - self.limit_side_len = kwargs['limit_side_len'] - self.limit_type = kwargs.get('limit_type', 'min') - elif 'resize_long' in kwargs: - self.resize_type = 2 - self.resize_long = kwargs.get('resize_long', 960) - else: - self.limit_side_len = 736 - self.limit_type = 'min' - - def __call__(self, data): - img = data['image'] - src_h, src_w, _ = img.shape - if src_h + src_w < 64: - img = self._pad(img) - if self.resize_type == 0: - img, ratios = self._resize0(img) - elif self.resize_type == 2: - img, ratios = self._resize2(img) - else: - img, ratios = self._resize1(img) - if img is None: - return None - data['image'] = img - data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]) - return data - - def _pad(self, im, v=0): - h, w, c = im.shape - p = np.zeros((max(32, h), max(32, w), c), np.uint8) + v - p[:h, :w, :] = im - return p - - def _resize1(self, img): - rh, rw = self.image_shape - oh, ow = img.shape[:2] - if self.keep_ratio: - # Calculate new width based on aspect ratio - rw = ow * rh / oh - # Ensure width is a multiple of 32 - N = ceil(rw / 32) - rw = N * 32 - # Calculate resize ratios - r_h = float(rh) / oh - r_w = float(rw) / ow - # Resize image - img = cv2.resize(img, (int(rw), int(rh))) - return img, [r_h, r_w] - - def _resize0(self, img): - lsl = self.limit_side_len - h, w, _ = img.shape - r = 1.0 - if self.limit_type == 'max': - r = float(lsl) / max(h, w) if max(h, w) > lsl else 1.0 - elif self.limit_type == 'min': - r = float(lsl) / min(h, w) if min(h, w) < lsl else 1.0 - elif self.limit_type == 'resize_long': - r = float(lsl) / max(h, w) - else: - raise Exception('Unsupported limit_type') - rh = int(h * r) - rw = int(w * r) - rh = max(int(round(rh / 32) * 32), 32) - rw = max(int(round(rw / 32) * 32), 32) - if int(rw) <= 0 or int(rh) <= 0: - return None, (None, None) - img = cv2.resize(img, (int(rw), int(rh))) - r_h = rh / float(h) - r_w = rw / float(w) - return img, [r_h, r_w] - - def _resize2(self, img): - h, w, _ = img.shape - rl = self.resize_long - r = float(rl) / max(h, w) - rh = int(h * r) - rw = int(w * r) - ms = 128 - rh = (rh + ms - 1) // ms * ms - rw = (rw + ms - 1) // ms * ms - img = cv2.resize(img, (int(rw), int(rh))) - r_h = rh / float(h) - r_w = rw / float(w) - return img, [r_h, r_w] + def __init__(self, **kwargs): + self.resize_type = 0 + self.keep_ratio = False + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + self.keep_ratio = kwargs.get('keep_ratio', False) + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if src_h + src_w < 64: + img = self._pad(img) + if self.resize_type == 0: + img, ratios = self._resize0(img) + elif self.resize_type == 2: + img, ratios = self._resize2(img) + else: + img, ratios = self._resize1(img) + if img is None: + return None + data['image'] = img + data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]) + return data + + def _pad(self, im, v=0): + h, w, c = im.shape + p = np.zeros((max(32, h), max(32, w), c), np.uint8) + v + p[:h, :w, :] = im + return p + + def _resize1(self, img): + rh, rw = self.image_shape + oh, ow = img.shape[:2] + if self.keep_ratio: + # Calculate new width based on aspect ratio + rw = ow * rh / oh + # Ensure width is a multiple of 32 + N = ceil(rw / 32) + rw = N * 32 + # Calculate resize ratios + r_h = float(rh) / oh + r_w = float(rw) / ow + # Resize image + img = cv2.resize(img, (int(rw), int(rh))) + return img, [r_h, r_w] + + def _resize0(self, img): + lsl = self.limit_side_len + h, w, _ = img.shape + r = 1.0 + if self.limit_type == 'max': + r = float(lsl) / max(h, w) if max(h, w) > lsl else 1.0 + elif self.limit_type == 'min': + r = float(lsl) / min(h, w) if min(h, w) < lsl else 1.0 + elif self.limit_type == 'resize_long': + r = float(lsl) / max(h, w) + else: + raise Exception('Unsupported limit_type') + rh = int(h * r) + rw = int(w * r) + rh = max(int(round(rh / 32) * 32), 32) + rw = max(int(round(rw / 32) * 32), 32) + if int(rw) <= 0 or int(rh) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(rw), int(rh))) + r_h = rh / float(h) + r_w = rw / float(w) + return img, [r_h, r_w] + + def _resize2(self, img): + h, w, _ = img.shape + rl = self.resize_long + r = float(rl) / max(h, w) + rh = int(h * r) + rw = int(w * r) + ms = 128 + rh = (rh + ms - 1) // ms * ms + rw = (rw + ms - 1) // ms * ms + img = cv2.resize(img, (int(rw), int(rh))) + r_h = rh / float(h) + r_w = rw / float(w) + return img, [r_h, r_w] + class _MDR_ToCHWImage: - def __call__(self, data): - img = data['image'] - img = np.array(img) if isinstance(img, Image) else img - data['image'] = img.transpose((2, 0, 1)) - return data + def __call__(self, data): + img = data['image'] + img = np.array(img) if isinstance(img, Image) else img + data['image'] = img.transpose((2, 0, 1)) + return data + class _MDR_KeepKeys: - def __init__(self, keep_keys, **kwargs): self.keep_keys=keep_keys + def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys + + def __call__(self, data): return [data[key] for key in self.keep_keys] - def __call__(self, data): return [data[key] for key in self.keep_keys] def mdr_ocr_transform( - data: Any, - ops: Optional[List[Callable[[Any], Optional[Any]]]] = None + data: Any, + ops: Optional[List[Callable[[Any], Optional[Any]]]] = None ) -> Optional[Any]: """ Applies a sequence of transformation operations to the input data. @@ -644,386 +743,418 @@ def mdr_ocr_transform( else: operations_to_apply = ops - current_data = data # Use a separate variable to track the evolving data + current_data = data # Use a separate variable to track the evolving data # Sequentially apply each operation for op in operations_to_apply: - current_data = op(current_data) # Apply the operation + current_data = op(current_data) # Apply the operation # Check if the operation signaled failure or requested early exit # by returning None. if current_data is None: - return None # Short-circuit the pipeline + return None # Short-circuit the pipeline # If the loop completes without returning None, all operations succeeded. return current_data + def mdr_ocr_create_operators(op_param_list, global_config=None): - ops = [] - for operator in op_param_list: - assert isinstance(operator, dict) and len(operator)==1, "Op config error"; op_name = list(operator)[0] - param = {} if operator[op_name] is None else operator[op_name]; - if global_config: param.update(global_config) - op_class_name = f"_MDR_{op_name}" # Map to internal prefixed names - if op_class_name in globals() and isinstance(globals()[op_class_name], type): ops.append(globals()[op_class_name](**param)) - else: raise ValueError(f"Operator class '{op_class_name}' not found.") - return ops + ops = [] + for operator in op_param_list: + assert isinstance(operator, dict) and len(operator) == 1, "Op config error"; + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name]; + if global_config: param.update(global_config) + op_class_name = f"_MDR_{op_name}" # Map to internal prefixed names + if op_class_name in globals() and isinstance(globals()[op_class_name], type): + ops.append(globals()[op_class_name](**param)) + else: + raise ValueError(f"Operator class '{op_class_name}' not found.") + return ops + class _MDR_DBPostProcess: - def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs): - self.thresh = thresh - self.box_thresh = box_thresh - self.max_cand = max_candidates - self.unclip_r = unclip_ratio - self.min_sz = 3 - self.score_m = score_mode - self.box_t = box_type - assert score_mode in ["slow", "fast"] - self.dila_k = np.array([[1, 1], [1, 1]], dtype=np.uint8) if use_dilation else None - - def _polygons_from_bitmap(self, pred, bmp, dw, dh): - h, w = bmp.shape - boxes, scores = [], [] - contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - for contour in contours[:self.max_cand]: - eps = 0.002 * cv2.arcLength(contour, True) - approx = cv2.approxPolyDP(contour, eps, True) - pts = approx.reshape((-1, 2)) - if pts.shape[0] < 4: - continue - score = self._box_score_fast(pred, pts.reshape(-1, 2)) - if self.box_thresh > score: - continue - try: - box = self._unclip(pts, self.unclip_r) - except: - continue - if len(box) > 1: - continue - box = box.reshape(-1, 2) - _, sside = self._get_mini_boxes(box.reshape((-1, 1, 2))) - if sside < self.min_sz + 2: - continue - box = np.array(box) - box[:, 0] = np.clip(np.round(box[:, 0] / w * dw), 0, dw) - box[:, 1] = np.clip(np.round(box[:, 1] / h * dh), 0, dh) - boxes.append(box.tolist()) - scores.append(score) - return boxes, scores - -# In class _MDR_DBPostProcess: - def _boxes_from_bitmap(self, pred, bmp, dw, dh): # pred is the probability map, bmp is the binarized map - h, w = bmp.shape - # ADDED: More detailed logging - print(f" DEBUG OCR: _boxes_from_bitmap: Processing bitmap of shape {h}x{w} for original dimensions {dw:.1f}x{dh:.1f}.") - contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - num_contours_found = len(contours) - print(f" DEBUG OCR: _boxes_from_bitmap: Found {num_contours_found} raw contours.") - - num_contours_to_process = min(num_contours_found, self.max_cand) - if num_contours_found > self.max_cand: - print(f" DEBUG OCR: _boxes_from_bitmap: Processing limited to {self.max_cand} contours (max_candidates).") - - boxes, scores = [], [] - kept_boxes_count = 0 - for i in range(num_contours_to_process): - contour = contours[i] - pts_mini_box, sside = self._get_mini_boxes(contour) - if sside < self.min_sz: - # print(f" DEBUG OCR: Contour {i} too small (sside {sside:.2f} < min_sz {self.min_sz}). Skipping.") # Can be too verbose - continue - - pts_arr = np.array(pts_mini_box) - current_score = self._box_score_fast(pred, pts_arr.reshape(-1, 2)) if self.score_m == "fast" else self._box_score_slow(pred, contour) - - if self.box_thresh > current_score: - # print(f" DEBUG OCR: Contour {i} score {current_score:.4f} < box_thresh {self.box_thresh}. Skipping.") # Can be too verbose - continue - - try: - box_unclipped = self._unclip(pts_arr, self.unclip_r).reshape(-1, 1, 2) - except Exception as e_unclip: - # print(f" DEBUG OCR: Contour {i} unclip failed: {e_unclip}. Skipping.") # Can be too verbose - continue - - box_final, sside_final = self._get_mini_boxes(box_unclipped) - if sside_final < self.min_sz + 2: # min_sz is 3 - # print(f" DEBUG OCR: Contour {i} final size after unclip too small (sside_final {sside_final:.2f} < {self.min_sz + 2}). Skipping.") # Can be too verbose - continue - - box_final_arr = np.array(box_final) - box_final_arr[:, 0] = np.clip(np.round(box_final_arr[:, 0] / w * dw), 0, dw) - box_final_arr[:, 1] = np.clip(np.round(box_final_arr[:, 1] / h * dh), 0, dh) - - boxes.append(box_final_arr.astype("int32")) - scores.append(current_score) - kept_boxes_count +=1 - print(f" DEBUG OCR: _boxes_from_bitmap: Kept {kept_boxes_count} boxes after all filtering (size, score, unclip). Configured box_thresh: {self.box_thresh}, min_sz: {self.min_sz}.") - return np.array(boxes, dtype="int32"), scores - - def _unclip(self, box, ratio): - poly = Polygon(box) - dist = poly.area * ratio / poly.length - offset = pyclipper.PyclipperOffset() - offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - expanded = offset.Execute(dist) - if not expanded: - raise ValueError("Unclip failed") - return np.array(expanded[0]) - - def _get_mini_boxes(self, contour): - bb = cv2.minAreaRect(contour) - pts = sorted(list(cv2.boxPoints(bb)), key=lambda x: x[0]) - i1, i4 = (0, 1) if pts[1][1] > pts[0][1] else (1, 0) - i2, i3 = (2, 3) if pts[3][1] > pts[2][1] else (3, 2) - box = [pts[i1], pts[i2], pts[i3], pts[i4]] - return box, min(bb[1]) - - def _box_score_fast(self, bmp, box): - h, w = bmp.shape[:2] - xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) - xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) - ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) - ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) - mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) - box[:, 0] -= xmin - box[:, 1] -= ymin - cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) - return cv2.mean(bmp[ymin : ymax + 1, xmin : xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 - - def _box_score_slow(self, bmp, contour): # Not used if fast - h, w = bmp.shape[:2] - contour = np.reshape(contour.copy(), (-1, 2)) - xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) - xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) - ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) - ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) - mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) - contour[:, 0] -= xmin - contour[:, 1] -= ymin - cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) - return cv2.mean(bmp[ymin : ymax + 1, xmin : xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 - - def __call__(self, outs_dict, shape_list): - pred = outs_dict['maps'][:, 0, :, :] - seg = pred > self.thresh - # ADDED: More detailed logging - print(f" DEBUG OCR: _MDR_DBPostProcess: pred map shape: {pred.shape}, seg map shape: {seg.shape}, configured thresh: {self.thresh}") - print(f" DEBUG OCR: _MDR_DBPostProcess: Number of pixels in seg map above threshold (sum of all batches): {np.sum(seg)}") - - boxes_batch = [] - for batch_idx in range(pred.shape[0]): - # MODIFIED: Ensure sh, sw are floats for division if they come from shape_list - sh_orig, sw_orig, rh_ratio, rw_ratio = shape_list[batch_idx] - # The dw, dh for _boxes_from_bitmap should be the original image dimensions before DetResizeForTest - # shape_list contains [src_h, src_w, ratio_h, ratio_w] - # So dw = src_w, dh = src_h - dw_orig, dh_orig = sw_orig, sh_orig - - current_pred_map = pred[batch_idx] - current_seg_map = seg[batch_idx] - - mask = cv2.dilate(np.array(current_seg_map).astype(np.uint8), self.dila_k) if self.dila_k is not None else current_seg_map - print(f" DEBUG OCR: _MDR_DBPostProcess (batch {batch_idx}): Input shape to postproc (orig) {dh_orig:.1f}x{dw_orig:.1f}. Sum of mask pixels: {np.sum(mask)}") - - if self.box_t == 'poly': - boxes, scores = self._polygons_from_bitmap(current_pred_map, mask, dw_orig, dh_orig) - elif self.box_t == 'quad': - boxes, scores = self._boxes_from_bitmap(current_pred_map, mask, dw_orig, dh_orig) # Pass original dimensions - else: - raise ValueError("box_type must be 'quad' or 'poly'") - print(f" DEBUG OCR: _MDR_DBPostProcess (batch {batch_idx}): Found {len(boxes)} boxes from bitmap processing.") - boxes_batch.append({'points': boxes}) - return boxes_batch + def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, + score_mode="fast", box_type='quad', **kwargs): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_cand = max_candidates + self.unclip_r = unclip_ratio + self.min_sz = 3 + self.score_m = score_mode + self.box_t = box_type + assert score_mode in ["slow", "fast"] + self.dila_k = np.array([[1, 1], [1, 1]], dtype=np.uint8) if use_dilation else None + + def _polygons_from_bitmap(self, pred, bmp, dw, dh): + h, w = bmp.shape + boxes, scores = [], [] + contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours[:self.max_cand]: + eps = 0.002 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, eps, True) + pts = approx.reshape((-1, 2)) + if pts.shape[0] < 4: + continue + score = self._box_score_fast(pred, pts.reshape(-1, 2)) + if self.box_thresh > score: + continue + try: + box = self._unclip(pts, self.unclip_r) + except: + continue + if len(box) > 1: + continue + box = box.reshape(-1, 2) + _, sside = self._get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_sz + 2: + continue + box = np.array(box) + box[:, 0] = np.clip(np.round(box[:, 0] / w * dw), 0, dw) + box[:, 1] = np.clip(np.round(box[:, 1] / h * dh), 0, dh) + boxes.append(box.tolist()) + scores.append(score) + return boxes, scores + + # In class _MDR_DBPostProcess: + def _boxes_from_bitmap(self, pred, bmp, dw, dh): # pred is the probability map, bmp is the binarized map + h, w = bmp.shape + # ADDED: More detailed logging + print( + f" DEBUG OCR: _boxes_from_bitmap: Processing bitmap of shape {h}x{w} for original dimensions {dw:.1f}x{dh:.1f}.") + contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + num_contours_found = len(contours) + print(f" DEBUG OCR: _boxes_from_bitmap: Found {num_contours_found} raw contours.") + + num_contours_to_process = min(num_contours_found, self.max_cand) + if num_contours_found > self.max_cand: + print( + f" DEBUG OCR: _boxes_from_bitmap: Processing limited to {self.max_cand} contours (max_candidates).") + + boxes, scores = [], [] + kept_boxes_count = 0 + for i in range(num_contours_to_process): + contour = contours[i] + pts_mini_box, sside = self._get_mini_boxes(contour) + if sside < self.min_sz: + # print(f" DEBUG OCR: Contour {i} too small (sside {sside:.2f} < min_sz {self.min_sz}). Skipping.") # Can be too verbose + continue + + pts_arr = np.array(pts_mini_box) + current_score = self._box_score_fast(pred, pts_arr.reshape(-1, + 2)) if self.score_m == "fast" else self._box_score_slow( + pred, contour) + + if self.box_thresh > current_score: + # print(f" DEBUG OCR: Contour {i} score {current_score:.4f} < box_thresh {self.box_thresh}. Skipping.") # Can be too verbose + continue + + try: + box_unclipped = self._unclip(pts_arr, self.unclip_r).reshape(-1, 1, 2) + except Exception as e_unclip: + # print(f" DEBUG OCR: Contour {i} unclip failed: {e_unclip}. Skipping.") # Can be too verbose + continue + + box_final, sside_final = self._get_mini_boxes(box_unclipped) + if sside_final < self.min_sz + 2: # min_sz is 3 + # print(f" DEBUG OCR: Contour {i} final size after unclip too small (sside_final {sside_final:.2f} < {self.min_sz + 2}). Skipping.") # Can be too verbose + continue + + box_final_arr = np.array(box_final) + box_final_arr[:, 0] = np.clip(np.round(box_final_arr[:, 0] / w * dw), 0, dw) + box_final_arr[:, 1] = np.clip(np.round(box_final_arr[:, 1] / h * dh), 0, dh) + + boxes.append(box_final_arr.astype("int32")) + scores.append(current_score) + kept_boxes_count += 1 + print( + f" DEBUG OCR: _boxes_from_bitmap: Kept {kept_boxes_count} boxes after all filtering (size, score, unclip). Configured box_thresh: {self.box_thresh}, min_sz: {self.min_sz}.") + return np.array(boxes, dtype="int32"), scores + + def _unclip(self, box, ratio): + poly = Polygon(box) + dist = poly.area * ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = offset.Execute(dist) + if not expanded: + raise ValueError("Unclip failed") + return np.array(expanded[0]) + + def _get_mini_boxes(self, contour): + bb = cv2.minAreaRect(contour) + pts = sorted(list(cv2.boxPoints(bb)), key=lambda x: x[0]) + i1, i4 = (0, 1) if pts[1][1] > pts[0][1] else (1, 0) + i2, i3 = (2, 3) if pts[3][1] > pts[2][1] else (3, 2) + box = [pts[i1], pts[i2], pts[i3], pts[i4]] + return box, min(bb[1]) + + def _box_score_fast(self, bmp, box): + h, w = bmp.shape[:2] + xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] -= xmin + box[:, 1] -= ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) + return cv2.mean(bmp[ymin: ymax + 1, xmin: xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 + + def _box_score_slow(self, bmp, contour): # Not used if fast + h, w = bmp.shape[:2] + contour = np.reshape(contour.copy(), (-1, 2)) + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + contour[:, 0] -= xmin + contour[:, 1] -= ymin + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) + return cv2.mean(bmp[ymin: ymax + 1, xmin: xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'][:, 0, :, :] + seg = pred > self.thresh + # ADDED: More detailed logging + print( + f" DEBUG OCR: _MDR_DBPostProcess: pred map shape: {pred.shape}, seg map shape: {seg.shape}, configured thresh: {self.thresh}") + print( + f" DEBUG OCR: _MDR_DBPostProcess: Number of pixels in seg map above threshold (sum of all batches): {np.sum(seg)}") + + boxes_batch = [] + for batch_idx in range(pred.shape[0]): + # MODIFIED: Ensure sh, sw are floats for division if they come from shape_list + sh_orig, sw_orig, rh_ratio, rw_ratio = shape_list[batch_idx] + # The dw, dh for _boxes_from_bitmap should be the original image dimensions before DetResizeForTest + # shape_list contains [src_h, src_w, ratio_h, ratio_w] + # So dw = src_w, dh = src_h + dw_orig, dh_orig = sw_orig, sh_orig + + current_pred_map = pred[batch_idx] + current_seg_map = seg[batch_idx] + + mask = cv2.dilate(np.array(current_seg_map).astype(np.uint8), + self.dila_k) if self.dila_k is not None else current_seg_map + print( + f" DEBUG OCR: _MDR_DBPostProcess (batch {batch_idx}): Input shape to postproc (orig) {dh_orig:.1f}x{dw_orig:.1f}. Sum of mask pixels: {np.sum(mask)}") + + if self.box_t == 'poly': + boxes, scores = self._polygons_from_bitmap(current_pred_map, mask, dw_orig, dh_orig) + elif self.box_t == 'quad': + boxes, scores = self._boxes_from_bitmap(current_pred_map, mask, dw_orig, + dh_orig) # Pass original dimensions + else: + raise ValueError("box_type must be 'quad' or 'poly'") + print( + f" DEBUG OCR: _MDR_DBPostProcess (batch {batch_idx}): Found {len(boxes)} boxes from bitmap processing.") + boxes_batch.append({'points': boxes}) + return boxes_batch + class _MDR_TextDetector(_MDR_PredictBase): - def __init__(self, args): - super().__init__() - self.args = args - pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}] - self.pre_op = mdr_ocr_create_operators(pre_ops) - post_params = {'thresh': args.det_db_thresh, 'box_thresh': args.det_db_box_thresh, 'max_candidates': 1000, 'unclip_ratio': args.det_db_unclip_ratio, 'use_dilation': args.use_dilation, 'score_mode': args.det_db_score_mode, 'box_type': args.det_box_type} - self.post_op = _MDR_DBPostProcess(**post_params) - self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu) - self.input_name = self.get_input_name(self.sess) - self.output_name = self.get_output_name(self.sess) - - def _order_pts(self, pts): - r = np.zeros((4, 2), dtype="float32") - s = pts.sum(axis=1) - r[0] = pts[np.argmin(s)] - r[2] = pts[np.argmax(s)] - tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) - d = np.diff(np.array(tmp), axis=1) - r[1] = tmp[np.argmin(d)] - r[3] = tmp[np.argmax(d)] - return r - - def _clip_pts(self, pts, h, w): - pts[:, 0] = np.clip(pts[:, 0], 0, w - 1) - pts[:, 1] = np.clip(pts[:, 1], 0, h - 1) - return pts - - def _filter_quad(self, boxes, shape): - h, w = shape[0:2] - new_boxes = [] - for box in boxes: - box = np.array(box) if isinstance(box, list) else box - box = self._order_pts(box) - box = self._clip_pts(box, h, w) - rw = int(np.linalg.norm(box[0] - box[1])) - rh = int(np.linalg.norm(box[0] - box[3])) - if rw <= 3 or rh <= 3: - continue - new_boxes.append(box) - return np.array(new_boxes) - - def _filter_poly(self, boxes, shape): - h, w = shape[0:2] - new_boxes = [] - for box in boxes: - box = np.array(box) if isinstance(box, list) else box - box = self._clip_pts(box, h, w) - if Polygon(box).area < 10: - continue - new_boxes.append(box) - return np.array(new_boxes) - -# In class _MDR_TextDetector: - def __call__(self, img): - ori_im = img.copy() - data = {"image": img} - print(f" DEBUG OCR: _MDR_TextDetector: Original image shape: {ori_im.shape}") - - # Preprocessing - try: - data = mdr_ocr_transform(data, self.pre_op) - except Exception as e_preproc: - print(f" DEBUG OCR: _MDR_TextDetector: Error during preprocessing (mdr_ocr_transform): {e_preproc}") - import traceback - traceback.print_exc() - return np.array([]) # Return empty array on failure - - if data is None: - print(" DEBUG OCR: _MDR_TextDetector: Preprocessing (mdr_ocr_transform) returned None. No text will be detected.") - return np.array([]) - - processed_img, shape_list = data # shape_list is [src_h, src_w, ratio_h, ratio_w] - if processed_img is None: - print(" DEBUG OCR: _MDR_TextDetector: Processed image after transform is None. No text will be detected.") - return np.array([]) - print(f" DEBUG OCR: _MDR_TextDetector: Processed image shape for ONNX: {processed_img.shape}, shape_list: {shape_list}") - - img_for_onnx = np.expand_dims(processed_img, axis=0) - shape_list_for_onnx = np.expand_dims(shape_list, axis=0) - img_for_onnx = img_for_onnx.copy() - - inputs = self.get_input_feed(self.input_name, img_for_onnx) - print(f" DEBUG OCR: _MDR_TextDetector: Running ONNX inference for text detection...") - try: - outputs = self.sess.run(self.output_name, input_feed=inputs) - except Exception as e_infer: - print(f" DEBUG OCR: _MDR_TextDetector: ONNX inference for detection failed: {e_infer}") - import traceback - traceback.print_exc() - return np.array([]) # Return empty array on failure - print(f" DEBUG OCR: _MDR_TextDetector: ONNX inference done. Output map shape: {outputs[0].shape}") - - preds = {"maps": outputs[0]} - try: - post_res = self.post_op(preds, shape_list_for_onnx) - except Exception as e_postproc: - print(f" DEBUG OCR: _MDR_TextDetector: Error during DBPostProcess: {e_postproc}") - import traceback - traceback.print_exc() - return np.array([]) - - if not post_res or not post_res[0].get('points'): - print(" DEBUG OCR: _MDR_TextDetector: DBPostProcess returned no points.") - return np.array([]) - - boxes_from_post = post_res[0]['points'] - print(f" DEBUG OCR: _MDR_TextDetector: Boxes from DBPostProcess before final filtering: {len(boxes_from_post)}") - - if not isinstance(boxes_from_post, (list, np.ndarray)) or len(boxes_from_post) == 0: # Check if it's empty or not list-like - print(" DEBUG OCR: _MDR_TextDetector: No boxes from DBPostProcess to filter.") - return np.array([]) - - if self.args.det_box_type == 'poly': - final_boxes = self._filter_poly(boxes_from_post, ori_im.shape) - else: # 'quad' - final_boxes = self._filter_quad(boxes_from_post, ori_im.shape) - print(f" DEBUG OCR: _MDR_TextDetector: Boxes after final poly/quad filtering: {len(final_boxes)}") - return final_boxes + def __init__(self, args): + super().__init__() + self.args = args + pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, + {'NormalizeImage': {'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', + 'order': 'hwc'}}, {'ToCHWImage': None}, + {'KeepKeys': {'keep_keys': ['image', 'shape']}}] + self.pre_op = mdr_ocr_create_operators(pre_ops) + post_params = {'thresh': args.det_db_thresh, 'box_thresh': args.det_db_box_thresh, 'max_candidates': 1000, + 'unclip_ratio': args.det_db_unclip_ratio, 'use_dilation': args.use_dilation, + 'score_mode': args.det_db_score_mode, 'box_type': args.det_box_type} + self.post_op = _MDR_DBPostProcess(**post_params) + self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu) + self.input_name = self.get_input_name(self.sess) + self.output_name = self.get_output_name(self.sess) + + def _order_pts(self, pts): + r = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + r[0] = pts[np.argmin(s)] + r[2] = pts[np.argmax(s)] + tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) + d = np.diff(np.array(tmp), axis=1) + r[1] = tmp[np.argmin(d)] + r[3] = tmp[np.argmax(d)] + return r + + def _clip_pts(self, pts, h, w): + pts[:, 0] = np.clip(pts[:, 0], 0, w - 1) + pts[:, 1] = np.clip(pts[:, 1], 0, h - 1) + return pts + + def _filter_quad(self, boxes, shape): + h, w = shape[0:2] + new_boxes = [] + for box in boxes: + box = np.array(box) if isinstance(box, list) else box + box = self._order_pts(box) + box = self._clip_pts(box, h, w) + rw = int(np.linalg.norm(box[0] - box[1])) + rh = int(np.linalg.norm(box[0] - box[3])) + if rw <= 3 or rh <= 3: + continue + new_boxes.append(box) + return np.array(new_boxes) + + def _filter_poly(self, boxes, shape): + h, w = shape[0:2] + new_boxes = [] + for box in boxes: + box = np.array(box) if isinstance(box, list) else box + box = self._clip_pts(box, h, w) + if Polygon(box).area < 10: + continue + new_boxes.append(box) + return np.array(new_boxes) + + # In class _MDR_TextDetector: + def __call__(self, img): + ori_im = img.copy() + data = {"image": img} + print(f" DEBUG OCR: _MDR_TextDetector: Original image shape: {ori_im.shape}") + + # Preprocessing + try: + data = mdr_ocr_transform(data, self.pre_op) + except Exception as e_preproc: + print(f" DEBUG OCR: _MDR_TextDetector: Error during preprocessing (mdr_ocr_transform): {e_preproc}") + import traceback + traceback.print_exc() + return np.array([]) # Return empty array on failure + + if data is None: + print( + " DEBUG OCR: _MDR_TextDetector: Preprocessing (mdr_ocr_transform) returned None. No text will be detected.") + return np.array([]) + + processed_img, shape_list = data # shape_list is [src_h, src_w, ratio_h, ratio_w] + if processed_img is None: + print(" DEBUG OCR: _MDR_TextDetector: Processed image after transform is None. No text will be detected.") + return np.array([]) + print( + f" DEBUG OCR: _MDR_TextDetector: Processed image shape for ONNX: {processed_img.shape}, shape_list: {shape_list}") + + img_for_onnx = np.expand_dims(processed_img, axis=0) + shape_list_for_onnx = np.expand_dims(shape_list, axis=0) + img_for_onnx = img_for_onnx.copy() + + inputs = self.get_input_feed(self.input_name, img_for_onnx) + print(f" DEBUG OCR: _MDR_TextDetector: Running ONNX inference for text detection...") + try: + outputs = self.sess.run(self.output_name, input_feed=inputs) + except Exception as e_infer: + print(f" DEBUG OCR: _MDR_TextDetector: ONNX inference for detection failed: {e_infer}") + import traceback + traceback.print_exc() + return np.array([]) # Return empty array on failure + print(f" DEBUG OCR: _MDR_TextDetector: ONNX inference done. Output map shape: {outputs[0].shape}") + + preds = {"maps": outputs[0]} + try: + post_res = self.post_op(preds, shape_list_for_onnx) + except Exception as e_postproc: + print(f" DEBUG OCR: _MDR_TextDetector: Error during DBPostProcess: {e_postproc}") + import traceback + traceback.print_exc() + return np.array([]) + + if not post_res or not post_res[0].get('points'): + print(" DEBUG OCR: _MDR_TextDetector: DBPostProcess returned no points.") + return np.array([]) + + boxes_from_post = post_res[0]['points'] + print( + f" DEBUG OCR: _MDR_TextDetector: Boxes from DBPostProcess before final filtering: {len(boxes_from_post)}") + + if not isinstance(boxes_from_post, (list, np.ndarray)) or len( + boxes_from_post) == 0: # Check if it's empty or not list-like + print(" DEBUG OCR: _MDR_TextDetector: No boxes from DBPostProcess to filter.") + return np.array([]) + + if self.args.det_box_type == 'poly': + final_boxes = self._filter_poly(boxes_from_post, ori_im.shape) + else: # 'quad' + final_boxes = self._filter_quad(boxes_from_post, ori_im.shape) + print(f" DEBUG OCR: _MDR_TextDetector: Boxes after final poly/quad filtering: {len(final_boxes)}") + return final_boxes + class _MDR_ClsPostProcess: - def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0:'0', 1:'180'} + def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0: '0', 1: '180'} + + def __call__(self, preds, label=None, *args, **kwargs): + preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; + idxs = preds.argmax(axis=1) + return [(self.labels[idx], float(preds[i, idx])) for i, idx in enumerate(idxs)] - def __call__(self, preds, label=None, *args, **kwargs): - preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; idxs = preds.argmax(axis=1) - return [(self.labels[idx], float(preds[i,idx])) for i,idx in enumerate(idxs)] class _MDR_TextClassifier(_MDR_PredictBase): - def __init__(self, args): - super().__init__() - self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape - self.batch_num = args.cls_batch_num - self.thresh = args.cls_thresh - self.post_op = _MDR_ClsPostProcess(label_list=args.label_list) - self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu) - self.input_name = self.get_input_name(self.sess) - self.output_name = self.get_output_name(self.sess) - - def _resize_norm(self, img): - imgC, imgH, imgW = self.shape - h, w = img.shape[:2] - r = w / float(h) if h > 0 else 0 - rw = int(ceil(imgH * r)) - rw = min(rw, imgW) - resized = cv2.resize(img, (rw, imgH)) - resized = resized.astype("float32") - if imgC == 1: - resized = resized / 255.0 - resized = resized[np.newaxis, :] - else: - resized = resized.transpose((2, 0, 1)) / 255.0 - resized -= 0.5 - resized /= 0.5 - padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding[:, :, 0:rw] = resized - return padding - - def __call__(self, img_list): - if not img_list: - return img_list, [] - img_list_cp = copy.deepcopy(img_list) - num = len(img_list_cp) - ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list_cp] - indices = np.argsort(np.array(ratios)) - results = [["", 0.0]] * num - batch_n = self.batch_num - for start in range(0, num, batch_n): - end = min(num, start + batch_n) - batch = [] - for i in range(start, end): - batch.append(self._resize_norm(img_list_cp[indices[i]])[np.newaxis, :]) - if not batch: - continue - batch = np.concatenate(batch, axis=0).copy() - inputs = self.get_input_feed(self.input_name, batch) - outputs = self.sess.run(self.output_name, input_feed=inputs) - cls_out = self.post_op(outputs[0]) - for i in range(len(cls_out)): - orig_idx = indices[start + i] - label, score = cls_out[i] - results[orig_idx] = [label, score] - if "180" in label and score > self.thresh: - img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180) - return img_list, results + def __init__(self, args): + super().__init__() + self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, + str) else args.cls_image_shape + self.batch_num = args.cls_batch_num + self.thresh = args.cls_thresh + self.post_op = _MDR_ClsPostProcess(label_list=args.label_list) + self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu) + self.input_name = self.get_input_name(self.sess) + self.output_name = self.get_output_name(self.sess) + + def _resize_norm(self, img): + imgC, imgH, imgW = self.shape + h, w = img.shape[:2] + r = w / float(h) if h > 0 else 0 + rw = int(ceil(imgH * r)) + rw = min(rw, imgW) + resized = cv2.resize(img, (rw, imgH)) + resized = resized.astype("float32") + if imgC == 1: + resized = resized / 255.0 + resized = resized[np.newaxis, :] + else: + resized = resized.transpose((2, 0, 1)) / 255.0 + resized -= 0.5 + resized /= 0.5 + padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding[:, :, 0:rw] = resized + return padding + + def __call__(self, img_list): + if not img_list: + return img_list, [] + img_list_cp = copy.deepcopy(img_list) + num = len(img_list_cp) + ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list_cp] + indices = np.argsort(np.array(ratios)) + results = [["", 0.0]] * num + batch_n = self.batch_num + for start in range(0, num, batch_n): + end = min(num, start + batch_n) + batch = [] + for i in range(start, end): + batch.append(self._resize_norm(img_list_cp[indices[i]])[np.newaxis, :]) + if not batch: + continue + batch = np.concatenate(batch, axis=0).copy() + inputs = self.get_input_feed(self.input_name, batch) + outputs = self.sess.run(self.output_name, input_feed=inputs) + cls_out = self.post_op(outputs[0]) + for i in range(len(cls_out)): + orig_idx = indices[start + i] + label, score = cls_out[i] + results[orig_idx] = [label, score] + if "180" in label and score > self.thresh: + img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180) + return img_list, results + class _MDR_BaseRecLabelDecode: @@ -1095,449 +1226,514 @@ class _MDR_BaseRecLabelDecode: res.append((txt, float(np.mean(conf_l)))) return res + class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode): - def __init__(self, char_path=None, use_space=False, **kwargs): super().__init__(char_path, use_space) - def add_special_char(self, chars): return ["blank"]+chars - def get_ignored_tokens(self): return [0] # blank index - def __call__(self, preds, label=None, *args, **kwargs): - preds = preds[-1] if isinstance(preds,(tuple,list)) else preds; preds = np.array(preds) if not isinstance(preds,np.ndarray) else preds - idxs=preds.argmax(axis=2); probs=preds.max(axis=2); txt=self.decode(idxs, probs, remove_dup=True); return txt + def __init__(self, char_path=None, use_space=False, **kwargs): super().__init__(char_path, use_space) + + def add_special_char(self, chars): return ["blank"] + chars + + def get_ignored_tokens(self): return [0] # blank index + + def __call__(self, preds, label=None, *args, **kwargs): + preds = preds[-1] if isinstance(preds, (tuple, list)) else preds; + preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds + idxs = preds.argmax(axis=2); + probs = preds.max(axis=2); + txt = self.decode(idxs, probs, remove_dup=True); + return txt + class _MDR_TextRecognizer(_MDR_PredictBase): - def __init__(self, args): - super().__init__() - shape_str = getattr(args, 'rec_image_shape', "3,48,320") - self.shape = tuple(map(int, shape_str.split(','))) - self.batch_num = getattr(args, 'rec_batch_num', 6) - self.algo = getattr(args, 'rec_algorithm', 'SVTR_LCNet') - self.post_op = _MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, use_space=getattr(args, 'use_space_char', True)) - self.sess = self.get_onnx_session(args.rec_model_dir, args.use_gpu) - self.input_name = self.get_input_name(self.sess) - self.output_name = self.get_output_name(self.sess) - -# In class _MDR_TextRecognizer - def _resize_norm(self, img, max_r): # img is a single crop - imgC, imgH, imgW = self.shape # e.g., (3, 48, 320) - h_orig, w_orig = img.shape[:2] - # ADDED: Log input crop shape - print(f" DEBUG RECOGNIZER: _resize_norm input crop shape: ({h_orig}, {w_orig}), target shape: {self.shape}, max_r_batch: {max_r:.2f}") - - if h_orig == 0 or w_orig == 0: - print(f" DEBUG RECOGNIZER: _resize_norm received zero-dimension crop ({h_orig}x{w_orig}). Returning zeros.") - return np.zeros((imgC, imgH, imgW), dtype=np.float32) - - r_current = w_orig / float(h_orig) - tw = min(imgW, int(ceil(imgH * r_current))) - tw = max(1, tw) - print(f" DEBUG RECOGNIZER: _resize_norm calculated target width (tw): {tw} for target height (imgH): {imgH}") + def __init__(self, args): + super().__init__() + shape_str = getattr(args, 'rec_image_shape', "3,48,320") + self.shape = tuple(map(int, shape_str.split(','))) + self.batch_num = getattr(args, 'rec_batch_num', 6) + self.algo = getattr(args, 'rec_algorithm', 'SVTR_LCNet') + self.post_op = _MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, + use_space=getattr(args, 'use_space_char', True)) + self.sess = self.get_onnx_session(args.rec_model_dir, args.use_gpu) + self.input_name = self.get_input_name(self.sess) + self.output_name = self.get_output_name(self.sess) + + # In class _MDR_TextRecognizer + def _resize_norm(self, img, max_r): # img is a single crop + imgC, imgH, imgW = self.shape # e.g., (3, 48, 320) + h_orig, w_orig = img.shape[:2] + # ADDED: Log input crop shape + print( + f" DEBUG RECOGNIZER: _resize_norm input crop shape: ({h_orig}, {w_orig}), target shape: {self.shape}, max_r_batch: {max_r:.2f}") + + if h_orig == 0 or w_orig == 0: + print( + f" DEBUG RECOGNIZER: _resize_norm received zero-dimension crop ({h_orig}x{w_orig}). Returning zeros.") + return np.zeros((imgC, imgH, imgW), dtype=np.float32) + + r_current = w_orig / float(h_orig) + tw = min(imgW, int(ceil(imgH * r_current))) + tw = max(1, tw) + print(f" DEBUG RECOGNIZER: _resize_norm calculated target width (tw): {tw} for target height (imgH): {imgH}") + + try: + resized = cv2.resize(img, (tw, imgH)) + except cv2.error as e_resize: # Catch specific cv2 error + print( + f" DEBUG RECOGNIZER: _resize_norm cv2.resize failed: {e_resize}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") + return np.zeros((imgC, imgH, imgW), dtype=np.float32) + except Exception as e_resize_general: # Catch any other unexpected error + print( + f" DEBUG RECOGNIZER: _resize_norm general error during resize: {e_resize_general}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") + import traceback + traceback.print_exc() + return np.zeros((imgC, imgH, imgW), dtype=np.float32) + + resized = resized.astype("float32") + if imgC == 1 and len(resized.shape) == 3: + resized = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY) + if len(resized.shape) == 2: + resized = resized[:, :, np.newaxis] # Add channel dim if grayscale + + # Ensure resized has 3 channels if imgC is 3, even if input was grayscale + if imgC == 3 and resized.shape[2] == 1: + resized = cv2.cvtColor(resized, cv2.COLOR_GRAY2BGR) + + resized = resized.transpose((2, 0, 1)) / 255.0 + resized -= 0.5 + resized /= 0.5 + + padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding[:, :, 0:tw] = resized + print(f" DEBUG RECOGNIZER: _resize_norm output padded shape: {padding.shape}") + + # ADDED: Log normalized crop properties + min_px, max_px, mean_px = np.min(padding), np.max(padding), np.mean(padding) + print(f" DEBUG RECOGNIZER: Normalized Crop Properties (before ONNX): " + f"dtype: {padding.dtype}, " + f"MinPx: {min_px:.4f}, " + f"MaxPx: {max_px:.4f}, " + f"MeanPx: {mean_px:.4f}") + if np.all(padding == 0): + print(" DEBUG RECOGNIZER: WARNING - Normalized image is all zeros!") + elif np.abs(max_px - min_px) < 1e-6: # Check if all elements are (close to) the same + print(f" DEBUG RECOGNIZER: WARNING - Normalized image is a constant value: {mean_px:.4f}") + return padding + + def __call__(self, img_list): + if not img_list: + return [] + num = len(img_list) + ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list] + indices = np.argsort(np.array(ratios)) + results = [["", 0.0]] * num + batch_n = self.batch_num + for start in range(0, num, batch_n): + end = min(num, start + batch_n) + batch = [] + max_r_batch = 0 + for i in range(start, end): + h, w = img_list[indices[i]].shape[0:2] + if h > 0: + max_r_batch = max(max_r_batch, w / float(h)) + for i in range(start, end): + batch.append(self._resize_norm(img_list[indices[i]], max_r_batch)[np.newaxis, :]) + if not batch: + continue + batch = np.concatenate(batch, axis=0).copy() + inputs = self.get_input_feed(self.input_name, batch) + outputs = self.sess.run(self.output_name, input_feed=inputs) + rec_out = self.post_op(outputs[0]) + for i in range(len(rec_out)): + results[indices[start + i]] = rec_out[i] + return results - try: - resized = cv2.resize(img, (tw, imgH)) - except cv2.error as e_resize: # Catch specific cv2 error - print(f" DEBUG RECOGNIZER: _resize_norm cv2.resize failed: {e_resize}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") - return np.zeros((imgC, imgH, imgW), dtype=np.float32) - except Exception as e_resize_general: # Catch any other unexpected error - print(f" DEBUG RECOGNIZER: _resize_norm general error during resize: {e_resize_general}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") - import traceback - traceback.print_exc() - return np.zeros((imgC, imgH, imgW), dtype=np.float32) - - - resized = resized.astype("float32") - if imgC == 1 and len(resized.shape) == 3: - resized = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY) - if len(resized.shape) == 2: - resized = resized[:, :, np.newaxis] # Add channel dim if grayscale - - # Ensure resized has 3 channels if imgC is 3, even if input was grayscale - if imgC == 3 and resized.shape[2] == 1: - resized = cv2.cvtColor(resized, cv2.COLOR_GRAY2BGR) - - - resized = resized.transpose((2, 0, 1)) / 255.0 - resized -= 0.5 - resized /= 0.5 - - padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding[:, :, 0:tw] = resized - print(f" DEBUG RECOGNIZER: _resize_norm output padded shape: {padding.shape}") - - # ADDED: Log normalized crop properties - min_px, max_px, mean_px = np.min(padding), np.max(padding), np.mean(padding) - print(f" DEBUG RECOGNIZER: Normalized Crop Properties (before ONNX): " - f"dtype: {padding.dtype}, " - f"MinPx: {min_px:.4f}, " - f"MaxPx: {max_px:.4f}, " - f"MeanPx: {mean_px:.4f}") - if np.all(padding == 0): - print(" DEBUG RECOGNIZER: WARNING - Normalized image is all zeros!") - elif np.abs(max_px - min_px) < 1e-6 : # Check if all elements are (close to) the same - print(f" DEBUG RECOGNIZER: WARNING - Normalized image is a constant value: {mean_px:.4f}") - return padding - - def __call__(self, img_list): - if not img_list: - return [] - num = len(img_list) - ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list] - indices = np.argsort(np.array(ratios)) - results = [["", 0.0]] * num - batch_n = self.batch_num - for start in range(0, num, batch_n): - end = min(num, start + batch_n) - batch = [] - max_r_batch = 0 - for i in range(start, end): - h, w = img_list[indices[i]].shape[0:2] - if h > 0: - max_r_batch = max(max_r_batch, w / float(h)) - for i in range(start, end): - batch.append(self._resize_norm(img_list[indices[i]], max_r_batch)[np.newaxis, :]) - if not batch: - continue - batch = np.concatenate(batch, axis=0).copy() - inputs = self.get_input_feed(self.input_name, batch) - outputs = self.sess.run(self.output_name, input_feed=inputs) - rec_out = self.post_op(outputs[0]) - for i in range(len(rec_out)): - results[indices[start + i]] = rec_out[i] - return results # --- MDR ONNX OCR System --- class _MDR_TextSystem: - def __init__(self, args): - class ArgsObject: # Helper to access dict args with dot notation - def __init__(self, **entries): self.__dict__.update(entries) - if isinstance(args, dict): args = ArgsObject(**args) - self.args = args - self.detector = _MDR_TextDetector(args) - self.recognizer = _MDR_TextRecognizer(args) - self.use_cls = getattr(args, 'use_angle_cls', True) - self.drop_score = getattr(args, 'drop_score', 0.5) - self.classifier = _MDR_TextClassifier(args) if self.use_cls else None - self.crop_idx = 0 - self.save_crop = getattr(args, 'save_crop_res', False) - self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res") - - def _sort_boxes(self, boxes): - if boxes is None or len(boxes)==0: return [] - def key(box): min_y=min(p[1] for p in box); min_x=min(p[0] for p in box); return (min_y, min_x) - try: return list(sorted(boxes, key=key)) - except: return list(boxes) # Fallback - -# In class _MDR_TextRecognizer - def _resize_norm(self, img, max_r): # img is a single crop - imgC, imgH, imgW = self.shape # e.g., (3, 48, 320) - h_orig, w_orig = img.shape[:2] - # ADDED: Log input crop shape - print(f" DEBUG RECOGNIZER: _resize_norm input crop shape: ({h_orig}, {w_orig}), target shape: {self.shape}, max_r_batch: {max_r:.2f}") - - if h_orig == 0 or w_orig == 0: - print(f" DEBUG RECOGNIZER: _resize_norm received zero-dimension crop ({h_orig}x{w_orig}). Returning zeros.") - return np.zeros((imgC, imgH, imgW), dtype=np.float32) - - r_current = w_orig / float(h_orig) - tw = min(imgW, int(ceil(imgH * r_current))) - tw = max(1, tw) - print(f" DEBUG RECOGNIZER: _resize_norm calculated target width (tw): {tw} for target height (imgH): {imgH}") + def __init__(self, args): + class ArgsObject: # Helper to access dict args with dot notation + def __init__(self, **entries): self.__dict__.update(entries) + + if isinstance(args, dict): args = ArgsObject(**args) + self.args = args + self.detector = _MDR_TextDetector(args) + self.recognizer = _MDR_TextRecognizer(args) + self.use_cls = getattr(args, 'use_angle_cls', True) + self.drop_score = getattr(args, 'drop_score', 0.5) + self.classifier = _MDR_TextClassifier(args) if self.use_cls else None + self.crop_idx = 0 + self.save_crop = getattr(args, 'save_crop_res', False) + self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res") + + # --- START: CORRECTED/ADDED __call__ METHOD --- + def __call__(self, img: np.ndarray) -> tuple[list[np.ndarray], list[tuple[str, float]]]: + """ + Processes an image to detect and recognize text. + Args: + img: A NumPy array representing the image (BGR format). + Returns: + A tuple containing: + - A list of detected text bounding boxes (each box is a NumPy array of 4 points). + - A list of recognition results (each result is a tuple of [text, confidence_score]). + """ + ori_im = img.copy() # Keep original for cropping + + # 1. Detect text boxes using self.detector + # The detector's __call__ method handles its own preprocessing. + # dt_boxes are expected to be in original image coordinates. + dt_boxes: np.ndarray = self.detector(img) # This is an np.ndarray of shape (N, 4, 2) or empty + print( + f" DEBUG TextSystem: Detector found {len(dt_boxes) if dt_boxes is not None and dt_boxes.size > 0 else 0} initial boxes.") + + if dt_boxes is None or dt_boxes.size == 0: # Check if array is empty + return [], [] + + # 2. Sort boxes (typically top-to-bottom, left-to-right) + dt_boxes_sorted: list[np.ndarray] = self._sort_boxes(dt_boxes) + print(f" DEBUG TextSystem: Sorted {len(dt_boxes_sorted)} boxes.") + + if not dt_boxes_sorted: # If sorting resulted in empty list (e.g. due to unexpected format) + return [], [] + + # 3. Get cropped images from detected boxes + img_crop_list: list[np.ndarray] = [] + for i in range(len(dt_boxes_sorted)): + # dt_boxes_sorted[i] is a single box (e.g., 4x2 array of points) + crop_im = mdr_get_rotated_crop(ori_im, dt_boxes_sorted[i]) + img_crop_list.append(crop_im) + print(f" DEBUG TextSystem: Created {len(img_crop_list)} crops for further processing.") + + # 4. (Optional) Classify text orientation and rotate crops if necessary + # The classifier's __call__ method handles its own preprocessing and modifies img_crop_list in place. + if self.use_cls and self.classifier is not None and img_crop_list: + print(f" DEBUG TextSystem: Applying text classification for {len(img_crop_list)} crops.") + img_crop_list, cls_results = self.classifier(img_crop_list) # classifier updates img_crop_list + print(f" DEBUG TextSystem: Classification complete. {len(cls_results if cls_results else [])} results.") + + # 5. Recognize text in the (potentially rotated) cropped images + # The recognizer's __call__ method handles its own preprocessing. + rec_results: list[tuple[str, float]] = [] + if img_crop_list: + print(f" DEBUG TextSystem: Recognizing text for {len(img_crop_list)} crops.") + rec_results = self.recognizer(img_crop_list) + else: + print(f" DEBUG TextSystem: No crops to recognize.") + + # 6. Filter results + final_boxes_to_return: list[np.ndarray] = [] + final_recs_to_return: list[tuple[str, float]] = [] + final_crops_for_saving: list[np.ndarray] = [] + + if rec_results and len(rec_results) == len(dt_boxes_sorted) and len(rec_results) == len(img_crop_list): + for i in range(len(rec_results)): + text, confidence = rec_results[i] + if confidence >= self.drop_score and text and not mdr_is_whitespace(text): + final_boxes_to_return.append(dt_boxes_sorted[i]) + final_recs_to_return.append(rec_results[i]) + if self.save_crop: + final_crops_for_saving.append(img_crop_list[i]) + else: + print(f" DEBUG TextSystem: Warning - Mismatch or empty rec_results. " + f"len(rec_results)={len(rec_results) if rec_results else 'None'}, " + f"len(dt_boxes_sorted)={len(dt_boxes_sorted)}, " + f"len(img_crop_list)={len(img_crop_list)}. No results will be returned from this stage.") + # Do not return here, allow empty lists to propagate if that's the case + + print(f" DEBUG TextSystem: Kept {len(final_boxes_to_return)} boxes after recognition and filtering.") + + # 7. (Optional) Save cropped images that passed all filters + if self.save_crop and final_crops_for_saving: + print(f" DEBUG TextSystem: Saving {len(final_crops_for_saving)} filtered crops.") + self._save_crops(final_crops_for_saving, final_recs_to_return) + + return final_boxes_to_return, final_recs_to_return + + # --- END: CORRECTED/ADDED __call__ METHOD --- + + def _sort_boxes(self, boxes): + if boxes is None or len(boxes) == 0: return [] + + def key(box): + min_y = min(p[1] for p in box); min_x = min(p[0] for p in box); return (min_y, min_x) + + try: + return list(sorted(boxes, key=key)) + except: + return list(boxes) # Fallback + + def _save_crops(self, crops, recs): + mdr_ensure_directory(self.crop_dir) + num = len(crops) + for i in range(num): + txt, score = recs[i] + safe = re.sub(r'\W+', '_', txt)[:20] + fname = f"crop_{self.crop_idx + i}_{safe}_{score:.2f}.jpg" + cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i]) + self.crop_idx += num - try: - resized = cv2.resize(img, (tw, imgH)) - except cv2.error as e_resize: # Catch specific cv2 error - print(f" DEBUG RECOGNIZER: _resize_norm cv2.resize failed: {e_resize}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") - return np.zeros((imgC, imgH, imgW), dtype=np.float32) - except Exception as e_resize_general: # Catch any other unexpected error - print(f" DEBUG RECOGNIZER: _resize_norm general error during resize: {e_resize_general}. Original shape ({h_orig},{w_orig}), target ({tw},{imgH}). Returning zeros.") - import traceback - traceback.print_exc() - return np.zeros((imgC, imgH, imgW), dtype=np.float32) - - - resized = resized.astype("float32") - if imgC == 1 and len(resized.shape) == 3: - resized = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY) - if len(resized.shape) == 2: - resized = resized[:, :, np.newaxis] # Add channel dim if grayscale - - # Ensure resized has 3 channels if imgC is 3, even if input was grayscale - if imgC == 3 and resized.shape[2] == 1: - resized = cv2.cvtColor(resized, cv2.COLOR_GRAY2BGR) - - - resized = resized.transpose((2, 0, 1)) / 255.0 - resized -= 0.5 - resized /= 0.5 - - padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding[:, :, 0:tw] = resized - print(f" DEBUG RECOGNIZER: _resize_norm output padded shape: {padding.shape}") - - # ADDED: Log normalized crop properties - min_px, max_px, mean_px = np.min(padding), np.max(padding), np.mean(padding) - print(f" DEBUG RECOGNIZER: Normalized Crop Properties (before ONNX): " - f"dtype: {padding.dtype}, " - f"MinPx: {min_px:.4f}, " - f"MaxPx: {max_px:.4f}, " - f"MeanPx: {mean_px:.4f}") - if np.all(padding == 0): - print(" DEBUG RECOGNIZER: WARNING - Normalized image is all zeros!") - elif np.abs(max_px - min_px) < 1e-6 : # Check if all elements are (close to) the same - print(f" DEBUG RECOGNIZER: WARNING - Normalized image is a constant value: {mean_px:.4f}") - return padding - - def _save_crops(self, crops, recs): - mdr_ensure_directory(self.crop_dir) - num = len(crops) - for i in range(num): - txt, score = recs[i] - safe = re.sub(r'\W+', '_', txt)[:20] - fname = f"crop_{self.crop_idx + i}_{safe}_{score:.2f}.jpg" - cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i]) - self.crop_idx += num # --- MDR ONNX OCR Utilities --- def mdr_get_rotated_crop(img, points): - """Crops and perspective-transforms a quadrilateral region.""" - pts = np.array(points, dtype="float32") - assert len(pts) == 4 - w = int(max(np.linalg.norm(pts[0] - pts[1]), np.linalg.norm(pts[2] - pts[3]))) - h = int(max(np.linalg.norm(pts[0] - pts[3]), np.linalg.norm(pts[1] - pts[2]))) - std = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) - M = cv2.getPerspectiveTransform(pts, std) - dst = cv2.warpPerspective(img, M, (w, h), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC) - dh, dw = dst.shape[0:2] - if dh > 0 and dw > 0 and dh * 1.0 / dw >= 1.5: - dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE) - return dst + """Crops and perspective-transforms a quadrilateral region.""" + pts = np.array(points, dtype="float32") + assert len(pts) == 4 + w = int(max(np.linalg.norm(pts[0] - pts[1]), np.linalg.norm(pts[2] - pts[3]))) + h = int(max(np.linalg.norm(pts[0] - pts[3]), np.linalg.norm(pts[1] - pts[2]))) + std = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) + M = cv2.getPerspectiveTransform(pts, std) + dst = cv2.warpPerspective(img, M, (w, h), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC) + dh, dw = dst.shape[0:2] + if dh > 0 and dw > 0 and dh * 1.0 / dw >= 1.5: + dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE) + return dst + def mdr_get_min_area_crop(img, points): - """Crops the minimum area rectangle containing the points.""" - bb = cv2.minAreaRect(np.array(points).astype(np.int32)) - box_pts = cv2.boxPoints(bb) - return mdr_get_rotated_crop(img, box_pts) + """Crops the minimum area rectangle containing the points.""" + bb = cv2.minAreaRect(np.array(points).astype(np.int32)) + box_pts = cv2.boxPoints(bb) + return mdr_get_rotated_crop(img, box_pts) + # --- MDR Layout Processing --- _MDR_INCLUDES_MIN_RATE = 0.99 + class _MDR_OverlapMatrixContext: - def __init__(self, layouts: list[MDRLayoutElement]): - length = len(layouts); self.polys: list[Polygon|None] = [] - for l in layouts: - try: p = Polygon(l.rect); self.polys.append(p if p.is_valid else None) - except: self.polys.append(None) - self.matrix = [[0.0]*length for _ in range(length)]; self.removed = set() - for i in range(length): - p1 = self.polys[i]; - if p1 is None: continue; self.matrix[i][i] = 1.0 - for j in range(i+1, length): - p2 = self.polys[j]; - if p2 is None: continue - r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji - - def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2 - try: - inter = p1.intersection(p2) - except: - return 0.0 - if inter.is_empty or inter.area < 1e-6: - return 0.0 - _, _, ix1, iy1 = inter.bounds - iw = ix1 - inter.bounds[0] - ih = iy1 - inter.bounds[1] - _, _, px1, py1 = p2.bounds - pw = px1 - p2.bounds[0] - ph = py1 - p2.bounds[1] - if pw < 1e-6 or ph < 1e-6: - return 0.0 - wr = min(iw / pw, 1.0) - hr = min(ih / ph, 1.0) - return (wr + hr) / 2.0 - - def others(self, idx: int): - for i, r in enumerate(self.matrix[idx]): - if i != idx and i not in self.removed: yield r - - def includes(self, idx: int): # Layouts included BY idx - for i, r in enumerate(self.matrix[idx]): - if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE: - if self.matrix[i][idx] < _MDR_INCLUDES_MIN_RATE: yield i + def __init__(self, layouts: list[MDRLayoutElement]): + length = len(layouts); + self.polys: list[Polygon | None] = [] + for l in layouts: + try: + p = Polygon(l.rect); self.polys.append(p if p.is_valid else None) + except: + self.polys.append(None) + self.matrix = [[0.0] * length for _ in range(length)]; + self.removed = set() + for i in range(length): + p1 = self.polys[i]; + if p1 is None: continue; self.matrix[i][i] = 1.0 + for j in range(i + 1, length): + p2 = self.polys[j]; + if p2 is None: continue + r_ij = self._rate(p1, p2); + r_ji = self._rate(p2, p1); + self.matrix[i][j] = r_ij; + self.matrix[j][i] = r_ji + + def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2 + try: + inter = p1.intersection(p2) + except: + return 0.0 + if inter.is_empty or inter.area < 1e-6: + return 0.0 + _, _, ix1, iy1 = inter.bounds + iw = ix1 - inter.bounds[0] + ih = iy1 - inter.bounds[1] + _, _, px1, py1 = p2.bounds + pw = px1 - p2.bounds[0] + ph = py1 - p2.bounds[1] + if pw < 1e-6 or ph < 1e-6: + return 0.0 + wr = min(iw / pw, 1.0) + hr = min(ih / ph, 1.0) + return (wr + hr) / 2.0 + + def others(self, idx: int): + for i, r in enumerate(self.matrix[idx]): + if i != idx and i not in self.removed: yield r + + def includes(self, idx: int): # Layouts included BY idx + for i, r in enumerate(self.matrix[idx]): + if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE: + if self.matrix[i][idx] < _MDR_INCLUDES_MIN_RATE: yield i + def mdr_remove_overlap_layouts(layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]: - if not layouts: - return [] - ctx = _MDR_OverlapMatrixContext(layouts) - prev_removed = -1 - while len(ctx.removed) != prev_removed: - prev_removed = len(ctx.removed) - current_removed = set() - for i in range(len(layouts)): - if i in ctx.removed or i in current_removed: - continue - li = layouts[i] - pi = ctx.polys[i] - if pi is None: - current_removed.add(i) - continue - contained = False - for j in range(len(layouts)): - if i == j or j in ctx.removed or j in current_removed: - continue - if ctx.matrix[j][i] >= _MDR_INCLUDES_MIN_RATE and ctx.matrix[i][j] < _MDR_INCLUDES_MIN_RATE: - contained = True - break - if contained: - current_removed.add(i) - continue - contained_by_i = list(ctx.includes(i)) - if contained_by_i: - for j in contained_by_i: - if j not in ctx.removed and j not in current_removed: - li.fragments.extend(layouts[j].fragments) - current_removed.add(j) - li.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) - ctx.removed.update(current_removed) - return [l for i, l in enumerate(layouts) if i not in ctx.removed] + if not layouts: + return [] + ctx = _MDR_OverlapMatrixContext(layouts) + prev_removed = -1 + while len(ctx.removed) != prev_removed: + prev_removed = len(ctx.removed) + current_removed = set() + for i in range(len(layouts)): + if i in ctx.removed or i in current_removed: + continue + li = layouts[i] + pi = ctx.polys[i] + if pi is None: + current_removed.add(i) + continue + contained = False + for j in range(len(layouts)): + if i == j or j in ctx.removed or j in current_removed: + continue + if ctx.matrix[j][i] >= _MDR_INCLUDES_MIN_RATE and ctx.matrix[i][j] < _MDR_INCLUDES_MIN_RATE: + contained = True + break + if contained: + current_removed.add(i) + continue + contained_by_i = list(ctx.includes(i)) + if contained_by_i: + for j in contained_by_i: + if j not in ctx.removed and j not in current_removed: + li.fragments.extend(layouts[j].fragments) + current_removed.add(j) + li.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + ctx.removed.update(current_removed) + return [l for i, l in enumerate(layouts) if i not in ctx.removed] + def _mdr_split_fragments_into_lines(frags: list[MDROcrFragment]) -> Generator[list[MDROcrFragment], None, None]: - if not frags: - return - frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) - group, y_sum, h_sum = [], 0.0, 0.0 - for f in frags: - _, y1, _, y2 = f.rect.wrapper - h = y2 - y1 - med_y = (y1 + y2) / 2.0 - if h <= 0: - continue - if not group: - group.append(f) - y_sum, h_sum = med_y, h - else: - g_len = len(group) - avg_med_y = y_sum / g_len - avg_h = h_sum / g_len - max_dev = avg_h * 0.40 - if abs(med_y - avg_med_y) > max_dev: - yield group - group, y_sum, h_sum = [f], med_y, h - else: - group.append(f) - y_sum += med_y - h_sum += h - if group: - yield group + if not frags: + return + frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + group, y_sum, h_sum = [], 0.0, 0.0 + for f in frags: + _, y1, _, y2 = f.rect.wrapper + h = y2 - y1 + med_y = (y1 + y2) / 2.0 + if h <= 0: + continue + if not group: + group.append(f) + y_sum, h_sum = med_y, h + else: + g_len = len(group) + avg_med_y = y_sum / g_len + avg_h = h_sum / g_len + max_dev = avg_h * 0.40 + if abs(med_y - avg_med_y) > max_dev: + yield group + group, y_sum, h_sum = [f], med_y, h + else: + group.append(f) + y_sum += med_y + h_sum += h + if group: + yield group + def mdr_merge_fragments_into_lines(orig_frags: list[MDROcrFragment]) -> list[MDROcrFragment]: - merged = [] - for group in _mdr_split_fragments_into_lines(orig_frags): - if not group: - continue - if len(group) == 1: - merged.append(group[0]) - continue - group.sort(key=lambda f: f.rect.lt[0]) - min_order = min(f.order for f in group if hasattr(f, 'order')) if group else 0 - texts, rank_w, txt_len = [], 0.0, 0 - x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") - for f in group: - fx1, fy1, fx2, fy2 = f.rect.wrapper - x1, y1, x2, y2 = min(x1, fx1), min(y1, fy1), max(x2, fx2), max(y2, fy2) - t = f.text - l = len(t) - if l > 0: - texts.append(t) - rank_w += f.rank * l - txt_len += l - if txt_len == 0: - continue - m_txt = " ".join(texts) - m_rank = rank_w / txt_len if txt_len > 0 else 0.0 - m_rect = MDRRectangle(lt=(x1, y1), rt=(x2, y1), lb=(x1, y2), rb=(x2, y2)) - merged.append(MDROcrFragment(order=min_order, text=m_txt, rank=m_rank, rect=m_rect)) - merged.sort(key=lambda f: (f.order, f.rect.lt[1], f.rect.lt[0])) - for i, f in enumerate(merged): - f.order = i - return merged + merged = [] + for group in _mdr_split_fragments_into_lines(orig_frags): + if not group: + continue + if len(group) == 1: + merged.append(group[0]) + continue + group.sort(key=lambda f: f.rect.lt[0]) + min_order = min(f.order for f in group if hasattr(f, 'order')) if group else 0 + texts, rank_w, txt_len = [], 0.0, 0 + x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") + for f in group: + fx1, fy1, fx2, fy2 = f.rect.wrapper + x1, y1, x2, y2 = min(x1, fx1), min(y1, fy1), max(x2, fx2), max(y2, fy2) + t = f.text + l = len(t) + if l > 0: + texts.append(t) + rank_w += f.rank * l + txt_len += l + if txt_len == 0: + continue + m_txt = " ".join(texts) + m_rank = rank_w / txt_len if txt_len > 0 else 0.0 + m_rect = MDRRectangle(lt=(x1, y1), rt=(x2, y1), lb=(x1, y2), rb=(x2, y2)) + merged.append(MDROcrFragment(order=min_order, text=m_txt, rank=m_rank, rect=m_rect)) + merged.sort(key=lambda f: (f.order, f.rect.lt[1], f.rect.lt[0])) + for i, f in enumerate(merged): + f.order = i + return merged + # --- MDR Layout Processing --- _MDR_CORRECTION_MIN_OVERLAP = 0.5 + def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image, layout: MDRLayoutElement): - if not layout.fragments: - return - try: - x1, y1, x2, y2 = layout.rect.wrapper - margin = 5 - crop_box = (max(0, round(x1) - margin), max(0, round(y1) - margin), min(source_img.width, round(x2) + margin), min(source_img.height, round(y2) + margin)) - if crop_box[0] >= crop_box[2] or crop_box[1] >= crop_box[3]: + if not layout.fragments: return - cropped = source_img.crop(crop_box) - off_x, off_y = crop_box[0], crop_box[1] - except Exception as e: - print(f"Correct: Crop error: {e}") - return - try: - cropped_np = np.array(cropped.convert("RGB"))[:, :, ::-1] - new_frags_local = list(ocr_engine.find_text_fragments(cropped_np)) - except Exception as e: - print(f"Correct: OCR error: {e}") - return - new_frags_global = [] - for f in new_frags_local: - r = f.rect - lt, rt, lb, rb = r.lt, r.rt, r.lb, r.rb - f.rect = MDRRectangle(lt=(lt[0] + off_x, lt[1] + off_y), rt=(rt[0] + off_x, rt[1] + off_y), lb=(lb[0] + off_x, lb[1] + off_y), rb=(rb[0] + off_x, rb[1] + off_y)) - new_frags_global.append(f) - orig_frags = layout.fragments - matched, unmatched_orig = [], [] - used_new = set() - for i, orig_f in enumerate(orig_frags): - best_j, best_rate = -1, -1.0 try: - poly_o = Polygon(orig_f.rect) - except: - continue - if not poly_o.is_valid: - continue - for j, new_f in enumerate(new_frags_global): - if j in used_new: - continue - try: - poly_n = Polygon(new_f.rect) - except: - continue - if not poly_n.is_valid: - continue - try: - inter = poly_o.intersection(poly_n) - union = poly_o.union(poly_n) - except: - continue - rate = inter.area / union.area if union.area > 1e-6 else 0.0 - if rate > _MDR_CORRECTION_MIN_OVERLAP and rate > best_rate: - best_rate = rate - best_j = j - if best_j != -1: - matched.append((orig_f, new_frags_global[best_j])) - used_new.add(best_j) - else: - unmatched_orig.append(orig_f) - unmatched_new = [f for j, f in enumerate(new_frags_global) if j not in used_new] - final = [n if n.rank >= o.rank else o for o, n in matched] - final.extend(unmatched_orig) - final.extend(unmatched_new) - layout.fragments = final - layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + x1, y1, x2, y2 = layout.rect.wrapper + margin = 5 + crop_box = (max(0, round(x1) - margin), max(0, round(y1) - margin), min(source_img.width, round(x2) + margin), + min(source_img.height, round(y2) + margin)) + if crop_box[0] >= crop_box[2] or crop_box[1] >= crop_box[3]: + return + cropped = source_img.crop(crop_box) + off_x, off_y = crop_box[0], crop_box[1] + except Exception as e: + print(f"Correct: Crop error: {e}") + return + try: + cropped_np = np.array(cropped.convert("RGB"))[:, :, ::-1] + new_frags_local = list(ocr_engine.find_text_fragments(cropped_np)) + except Exception as e: + print(f"Correct: OCR error: {e}") + return + new_frags_global = [] + for f in new_frags_local: + r = f.rect + lt, rt, lb, rb = r.lt, r.rt, r.lb, r.rb + f.rect = MDRRectangle(lt=(lt[0] + off_x, lt[1] + off_y), rt=(rt[0] + off_x, rt[1] + off_y), + lb=(lb[0] + off_x, lb[1] + off_y), rb=(rb[0] + off_x, rb[1] + off_y)) + new_frags_global.append(f) + orig_frags = layout.fragments + matched, unmatched_orig = [], [] + used_new = set() + for i, orig_f in enumerate(orig_frags): + best_j, best_rate = -1, -1.0 + try: + poly_o = Polygon(orig_f.rect) + except: + continue + if not poly_o.is_valid: + continue + for j, new_f in enumerate(new_frags_global): + if j in used_new: + continue + try: + poly_n = Polygon(new_f.rect) + except: + continue + if not poly_n.is_valid: + continue + try: + inter = poly_o.intersection(poly_n) + union = poly_o.union(poly_n) + except: + continue + rate = inter.area / union.area if union.area > 1e-6 else 0.0 + if rate > _MDR_CORRECTION_MIN_OVERLAP and rate > best_rate: + best_rate = rate + best_j = j + if best_j != -1: + matched.append((orig_f, new_frags_global[best_j])) + used_new.add(best_j) + else: + unmatched_orig.append(orig_f) + unmatched_new = [f for j, f in enumerate(new_frags_global) if j not in used_new] + final = [n if n.rank >= o.rank else o for o, n in matched] + final.extend(unmatched_orig) + final.extend(unmatched_new) + layout.fragments = final + layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + # --- MDR OCR Engine --- -_MDR_OCR_MODELS = {"det": ("ppocrv4","det","det.onnx"), "cls": ("ppocrv4","cls","cls.onnx"), "rec": ("ppocrv4","rec","rec.onnx"), "keys": ("ch_ppocr_server_v2.0","ppocr_keys_v1.txt")} +_MDR_OCR_MODELS = {"det": ("ppocrv4", "det", "det.onnx"), "cls": ("ppocrv4", "cls", "cls.onnx"), + "rec": ("ppocrv4", "rec", "rec.onnx"), "keys": ("ch_ppocr_server_v2.0", "ppocr_keys_v1.txt")} _MDR_OCR_URL_BASE = "https://huggingface.co/moskize/OnnxOCR/resolve/main/" + @dataclass class _MDR_ONNXParams: # Attributes without default values @@ -1578,122 +1774,144 @@ class _MDR_ONNXParams: show_log: bool = False use_onnx: bool = True -class MDROcrEngine: - """Handles OCR detection and recognition using ONNX models.""" - - def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str): - self._device = device; self._model_dir = mdr_ensure_directory(model_dir_path) - self._text_system: _MDR_TextSystem | None = None; self._onnx_params: _MDR_ONNXParams | None = None - self._ensure_models(); self._get_system() # Init on creation - - def _ensure_models(self): - for key, parts in _MDR_OCR_MODELS.items(): - fp = Path(self._model_dir) / Path(*parts) - if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join(parts); mdr_download_model(url, fp) - - def _get_system(self) -> _MDR_TextSystem | None: - if self._text_system is None: - paths = {k: str(Path(self._model_dir)/Path(*p)) for k,p in _MDR_OCR_MODELS.items()} - # In MDROcrEngine._get_system() - self._onnx_params = _MDR_ONNXParams( - use_gpu=(self._device=="cuda"), - det_model_dir=paths["det"], - cls_model_dir=paths["cls"], - rec_model_dir=paths["rec"], - rec_char_dict_path=paths["keys"], - # much lower thresholds so we actually get some candidate masks: - det_db_thresh=0.1, - det_db_box_thresh=0.3, - drop_score=0.1, - use_angle_cls=False, - ) - try: self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.") - except Exception as e: print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None - return self._text_system - -# In class MDROcrEngine: - def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]: - """Finds and recognizes text fragments in a NumPy image (BGR).""" - system = self._get_system() - if system is None: - print(" DEBUG OCR Engine: MDR OCR System unavailable. No fragments will be found.") - return - - img_for_system = self._preprocess(image_np) - print(f" DEBUG OCR Engine: Image preprocessed for TextSystem. Shape: {img_for_system.shape}") - try: - boxes, recs = system(img_for_system) - except Exception as e: - print(f" DEBUG OCR Engine: Error during TextSystem prediction: {e}") - import traceback - traceback.print_exc() - return +class MDROcrEngine: + """Handles OCR detection and recognition using ONNX models.""" + + def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str): + self._device = device; + self._model_dir = mdr_ensure_directory(model_dir_path) + self._text_system: _MDR_TextSystem | None = None; + self._onnx_params: _MDR_ONNXParams | None = None + self._ensure_models(); + self._get_system() # Init on creation + + def _ensure_models(self): + for key, parts in _MDR_OCR_MODELS.items(): + fp = Path(self._model_dir) / Path(*parts) + if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join( + parts); mdr_download_model(url, fp) + + def _get_system(self) -> _MDR_TextSystem | None: + if self._text_system is None: + paths = {k: str(Path(self._model_dir) / Path(*p)) for k, p in _MDR_OCR_MODELS.items()} + # In MDROcrEngine._get_system() + self._onnx_params = _MDR_ONNXParams( + use_gpu=(self._device == "cuda"), + det_model_dir=paths["det"], + cls_model_dir=paths["cls"], + rec_model_dir=paths["rec"], + rec_char_dict_path=paths["keys"], + # much lower thresholds so we actually get some candidate masks: + det_db_thresh=0.1, + det_db_box_thresh=0.3, + drop_score=0.1, + use_angle_cls=False, + ) + try: + self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.") + except Exception as e: + print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None + return self._text_system - if not boxes or not recs: - print(f" DEBUG OCR Engine: TextSystem returned no boxes ({len(boxes) if boxes is not None else 'None'}) or no recs ({len(recs) if recs is not None else 'None'}). No fragments generated.") - return + # In class MDROcrEngine: + def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]: + """Finds and recognizes text fragments in a NumPy image (BGR).""" + system = self._get_system() + if system is None: + print(" DEBUG OCR Engine: MDR OCR System unavailable. No fragments will be found.") + return - if len(boxes) != len(recs): - print(f" DEBUG OCR Engine: Mismatch between boxes ({len(boxes)}) and recs ({len(recs)}) from TextSystem. This is problematic. No fragments generated.") - return + img_for_system = self._preprocess(image_np) + print(f" DEBUG OCR Engine: Image preprocessed for TextSystem. Shape: {img_for_system.shape}") - print(f" DEBUG OCR Engine: TextSystem returned {len(boxes)} boxes and {len(recs)} recognition results. Converting to MDROcrFragment.") - fragments_generated_count = 0 - for i, (box_pts, rec_tuple) in enumerate(zip(boxes, recs)): - if not isinstance(rec_tuple, (list, tuple)) or len(rec_tuple) != 2: - print(f" DEBUG OCR Engine: Rec item {i} is not a valid (text, score) tuple: {rec_tuple}. Skipping.") - continue + try: + boxes, recs = system(img_for_system) + except Exception as e: + print(f" DEBUG OCR Engine: Error during TextSystem prediction: {e}") + import traceback + traceback.print_exc() + return + + if not boxes or not recs: + print( + f" DEBUG OCR Engine: TextSystem returned no boxes ({len(boxes) if boxes is not None else 'None'}) or no recs ({len(recs) if recs is not None else 'None'}). No fragments generated.") + return + + if len(boxes) != len(recs): + print( + f" DEBUG OCR Engine: Mismatch between boxes ({len(boxes)}) and recs ({len(recs)}) from TextSystem. This is problematic. No fragments generated.") + return + + print( + f" DEBUG OCR Engine: TextSystem returned {len(boxes)} boxes and {len(recs)} recognition results. Converting to MDROcrFragment.") + fragments_generated_count = 0 + for i, (box_pts, rec_tuple) in enumerate(zip(boxes, recs)): + if not isinstance(rec_tuple, (list, tuple)) or len(rec_tuple) != 2: + print(f" DEBUG OCR Engine: Rec item {i} is not a valid (text, score) tuple: {rec_tuple}. Skipping.") + continue - txt, conf = rec_tuple - if not txt or mdr_is_whitespace(txt): - # print(f" DEBUG OCR Engine: Fragment {i} has empty/whitespace text after system call. Text: '{txt}'. Skipping.") # Already logged in TextSystem - continue + txt, conf = rec_tuple + if not txt or mdr_is_whitespace(txt): + # print(f" DEBUG OCR Engine: Fragment {i} has empty/whitespace text after system call. Text: '{txt}'. Skipping.") # Already logged in TextSystem + continue - try: - pts = [(float(p[0]), float(p[1])) for p in box_pts] - if len(pts) == 4: - r = MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]) - if r.is_valid and r.area > 1: - yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r) - fragments_generated_count += 1 - # else: + try: + pts = [(float(p[0]), float(p[1])) for p in box_pts] + if len(pts) == 4: + r = MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]) + if r.is_valid and r.area > 1: + yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r) + fragments_generated_count += 1 + # else: # print(f" DEBUG OCR Engine: Fragment {i} has invalid/small rectangle. Area: {r.area:.2f}. Valid: {r.is_valid}. Skipping.") - # else: + # else: # print(f" DEBUG OCR Engine: Fragment {i} box_pts not length 4: {len(pts)}. Skipping.") - except Exception as e_frag: - print(f" DEBUG OCR Engine: Error creating MDROcrFragment for item {i}: {e_frag}") - continue + except Exception as e_frag: + print(f" DEBUG OCR Engine: Error creating MDROcrFragment for item {i}: {e_frag}") + continue + + print(f" DEBUG OCR Engine: Generated {fragments_generated_count} MDROcrFragment objects.") + + def _preprocess(self, img: np.ndarray) -> np.ndarray: + if len(img.shape) == 3 and img.shape[2] == 4: + a = img[:, :, 3] / 255.0 + bg = (255, 255, 255) + new = np.zeros_like(img[:, :, :3]) + [setattr(new[:, :, i], 'flags.writeable', True) for i in range(3)] + [np.copyto(new[:, :, i], (bg[i] * (1 - a) + img[:, :, i] * a)) for i in range(3)] + img = new.astype(np.uint8) + elif len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif not (len(img.shape) == 3 and img.shape[2] == 3): + raise ValueError("Unsupported image format") + return img - print(f" DEBUG OCR Engine: Generated {fragments_generated_count} MDROcrFragment objects.") - - def _preprocess(self, img: np.ndarray) -> np.ndarray: - if len(img.shape) == 3 and img.shape[2] == 4: - a = img[:, :, 3] / 255.0 - bg = (255, 255, 255) - new = np.zeros_like(img[:, :, :3]) - [setattr(new[:, :, i], 'flags.writeable', True) for i in range(3)] - [np.copyto(new[:, :, i], (bg[i] * (1 - a) + img[:, :, i] * a)) for i in range(3)] - img = new.astype(np.uint8) - elif len(img.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - elif not (len(img.shape) == 3 and img.shape[2] == 3): - raise ValueError("Unsupported image format") - return img # --- MDR Layout Reading Internals --- -_MDR_MAX_LEN = 510; _MDR_CLS_ID = 0; _MDR_SEP_ID = 2; _MDR_PAD_ID = 1 +_MDR_MAX_LEN = 510; +_MDR_CLS_ID = 0; +_MDR_SEP_ID = 2; +_MDR_PAD_ID = 1 -def mdr_boxes_to_reader_inputs(boxes: List[List[int]], max_len=_MDR_MAX_LEN) -> Dict[str, torch.Tensor]: - t_boxes = boxes[:max_len]; i_boxes = [[0,0,0,0]] + t_boxes + [[0,0,0,0]] - i_ids = [_MDR_CLS_ID] + [_MDR_PAD_ID]*len(t_boxes) + [_MDR_SEP_ID] - a_mask = [1]*len(i_ids); pad_len = (max_len+2) - len(i_ids) - if pad_len > 0: i_boxes.extend([[0,0,0,0]]*pad_len); i_ids.extend([_MDR_PAD_ID]*pad_len); a_mask.extend([0]*pad_len) - return {"bbox": torch.tensor([i_boxes]), "input_ids": torch.tensor([i_ids]), "attention_mask": torch.tensor([a_mask])} -def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification) -> Dict[str, torch.Tensor]: +def mdr_boxes_to_reader_inputs(boxes: List[List[int]], max_len=_MDR_MAX_LEN) -> Dict[str, torch.Tensor]: + t_boxes = boxes[:max_len]; + i_boxes = [[0, 0, 0, 0]] + t_boxes + [[0, 0, 0, 0]] + i_ids = [_MDR_CLS_ID] + [_MDR_PAD_ID] * len(t_boxes) + [_MDR_SEP_ID] + a_mask = [1] * len(i_ids); + pad_len = (max_len + 2) - len(i_ids) + if pad_len > 0: i_boxes.extend([[0, 0, 0, 0]] * pad_len); i_ids.extend([_MDR_PAD_ID] * pad_len); a_mask.extend( + [0] * pad_len) + return {"bbox": torch.tensor([i_boxes]), "input_ids": torch.tensor([i_ids]), + "attention_mask": torch.tensor([a_mask])} + + +def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification) -> Dict[ + str, torch.Tensor]: return {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]: print(f"mdr_parse_reader_logits: Called with logits shape: {logits.shape}, length: {length}") if length == 0: @@ -1702,14 +1920,14 @@ def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]: print(f"mdr_parse_reader_logits: Attempting to slice logits with [1 : {length + 1}, :{length}]") try: - rel_logits = logits[1 : length + 1, :length] + rel_logits = logits[1: length + 1, :length] print(f"mdr_parse_reader_logits: rel_logits shape: {rel_logits.shape}") except IndexError as e: print(f"mdr_parse_reader_logits: IndexError during rel_logits slicing! Error: {e}") import traceback traceback.print_exc() # Depending on desired behavior, either raise or return empty/fallback - return list(range(length)) # Fallback to sequential order if slicing fails + return list(range(length)) # Fallback to sequential order if slicing fails orders = rel_logits.argmax(dim=1).tolist() print(f"mdr_parse_reader_logits: Initial orders calculated. Count: {len(orders)}") @@ -1720,29 +1938,30 @@ def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]: # For N=33, N^2 = 1089. For N=21, N^2 = 441. This matches the logs. # A tighter bound might be N * (N-1) / 2 or N * some_factor. # Let's use N * N as seen in logs, or a fixed large number if N is small. - max_loops = max(50, length * length) # Ensure at least 50 loops for small N + max_loops = max(50, length * length) # Ensure at least 50 loops for small N while True: loop_count += 1 if loop_count > max_loops: - print(f"mdr_parse_reader_logits: Exceeded max_loops ({max_loops}), breaking while loop to prevent infinite loop.") - break + print( + f"mdr_parse_reader_logits: Exceeded max_loops ({max_loops}), breaking while loop to prevent infinite loop.") + break - # print(f"mdr_parse_reader_logits: While loop iteration: {loop_count}") # Can be too verbose + # print(f"mdr_parse_reader_logits: While loop iteration: {loop_count}") # Can be too verbose conflicts = defaultdict(list) [conflicts[order].append(idx) for idx, order in enumerate(orders)] - + # Filter to find actual conflicting orders (where multiple original indices map to the same target order) conflicting_orders_map = {o: idxs for o, idxs in conflicts.items() if len(idxs) > 1} if not conflicting_orders_map: # print("mdr_parse_reader_logits: No conflicting orders, breaking while loop.") # Verbose break - - # Log only if there are actual conflicts to resolve - if loop_count == 1 or loop_count % 10 == 0 : # Log first and every 10th iteration with conflicts - print(f"mdr_parse_reader_logits: While loop iteration: {loop_count}. Found {len(conflicting_orders_map)} conflicting orders.") + # Log only if there are actual conflicts to resolve + if loop_count == 1 or loop_count % 10 == 0: # Log first and every 10th iteration with conflicts + print( + f"mdr_parse_reader_logits: While loop iteration: {loop_count}. Found {len(conflicting_orders_map)} conflicting orders.") for order_val, c_idxs in conflicting_orders_map.items(): # This logic seems to pick the one with the highest score for that conflicting order. @@ -1772,7 +1991,7 @@ def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]: if confidence > max_confidence_for_this_order: max_confidence_for_this_order = confidence best_c_idx_for_this_order = current_c_idx - + # Now, for all conflicting indices for this 'order_val', # if they are not the 'best_c_idx_for_this_order', # they need a new order. A simple strategy is to make them point to themselves initially. @@ -1780,1061 +1999,1242 @@ def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]: for current_c_idx in c_idxs: if current_c_idx != best_c_idx_for_this_order: # Option 1: Reset to self (might not resolve complex cycles) - # orders[current_c_idx] = current_c_idx - + # orders[current_c_idx] = current_c_idx + # Option 2: Find next best order for this current_c_idx, excluding the conflicting 'order_val' # Create a temporary copy of its logits row, set the conflicting order's logit to -inf temp_logits_row = rel_logits[current_c_idx, :].clone() temp_logits_row[order_val] = -float('inf') orders[current_c_idx] = temp_logits_row.argmax().item() - - print(f"mdr_parse_reader_logits: While loop finished after {loop_count} iterations. Returning {len(orders)} orders.") + print( + f"mdr_parse_reader_logits: While loop finished after {loop_count} iterations. Returning {len(orders)} orders.") return orders + # --- MDR Layout Reading Engine --- @dataclass -class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; order: int; value: tuple[float, float, float, float] +class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; order: int; value: tuple[ + float, float, float, float] -class MDRLayoutReader: - """Determines reading order of layout elements using LayoutLMv3.""" - - def __init__(self, model_path: str): - self._model_path = model_path - self._model: LayoutLMv3ForTokenClassification | None = None - # Determine device more robustly, self._device will be 'cuda' or 'cpu' - if torch.cuda.is_available(): # Check if CUDA is actually available at runtime - self._device = "cuda" - print("MDRLayoutReader: CUDA is available. Setting device to cuda.") - else: - self._device = "cpu" - print("MDRLayoutReader: CUDA not available. Setting device to cpu.") - -# In class MDRLayoutReader: - def _get_model(self) -> LayoutLMv3ForTokenClassification | None: - if self._model is None: - # MODIFIED: Use self._model_path for the layoutreader's specific cache, - # and ensure it's a directory. self._model_path is passed during MDRLayoutReader init. - layoutreader_cache_dir = Path(self._model_path) # self._model_path is like "./mdr_models/layoutreader" - mdr_ensure_directory(str(layoutreader_cache_dir)) # Ensure this specific directory exists - - name = "microsoft/layoutlmv3-base" - - print(f"MDRLayoutReader: Attempting to load LayoutLMv3 model '{name}'. Cache dir: {layoutreader_cache_dir}") - try: - self._model = LayoutLMv3ForTokenClassification.from_pretrained( - name, - cache_dir=str(layoutreader_cache_dir), - local_files_only=False, - num_labels=_MDR_MAX_LEN+1 - ) - self._model.to(torch.device(self._device)) - self._model.eval() - print(f"MDR LayoutReader model '{name}' loaded successfully on device: {self._model.device}.") - except Exception as e: - print(f"ERROR loading MDR LayoutReader model '{name}': {e}") - import traceback - traceback.print_exc() - self._model = None - return self._model - def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]: - w, h = size - if w <= 0 or h <= 0: # ADDED check for invalid size - print("MDRLayoutReader: Invalid image size (w or h <= 0), returning layouts as is.") - return layouts - if not layouts: - print("MDRLayoutReader: No layouts to process, returning empty list.") - return [] # Return empty list if no layouts - - model = self._get_model() - # ... (rest of the method, add logging as needed) ... - print("MDRLayoutReader: Preparing bboxes...") - bbox_list = self._prepare_bboxes(layouts, w, h) - - if bbox_list is None or len(bbox_list) == 0: # Check if bbox_list is None or empty - print("MDRLayoutReader: No bboxes prepared from layouts, returning layouts as is (possibly sorted geometrically).") - # Fallback geometric sort if no bboxes could be prepared - layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) - return layouts - print(f"MDRLayoutReader: Prepared {len(bbox_list)} bboxes.") - # ... (rest of the scaling and inference logic) ... - try: - with torch.no_grad(): - print("MDRLayoutReader: Creating reader inputs...") - inputs = mdr_boxes_to_reader_inputs(scaled_bboxes) # scaled_bboxes comes from the loop above - print("MDRLayoutReader: Preparing inputs for model device...") - inputs = mdr_prepare_reader_inputs(inputs, model) - print("MDRLayoutReader: Running model inference...") - logits = model(**inputs).logits.cpu().squeeze(0) - print("MDRLayoutReader: Model inference complete. Parsing logits...") - orders = mdr_parse_reader_logits(logits, len(bbox_list)) - print(f"MDRLayoutReader: Logits parsed. Orders count: {len(orders)}") - except Exception as e: - print(f"MDR LayoutReader prediction error: {e}") - import traceback - traceback.print_exc() - # Fallback geometric sort on error - layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) - return layouts - # ... (rest of applying order) ... - print("MDRLayoutReader: Applying order...") - result_layouts = self._apply_order(layouts, bbox_list) # Ensure bbox_list has 'order' attribute set - print("MDRLayoutReader: Order applied. Returning layouts.") - return result_layouts - - def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None: - line_h = self._estimate_line_h(layouts) - bbox_list = [] - for i, l in enumerate(layouts): - if l.cls == MDRLayoutClass.PLAIN_TEXT and l.fragments: - [bbox_list.append(_MDR_ReaderBBox(i, j, False, -1, f.rect.wrapper)) for j, f in enumerate(l.fragments)] - else: - bbox_list.extend(self._gen_virtual(l, i, line_h, w, h)) - if len(bbox_list) > _MDR_MAX_LEN: - print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})") - return None - bbox_list.sort(key=lambda b: (b.value[1], b.value[0])) - return bbox_list - - def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]: - layout_map = defaultdict(list) - [layout_map[b.layout_index].append(b) for b in bbox_list] - layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes] - layout_orders.sort(key=lambda x: x[1]) - sorted_layouts = [layouts[idx] for idx, _ in layout_orders] - nfo = 0 - for l in sorted_layouts: - frags = l.fragments - if not frags: - continue - frag_bboxes = [b for b in layout_map[layouts.index(l)] if not b.virtual] - if frag_bboxes: - idx_to_order = {b.fragment_index: b.order for b in frag_bboxes} - frags.sort(key=lambda f: idx_to_order.get(frags.index(f), float('inf'))) +class MDRLayoutReader: + """Determines reading order of layout elements using LayoutLMv3.""" + + def __init__(self, model_path: str): + self._model_path = model_path + self._model: LayoutLMv3ForTokenClassification | None = None + # Determine device more robustly, self._device will be 'cuda' or 'cpu' + if torch.cuda.is_available(): # Check if CUDA is actually available at runtime + self._device = "cuda" + print("MDRLayoutReader: CUDA is available. Setting device to cuda.") else: - frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) - for frag in frags: - frag.order = nfo - nfo += 1 - return sorted_layouts - - def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float: - heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1] > 0] - return self._median(heights) if heights else 15.0 - - def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]: - x0, y0, x1, y1 = l.rect.wrapper - lh = y1 - y0 - lw = x1 - x0 - if lh <= 0 or lw <= 0 or line_h <= 0: - yield _MDR_ReaderBBox(l_idx, -1, True, -1, (x0, y0, x1, y1)) - return - lines = 1 - if lh > line_h * 1.5: - if lh <= ph * 0.25 or lw >= pw * 0.5: - lines = 3 - elif lw > pw * 0.25: - lines = 3 if lw > pw * 0.4 else 2 - elif lw <= pw * 0.25: - lines = max(1, int(lh / (line_h * 1.5))) if lh / lw > 1.5 else 2 + self._device = "cpu" + print("MDRLayoutReader: CUDA not available. Setting device to cpu.") + + # In class MDRLayoutReader: + def _get_model(self) -> LayoutLMv3ForTokenClassification | None: + if self._model is None: + # MODIFIED: Use self._model_path for the layoutreader's specific cache, + # and ensure it's a directory. self._model_path is passed during MDRLayoutReader init. + layoutreader_cache_dir = Path(self._model_path) # self._model_path is like "./mdr_models/layoutreader" + mdr_ensure_directory(str(layoutreader_cache_dir)) # Ensure this specific directory exists + + name = "microsoft/layoutlmv3-base" + + print(f"MDRLayoutReader: Attempting to load LayoutLMv3 model '{name}'. Cache dir: {layoutreader_cache_dir}") + try: + self._model = LayoutLMv3ForTokenClassification.from_pretrained( + name, + cache_dir=str(layoutreader_cache_dir), + local_files_only=False, + num_labels=_MDR_MAX_LEN + 1 + ) + self._model.to(torch.device(self._device)) + self._model.eval() + print(f"MDR LayoutReader model '{name}' loaded successfully on device: {self._model.device}.") + except Exception as e: + print(f"ERROR loading MDR LayoutReader model '{name}': {e}") + import traceback + traceback.print_exc() + self._model = None + return self._model + + # In class MDRLayoutReader: + + def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]: + w, h = size + if w <= 0 or h <= 0: + print("MDRLayoutReader: Invalid image size (w or h <= 0), returning layouts as is.") + return layouts + if not layouts: + print("MDRLayoutReader: No layouts to process, returning empty list.") + return [] + + model = self._get_model() + if model is None: + print("MDRLayoutReader: Model not available, returning layouts sorted geometrically.") + layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) # Sort by top-left y, then x + return layouts + + print("MDRLayoutReader: Preparing bboxes...") + # bbox_list contains _MDR_ReaderBBox objects, each with .value = (x0,y0,x1,y1) in original pixels + bbox_list = self._prepare_bboxes(layouts, w, h) + + if bbox_list is None or len(bbox_list) == 0: + print("MDRLayoutReader: No bboxes prepared from layouts, returning layouts as is (sorted geometrically).") + layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) + return layouts + print(f"MDRLayoutReader: Prepared {len(bbox_list)} bboxes.") + + # --- START: ADDED SCALING LOGIC --- + scaled_bboxes: list[list[int]] = [] + if w > 0 and h > 0: # Ensure valid width and height for division + for bbox_item in bbox_list: + # bbox_item.value is (x0, y0, x1, y1) in original image coordinates + x0, y0, x1, y1 = bbox_item.value + + # Scale to 0-1000 range based on image width (w) and height (h) + # Ensure coordinates are within [0, 1000] and x1>=x0, y1>=y0 + # Clamp values to image boundaries before scaling to prevent negative scaled values if original box is outside + x0_c = max(0.0, min(x0, float(w))) + y0_c = max(0.0, min(y0, float(h))) + x1_c = max(0.0, min(x1, float(w))) + y1_c = max(0.0, min(y1, float(h))) + + scaled_x0 = max(0, min(1000, int(1000 * x0_c / w))) + scaled_y0 = max(0, min(1000, int(1000 * y0_c / h))) + scaled_x1 = max(scaled_x0, min(1000, int(1000 * x1_c / w))) # Ensure x1 >= x0 + scaled_y1 = max(scaled_y0, min(1000, int(1000 * y1_c / h))) # Ensure y1 >= y0 + scaled_bboxes.append([scaled_x0, scaled_y0, scaled_x1, scaled_y1]) else: - lines = max(1, int(round(lh / line_h))) - lines = max(1, lines) - act_line_h = lh / lines - cur_y = y0 - for i in range(lines): - ly0 = max(0, min(ph, cur_y)) - ly1 = max(0, min(ph, cur_y + act_line_h)) - lx0 = max(0, min(pw, x0)) - lx1 = max(0, min(pw, x1)) - if ly1 > ly0 and lx1 > lx0: - yield _MDR_ReaderBBox(l_idx, -1, True, -1, (lx0, ly0, lx1, ly1)) - cur_y += act_line_h - - def _median(self, nums: list[float|int]) -> float: - if not nums: - return 0.0 - s_nums = sorted(nums) - n = len(s_nums) - return float(s_nums[n // 2]) if n % 2 == 1 else float((s_nums[n // 2 - 1] + s_nums[n // 2]) / 2.0) + print( + "MDRLayoutReader: Warning - Invalid image dimensions (w or h is zero) for scaling bboxes. Cannot determine reading order.") + layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) + return layouts + # --- END: ADDED SCALING LOGIC --- + + if not scaled_bboxes: # If scaling resulted in no bboxes (e.g. w/h was 0) + print( + "MDRLayoutReader: No scaled bboxes available after scaling step. Returning geometrically sorted layouts.") + layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) + return layouts + + orders: list[int] = [] + try: + with torch.no_grad(): + print("MDRLayoutReader: Creating reader inputs...") + inputs = mdr_boxes_to_reader_inputs(scaled_bboxes) # Use the newly created scaled_bboxes + print("MDRLayoutReader: Preparing inputs for model device...") + inputs = mdr_prepare_reader_inputs(inputs, model) + print("MDRLayoutReader: Running model inference...") + logits = model(**inputs).logits.cpu().squeeze(0) + print("MDRLayoutReader: Model inference complete. Parsing logits...") + # length is based on original bbox_list (which should match scaled_bboxes length) + orders = mdr_parse_reader_logits(logits, len(bbox_list)) + print(f"MDRLayoutReader: Logits parsed. Orders count: {len(orders)}") + + # Assign the determined orders back to the bbox_list items + if len(orders) == len(bbox_list): + for i, order_val in enumerate(orders): + bbox_list[i].order = order_val + else: + print( + f"MDRLayoutReader: Warning - Mismatch between orders ({len(orders)}) and bbox_list ({len(bbox_list)}). Order assignment might be incorrect. Using sequential order.") + for i in range(len(bbox_list)): # Fallback to sequential order + bbox_list[i].order = i + except Exception as e: + print(f"MDR LayoutReader prediction error: {e}") + import traceback + traceback.print_exc() + # Fallback: assign sequential order to bbox_list items before geometric sort of layouts + for i in range(len(bbox_list)): + bbox_list[i].order = i + # Then apply this sequential order (which effectively becomes a geometric sort) + print("MDRLayoutReader: Applying fallback sequential order due to error...") + result_layouts = self._apply_order(layouts, bbox_list) + return result_layouts # Return here after applying fallback order + + print("MDRLayoutReader: Applying order...") + result_layouts = self._apply_order(layouts, bbox_list) + print("MDRLayoutReader: Order applied. Returning layouts.") + return result_layouts + + def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None: + line_h = self._estimate_line_h(layouts) + bbox_list = [] + for i, l in enumerate(layouts): + if l.cls == MDRLayoutClass.PLAIN_TEXT and l.fragments: + [bbox_list.append(_MDR_ReaderBBox(i, j, False, -1, f.rect.wrapper)) for j, f in enumerate(l.fragments)] + else: + bbox_list.extend(self._gen_virtual(l, i, line_h, w, h)) + if len(bbox_list) > _MDR_MAX_LEN: + print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})") + return None + bbox_list.sort(key=lambda b: (b.value[1], b.value[0])) + return bbox_list + + def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]: + layout_map = defaultdict(list) + [layout_map[b.layout_index].append(b) for b in bbox_list] + layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes] + layout_orders.sort(key=lambda x: x[1]) + sorted_layouts = [layouts[idx] for idx, _ in layout_orders] + nfo = 0 + for l in sorted_layouts: + frags = l.fragments + if not frags: + continue + frag_bboxes = [b for b in layout_map[layouts.index(l)] if not b.virtual] + if frag_bboxes: + idx_to_order = {b.fragment_index: b.order for b in frag_bboxes} + frags.sort(key=lambda f: idx_to_order.get(frags.index(f), float('inf'))) + else: + frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + for frag in frags: + frag.order = nfo + nfo += 1 + return sorted_layouts + + def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float: + heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1] > 0] + return self._median(heights) if heights else 15.0 + + def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[ + _MDR_ReaderBBox, None, None]: + x0, y0, x1, y1 = l.rect.wrapper + lh = y1 - y0 + lw = x1 - x0 + if lh <= 0 or lw <= 0 or line_h <= 0: + yield _MDR_ReaderBBox(l_idx, -1, True, -1, (x0, y0, x1, y1)) + return + lines = 1 + if lh > line_h * 1.5: + if lh <= ph * 0.25 or lw >= pw * 0.5: + lines = 3 + elif lw > pw * 0.25: + lines = 3 if lw > pw * 0.4 else 2 + elif lw <= pw * 0.25: + lines = max(1, int(lh / (line_h * 1.5))) if lh / lw > 1.5 else 2 + else: + lines = max(1, int(round(lh / line_h))) + lines = max(1, lines) + act_line_h = lh / lines + cur_y = y0 + for i in range(lines): + ly0 = max(0, min(ph, cur_y)) + ly1 = max(0, min(ph, cur_y + act_line_h)) + lx0 = max(0, min(pw, x0)) + lx1 = max(0, min(pw, x1)) + if ly1 > ly0 and lx1 > lx0: + yield _MDR_ReaderBBox(l_idx, -1, True, -1, (lx0, ly0, lx1, ly1)) + cur_y += act_line_h + + def _median(self, nums: list[float | int]) -> float: + if not nums: + return 0.0 + s_nums = sorted(nums) + n = len(s_nums) + return float(s_nums[n // 2]) if n % 2 == 1 else float((s_nums[n // 2 - 1] + s_nums[n // 2]) / 2.0) + # --- MDR LaTeX Extractor --- class MDRLatexExtractor: - """Extracts LaTeX from formula images using pix2tex.""" + """Extracts LaTeX from formula images using pix2tex.""" + + def __init__(self, model_path: str): + self._model_path = model_path; + self._model: LatexOCR | None = None + self._device = "cuda" if torch.cuda.is_available() else "cpu" + + def extract(self, image: Image) -> str | None: + if LatexOCR is None: return None; + image = mdr_expand_image(image, 0.1) + model = self._get_model() + if model is None: return None; + try: + with torch.no_grad(): + img_rgb = image.convert('RGB') if image.mode != 'RGB' else image; latex = model( + img_rgb); return latex if latex else None + except Exception as e: + print(f"MDR LaTeX error: {e}"); return None + + def _get_model(self) -> LatexOCR | None: + if self._model is None and LatexOCR is not None: + mdr_ensure_directory(self._model_path) + wp = Path(self._model_path) / "weights.pth" + rp = Path(self._model_path) / "image_resizer.pth" + cp = Path(self._model_path) / "config.yaml" + if not wp.exists() or not rp.exists(): + print("Downloading MDR LaTeX models...") + self._download() + if not cp.exists(): + print(f"Warn: MDR LaTeX config not found {self._model_path}") + try: + args = Munch({"config": str(cp), "checkpoint": str(wp), "device": self._device, + "no_cuda": self._device == "cuda", "no_resize": False, "temperature": 0.0}) + self._model = LatexOCR(args) + print(f"MDR LaTeX loaded on {self._device}.") + except Exception as e: + print(f"ERROR initializing MDR LatexOCR: {e}") + self._model = None + return self._model - def __init__(self, model_path: str): - self._model_path = model_path; self._model: LatexOCR | None = None - self._device = "cuda" if torch.cuda.is_available() else "cpu" + def _download(self): + tag = "v0.0.1" + base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/" + files = {"weights.pth": base + "weights.pth", "image_resizer.pth": base + "image_resizer.pth"} + mdr_ensure_directory(self._model_path) + [mdr_download_model(url, Path(self._model_path) / name) for name, url in files.items() if + not (Path(self._model_path) / name).exists()] - def extract(self, image: Image) -> str | None: - if LatexOCR is None: return None; - image = mdr_expand_image(image, 0.1) - model = self._get_model() - if model is None: return None; - try: - with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None - except Exception as e: print(f"MDR LaTeX error: {e}"); return None - - def _get_model(self) -> LatexOCR | None: - if self._model is None and LatexOCR is not None: - mdr_ensure_directory(self._model_path) - wp = Path(self._model_path) / "weights.pth" - rp = Path(self._model_path) / "image_resizer.pth" - cp = Path(self._model_path) / "config.yaml" - if not wp.exists() or not rp.exists(): - print("Downloading MDR LaTeX models...") - self._download() - if not cp.exists(): - print(f"Warn: MDR LaTeX config not found {self._model_path}") - try: - args = Munch({"config": str(cp), "checkpoint": str(wp), "device": self._device, "no_cuda": self._device == "cuda", "no_resize": False, "temperature": 0.0}) - self._model = LatexOCR(args) - print(f"MDR LaTeX loaded on {self._device}.") - except Exception as e: - print(f"ERROR initializing MDR LatexOCR: {e}") - self._model = None - return self._model - - def _download(self): - tag = "v0.0.1" - base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/" - files = {"weights.pth": base + "weights.pth", "image_resizer.pth": base + "image_resizer.pth"} - mdr_ensure_directory(self._model_path) - [mdr_download_model(url, Path(self._model_path) / name) for name, url in files.items() if not (Path(self._model_path) / name).exists()] # --- MDR Table Parser --- MDRTableOutputFormat = Literal["latex", "markdown", "html"] -class MDRTableParser: - """Parses table structure/content from images using StructTable model.""" - - def __init__(self, device: Literal["cpu", "cuda"], model_path: str): - self._model: Any | None = None; self._model_path = mdr_ensure_directory(model_path) - self._device = device if torch.cuda.is_available() and device=="cuda" else "cpu" - self._disabled = self._device == "cuda" - if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.") - - def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None: - if self._disabled: return None; - fmt: MDRTableOutputFormat | None = None - if format == MDRTableLayoutParsedFormat.LATEX: fmt="latex" - elif format == MDRTableLayoutParsedFormat.MARKDOWN: fmt="markdown" - elif format == MDRTableLayoutParsedFormat.HTML: fmt="html" - else: return None - image = mdr_expand_image(image, 0.05) - model = self._get_model() - if model is None: return None; - try: - img_rgb = image.convert('RGB') if image.mode!='RGB' else image - with torch.no_grad(): results = model([img_rgb], output_format=fmt) - return results[0] if results else None - except Exception as e: print(f"MDR Table parsing error: {e}"); return None - - def _get_model(self): - if self._model is None and not self._disabled: - try: - from struct_eqtable import build_model # Dynamic import - name = "U4R/StructTable-InternVL2-1B"; local = any(Path(self._model_path).iterdir()) - print(f"Loading MDR StructTable model '{name}'...") - model = build_model(model_ckpt=name, max_new_tokens=1024, max_time=30, lmdeploy=False, flash_attn=True, batch_size=1, cache_dir=self._model_path, local_files_only=local) - self._model = model.to(self._device); print(f"MDR StructTable loaded on {self._device}.") - except ImportError: print("ERROR: struct_eqtable not found."); self._disabled=True; self._model=None - except Exception as e: print(f"ERROR loading MDR StructTable: {e}"); self._model=None - return self._model -# --- MDR Image Optimizer --- -_MDR_TINY_ROTATION = 0.005 +class MDRTableParser: + """Parses table structure/content from images using StructTable model.""" + + def __init__(self, device: Literal["cpu", "cuda"], model_path: str): + self._model: Any | None = None; + self._model_path = mdr_ensure_directory(model_path) + self._device = device if torch.cuda.is_available() and device == "cuda" else "cpu" + self._disabled = self._device == "cuda" + if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.") + + def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None: + if self._disabled: return None; + fmt: MDRTableOutputFormat | None = None + if format == MDRTableLayoutParsedFormat.LATEX: + fmt = "latex" + elif format == MDRTableLayoutParsedFormat.MARKDOWN: + fmt = "markdown" + elif format == MDRTableLayoutParsedFormat.HTML: + fmt = "html" + else: + return None + image = mdr_expand_image(image, 0.05) + model = self._get_model() + if model is None: return None; + try: + img_rgb = image.convert('RGB') if image.mode != 'RGB' else image + with torch.no_grad(): + results = model([img_rgb], output_format=fmt) + return results[0] if results else None + except Exception as e: + print(f"MDR Table parsing error: {e}"); return None -@dataclass -class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationAdjuster; fragment_origin_rectangles: list[MDRRectangle] + def _get_model(self): + if self._model is None and not self._disabled: + try: + from struct_eqtable import build_model # Dynamic import + name = "U4R/StructTable-InternVL2-1B"; + local = any(Path(self._model_path).iterdir()) + print(f"Loading MDR StructTable model '{name}'...") + model = build_model(model_ckpt=name, max_new_tokens=1024, max_time=30, lmdeploy=False, flash_attn=True, + batch_size=1, cache_dir=self._model_path, local_files_only=local) + self._model = model.to(self._device); + print(f"MDR StructTable loaded on {self._device}.") + except ImportError: + print("ERROR: struct_eqtable not found."); self._disabled = True; self._model = None + except Exception as e: + print(f"ERROR loading MDR StructTable: {e}"); self._model = None + return self._model -class MDRImageOptimizer: - """Handles image rotation detection and coordinate adjustments.""" - def __init__(self, raw_image: Image, adjust_points: bool): - self._raw = raw_image; self._image = raw_image; self._adjust_points = adjust_points - self._fragments: list[MDROcrFragment] = []; self._rotation: float = 0.0; self._rot_ctx: _MDR_RotationContext | None = None +# --- MDR Image Optimizer --- +_MDR_TINY_ROTATION = 0.005 - @property - def image(self) -> Image: return self._image - @property - def adjusted_image(self) -> Image | None: return self._image if self._rot_ctx is not None else None +@dataclass +class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationAdjuster; fragment_origin_rectangles: \ +list[MDRRectangle] - @property - def rotation(self) -> float: return self._rotation - @property - def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) +class MDRImageOptimizer: + """Handles image rotation detection and coordinate adjustments.""" + + def __init__(self, raw_image: Image, adjust_points: bool): + self._raw = raw_image; + self._image = raw_image; + self._adjust_points = adjust_points + self._fragments: list[MDROcrFragment] = []; + self._rotation: float = 0.0; + self._rot_ctx: _MDR_RotationContext | None = None + + @property + def image(self) -> Image: + return self._image + + @property + def adjusted_image(self) -> Image | None: + return self._image if self._rot_ctx is not None else None + + @property + def rotation(self) -> float: + return self._rotation + + @property + def image_np(self) -> np.ndarray: + img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + + def receive_fragments(self, fragments: list[MDROcrFragment]): + self._fragments = fragments + if not fragments: + return + self._rotation = mdr_calculate_image_rotation(fragments) + if abs(self._rotation) < _MDR_TINY_ROTATION: + self._rotation = 0.0 + return + orig_sz = self._raw.size + try: + self._image = self._raw.rotate(-np.degrees(self._rotation), resample=PILResampling.BICUBIC, + fillcolor=(255, 255, 255), expand=True) + except Exception as e: + print(f"Optimizer rotation error: {e}") + self._rotation = 0.0 + self._image = self._raw + return + new_sz = self._image.size + self._rot_ctx = _MDR_RotationContext( + fragment_origin_rectangles=[f.rect for f in fragments], + to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False), + to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True)) + adj = self._rot_ctx.to_new + [setattr(f, 'rect', + MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f + in fragments if (r := f.rect)] + + def finalize_layout_coords(self, layouts: list[MDRLayoutElement]): + if self._rot_ctx is None or self._adjust_points: return + if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for + f, orig_r in zip(self._fragments, + self._rot_ctx.fragment_origin_rectangles)] + adj = self._rot_ctx.to_origin; + [setattr(l, 'rect', + MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for l + in layouts if (r := l.rect)] - def receive_fragments(self, fragments: list[MDROcrFragment]): - self._fragments = fragments - if not fragments: - return - self._rotation = mdr_calculate_image_rotation(fragments) - if abs(self._rotation) < _MDR_TINY_ROTATION: - self._rotation = 0.0 - return - orig_sz = self._raw.size - try: - self._image = self._raw.rotate(-np.degrees(self._rotation), resample=PILResampling.BICUBIC, fillcolor=(255, 255, 255), expand=True) - except Exception as e: - print(f"Optimizer rotation error: {e}") - self._rotation = 0.0 - self._image = self._raw - return - new_sz = self._image.size - self._rot_ctx = _MDR_RotationContext( - fragment_origin_rectangles=[f.rect for f in fragments], - to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False), - to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True)) - adj = self._rot_ctx.to_new - [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r := f.rect)] - - def finalize_layout_coords(self, layouts: list[MDRLayoutElement]): - if self._rot_ctx is None or self._adjust_points: return - if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles)] - adj = self._rot_ctx.to_origin; [setattr(l, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for l in layouts if (r:=l.rect)] # --- MDR Image Clipping --- def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image: - """Clips a potentially rotated rectangle from an image.""" - try: - h_rot, _ = mdr_calculate_rectangle_rotation(rect) - avg_w, avg_h = rect.size - if avg_w <= 0 or avg_h <= 0: - return new_image("RGB", (1, 1), (255, 255, 255)) - tx, ty = rect.lt - trans_orig = np.array([[1, 0, -tx], [0, 1, -ty], [0, 0, 1]]) - cos_r = cos(-h_rot) - sin_r = sin(-h_rot) - rot = np.array([[cos_r, -sin_r, 0], [sin_r, cos_r, 0], [0, 0, 1]]) - pad_dx = wrap_w / 2.0 - pad_dy = wrap_h / 2.0 - trans_pad = np.array([[1, 0, pad_dx], [0, 1, pad_dy], [0, 0, 1]]) - matrix = trans_pad @ rot @ trans_orig + """Clips a potentially rotated rectangle from an image.""" try: - inv_matrix = np.linalg.inv(matrix) - except np.linalg.LinAlgError: - x0, y0, x1, y1 = rect.wrapper - return image.crop((round(x0), round(y0), round(x1), round(y1))) - p_mat = (inv_matrix[0, 0], inv_matrix[0, 1], inv_matrix[0, 2], inv_matrix[1, 0], inv_matrix[1, 1], inv_matrix[1, 2]) - out_w = ceil(avg_w + wrap_w) - out_h = ceil(avg_h + wrap_h) - return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, fillcolor=(255, 255, 255)) - except Exception as e: - print(f"MDR Clipping error: {e}") - return new_image("RGB", (10, 10), (255, 255, 255)) - -def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image: - """Clips a layout region from the MDRExtractionResult image.""" - img = res.adjusted_image if res.adjusted_image else res.extracted_image - return mdr_clip_from_image(img, layout.rect, wrap_w, wrap_h) + h_rot, _ = mdr_calculate_rectangle_rotation(rect) + avg_w, avg_h = rect.size + if avg_w <= 0 or avg_h <= 0: + return new_image("RGB", (1, 1), (255, 255, 255)) + tx, ty = rect.lt + trans_orig = np.array([[1, 0, -tx], [0, 1, -ty], [0, 0, 1]]) + cos_r = cos(-h_rot) + sin_r = sin(-h_rot) + rot = np.array([[cos_r, -sin_r, 0], [sin_r, cos_r, 0], [0, 0, 1]]) + pad_dx = wrap_w / 2.0 + pad_dy = wrap_h / 2.0 + trans_pad = np.array([[1, 0, pad_dx], [0, 1, pad_dy], [0, 0, 1]]) + matrix = trans_pad @ rot @ trans_orig + try: + inv_matrix = np.linalg.inv(matrix) + except np.linalg.LinAlgError: + x0, y0, x1, y1 = rect.wrapper + return image.crop((round(x0), round(y0), round(x1), round(y1))) + p_mat = ( + inv_matrix[0, 0], inv_matrix[0, 1], inv_matrix[0, 2], inv_matrix[1, 0], inv_matrix[1, 1], inv_matrix[1, 2]) + out_w = ceil(avg_w + wrap_w) + out_h = ceil(avg_h + wrap_h) + return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, + fillcolor=(255, 255, 255)) + except Exception as e: + print(f"MDR Clipping error: {e}") + return new_image("RGB", (10, 10), (255, 255, 255)) + + +def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, + wrap_h: float = 0.0) -> Image: + """Clips a layout region from the MDRExtractionResult image.""" + img = res.adjusted_image if res.adjusted_image else res.extracted_image + return mdr_clip_from_image(img, layout.rect, wrap_w, wrap_h) + # --- MDR Debug Plotting --- -_MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); _MDR_LAYOUT_COLORS = { MDRLayoutClass.TITLE: (0x0A,0x12,0x2C,255), MDRLayoutClass.PLAIN_TEXT: (0x3C,0x67,0x90,255), MDRLayoutClass.ABANDON: (0xC0,0xBB,0xA9,180), MDRLayoutClass.FIGURE: (0x5B,0x91,0x3C,255), MDRLayoutClass.FIGURE_CAPTION: (0x77,0xB3,0x54,255), MDRLayoutClass.TABLE: (0x44,0x17,0x52,255), MDRLayoutClass.TABLE_CAPTION: (0x81,0x75,0xA0,255), MDRLayoutClass.TABLE_FOOTNOTE: (0xEF,0xB6,0xC9,255), MDRLayoutClass.ISOLATE_FORMULA: (0xFA,0x38,0x27,255), MDRLayoutClass.FORMULA_CAPTION: (0xFF,0x9D,0x24,255) }; _MDR_DEFAULT_COLOR = (0x80,0x80,0x80,255); _MDR_RGBA = tuple[int,int,int,int] +_MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); +_MDR_LAYOUT_COLORS = {MDRLayoutClass.TITLE: (0x0A, 0x12, 0x2C, 255), MDRLayoutClass.PLAIN_TEXT: (0x3C, 0x67, 0x90, 255), + MDRLayoutClass.ABANDON: (0xC0, 0xBB, 0xA9, 180), MDRLayoutClass.FIGURE: (0x5B, 0x91, 0x3C, 255), + MDRLayoutClass.FIGURE_CAPTION: (0x77, 0xB3, 0x54, 255), + MDRLayoutClass.TABLE: (0x44, 0x17, 0x52, 255), + MDRLayoutClass.TABLE_CAPTION: (0x81, 0x75, 0xA0, 255), + MDRLayoutClass.TABLE_FOOTNOTE: (0xEF, 0xB6, 0xC9, 255), + MDRLayoutClass.ISOLATE_FORMULA: (0xFA, 0x38, 0x27, 255), + MDRLayoutClass.FORMULA_CAPTION: (0xFF, 0x9D, 0x24, 255)}; +_MDR_DEFAULT_COLOR = (0x80, 0x80, 0x80, 255); +_MDR_RGBA = tuple[int, int, int, int] + def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None: - """Draws layout and fragment boxes onto an image for debugging.""" - if not layouts: return; - try: - l_font = load_default(size=25) - f_font = load_default(size=15) # Not used currently, but kept for potential future use - draw = ImageDraw.Draw(image, mode="RGBA") - except Exception as e: - print(f"MDR Plot init error: {e}") - return - - def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA): + """Draws layout and fragment boxes onto an image for debugging.""" + if not layouts: return; try: - x, y = pos - txt = str(num) - txt_pos = (round(x) + 3, round(y) + 1) - bbox = draw.textbbox(txt_pos, txt, font=font) - bg_rect = (bbox[0] - 2, bbox[1] - 1, bbox[2] + 2, bbox[3] + 1) - bg_color = (color[0], color[1], color[2], 180) - draw.rectangle(bg_rect, fill=bg_color) - draw.text(txt_pos, txt, font=font, fill=(255, 255, 255, 255)) + l_font = load_default(size=25) + f_font = load_default(size=15) # Not used currently, but kept for potential future use + draw = ImageDraw.Draw(image, mode="RGBA") except Exception as e: - print(f"MDR Draw num error: {e}") + print(f"MDR Plot init error: {e}") + return + + def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA): + try: + x, y = pos + txt = str(num) + txt_pos = (round(x) + 3, round(y) + 1) + bbox = draw.textbbox(txt_pos, txt, font=font) + bg_rect = (bbox[0] - 2, bbox[1] - 1, bbox[2] + 2, bbox[3] + 1) + bg_color = (color[0], color[1], color[2], 180) + draw.rectangle(bg_rect, fill=bg_color) + draw.text(txt_pos, txt, font=font, fill=(255, 255, 255, 255)) + except Exception as e: + print(f"MDR Draw num error: {e}") + + for i, l in enumerate(layouts): + try: + l_color = _MDR_LAYOUT_COLORS.get(l.cls, _MDR_DEFAULT_COLOR) + draw.polygon([p for p in l.rect], outline=l_color, width=3) + _draw_num(l.rect.lt, i + 1, l_font, l_color) + except Exception as e: + print(f"MDR Layout draw error: {e}") + for l in layouts: + for f in l.fragments: + try: + draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1) + except Exception as e: + print(f"MDR Fragment draw error: {e}") - for i, l in enumerate(layouts): - try: - l_color = _MDR_LAYOUT_COLORS.get(l.cls, _MDR_DEFAULT_COLOR) - draw.polygon([p for p in l.rect], outline=l_color, width=3) - _draw_num(l.rect.lt, i + 1, l_font, l_color) - except Exception as e: - print(f"MDR Layout draw error: {e}") - for l in layouts: - for f in l.fragments: - try: - draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1) - except Exception as e: - print(f"MDR Fragment draw error: {e}") # --- MDR Extraction Engine --- class MDRExtractionEngine: - """Core engine for extracting structured information from a document image.""" - - def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"]="cpu", ocr_for_each_layouts: bool=True, extract_formula: bool=True, extract_table_format: MDRTableLayoutParsedFormat|None=None): - self._model_dir = model_dir_path # Base directory for all models - self._device = device if torch.cuda.is_available() else "cpu" - self._ocr_each = ocr_for_each_layouts; self._ext_formula = extract_formula; self._ext_table = extract_table_format - self._yolo: YOLOv10 | None = None - # Initialize sub-modules, passing the main model_dir_path - self._ocr_engine = MDROcrEngine(device=self._device, model_dir_path=os.path.join(self._model_dir, "onnx_ocr")) - self._table_parser = MDRTableParser(device=self._device, model_path=os.path.join(self._model_dir, "struct_eqtable")) - self._latex_extractor = MDRLatexExtractor(model_path=os.path.join(self._model_dir, "latex")) - self._layout_reader = MDRLayoutReader(model_path=os.path.join(self._model_dir, "layoutreader")) - print(f"MDR Extraction Engine initialized on device: {self._device}") - - # --- MODIFIED _get_yolo_model METHOD for HF --- - def _get_yolo_model(self) -> YOLOv10 | None: - """Loads the YOLOv10 layout detection model using hf_hub_download.""" - if self._yolo is None and YOLOv10 is not None: - repo_id = "juliozhao/DocLayout-YOLO-DocStructBench" - filename = "doclayout_yolo_docstructbench_imgsz1024.pt" - # Use a subdirectory within the main model dir for YOLO cache via HF Hub - yolo_cache_dir = Path(self._model_dir) / "yolo_hf_cache" - mdr_ensure_directory(str(yolo_cache_dir)) # Ensure cache dir exists - - print(f"Attempting to load YOLO model '{filename}' from repo '{repo_id}'...") - print(f"Hugging Face Hub cache directory for YOLO: {yolo_cache_dir}") + """Core engine for extracting structured information from a document image.""" + + def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"] = "cpu", ocr_for_each_layouts: bool = True, + extract_formula: bool = True, extract_table_format: MDRTableLayoutParsedFormat | None = None): + self._model_dir = model_dir_path # Base directory for all models + self._device = device if torch.cuda.is_available() else "cpu" + self._ocr_each = ocr_for_each_layouts; + self._ext_formula = extract_formula; + self._ext_table = extract_table_format + self._yolo: YOLOv10 | None = None + # Initialize sub-modules, passing the main model_dir_path + self._ocr_engine = MDROcrEngine(device=self._device, model_dir_path=os.path.join(self._model_dir, "onnx_ocr")) + self._table_parser = MDRTableParser(device=self._device, + model_path=os.path.join(self._model_dir, "struct_eqtable")) + self._latex_extractor = MDRLatexExtractor(model_path=os.path.join(self._model_dir, "latex")) + self._layout_reader = MDRLayoutReader(model_path=os.path.join(self._model_dir, "layoutreader")) + print(f"MDR Extraction Engine initialized on device: {self._device}") + + # --- MODIFIED _get_yolo_model METHOD for HF --- + def _get_yolo_model(self) -> YOLOv10 | None: + """Loads the YOLOv10 layout detection model using hf_hub_download.""" + if self._yolo is None and YOLOv10 is not None: + repo_id = "juliozhao/DocLayout-YOLO-DocStructBench" + filename = "doclayout_yolo_docstructbench_imgsz1024.pt" + # Use a subdirectory within the main model dir for YOLO cache via HF Hub + yolo_cache_dir = Path(self._model_dir) / "yolo_hf_cache" + mdr_ensure_directory(str(yolo_cache_dir)) # Ensure cache dir exists + + print(f"Attempting to load YOLO model '{filename}' from repo '{repo_id}'...") + print(f"Hugging Face Hub cache directory for YOLO: {yolo_cache_dir}") - try: - # Download the model file using huggingface_hub, caching it - yolo_model_filepath = hf_hub_download( - repo_id=repo_id, - filename=filename, - cache_dir=yolo_cache_dir, # Cache within our designated structure - local_files_only=False, # Allow download - force_download=False, # Use cache if available + try: + # Download the model file using huggingface_hub, caching it + yolo_model_filepath = hf_hub_download( + repo_id=repo_id, + filename=filename, + cache_dir=yolo_cache_dir, # Cache within our designated structure + local_files_only=False, # Allow download + force_download=False, # Use cache if available + ) + print(f"YOLO model file path: {yolo_model_filepath}") + + # Load the model using the downloaded file path + self._yolo = YOLOv10(yolo_model_filepath) + print("MDR YOLOv10 model loaded successfully.") + + # --- MODIFIED EXCEPTION HANDLING --- + except HfHubHTTPError as e: # <-- CHANGED THIS LINE + print( + f"ERROR: Failed to download/access YOLO model via Hugging Face Hub: {e}") # Slightly updated message + self._yolo = None + except FileNotFoundError as e: # Catch if hf_hub_download fails finding file OR YOLOv10 constructor fails + print(f"ERROR: YOLO model file not found or failed to load locally: {e}") # Slightly updated message + self._yolo = None + except Exception as e: + # Keep the general exception catch, but make the message more specific + print( + f"ERROR: An unexpected issue occurred loading YOLOv10 model from {yolo_cache_dir}/{filename}: {e}") + self._yolo = None + + elif YOLOv10 is None: + print("MDR YOLOv10 class not available. Layout detection skipped.") + + return self._yolo + + def analyze_image(self, image: Image, adjust_points: bool = False) -> MDRExtractionResult: + """Analyzes a single page image to extract layout and content.""" + print(" Engine: Analyzing image...") + optimizer = MDRImageOptimizer(image, adjust_points) + print(" Engine: Initial OCR...") + frags = list(self._ocr_engine.find_text_fragments(optimizer.image_np)) + print(f" Engine: {len(frags)} fragments found.") + optimizer.receive_fragments(frags) + frags = optimizer._fragments # Use adjusted fragments + print(" Engine: Layout detection...") + yolo = self._get_yolo_model() + raw_layouts = [] + if yolo: + try: + raw_layouts = list(self._run_yolo_detection(optimizer.image, yolo)) + print(f" Engine: {len(raw_layouts)} raw layouts found.") + except Exception: + import traceback, sys + traceback.print_exc(file=sys.stderr) + print(" Engine: Matching fragments...") + layouts = self._match_fragments_to_layouts(frags, raw_layouts) + if not layouts and frags: + # treat the whole page as one plain-text layout + page_rect = MDRRectangle( + lt=(0, 0), rt=(optimizer.image.width, 0), + lb=(0, optimizer.image.height), rb=(optimizer.image.width, optimizer.image.height) ) - print(f"YOLO model file path: {yolo_model_filepath}") - - # Load the model using the downloaded file path - self._yolo = YOLOv10(yolo_model_filepath) - print("MDR YOLOv10 model loaded successfully.") - - # --- MODIFIED EXCEPTION HANDLING --- - except HfHubHTTPError as e: # <-- CHANGED THIS LINE - print(f"ERROR: Failed to download/access YOLO model via Hugging Face Hub: {e}") # Slightly updated message - self._yolo = None - except FileNotFoundError as e: # Catch if hf_hub_download fails finding file OR YOLOv10 constructor fails - print(f"ERROR: YOLO model file not found or failed to load locally: {e}") # Slightly updated message - self._yolo = None - except Exception as e: - # Keep the general exception catch, but make the message more specific - print(f"ERROR: An unexpected issue occurred loading YOLOv10 model from {yolo_cache_dir}/{filename}: {e}") - self._yolo = None - - elif YOLOv10 is None: - print("MDR YOLOv10 class not available. Layout detection skipped.") - - return self._yolo - - def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult: - """Analyzes a single page image to extract layout and content.""" - print(" Engine: Analyzing image...") - optimizer = MDRImageOptimizer(image, adjust_points) - print(" Engine: Initial OCR...") - frags = list(self._ocr_engine.find_text_fragments(optimizer.image_np)) - print(f" Engine: {len(frags)} fragments found.") - optimizer.receive_fragments(frags) - frags = optimizer._fragments # Use adjusted fragments - print(" Engine: Layout detection...") - yolo = self._get_yolo_model() - raw_layouts = [] - if yolo: - try: - raw_layouts = list(self._run_yolo_detection(optimizer.image, yolo)) - print(f" Engine: {len(raw_layouts)} raw layouts found.") - except Exception: - import traceback, sys - traceback.print_exc(file=sys.stderr) - print(" Engine: Matching fragments...") - layouts = self._match_fragments_to_layouts(frags, raw_layouts) - if not layouts and frags: - # treat the whole page as one plain-text layout - page_rect = MDRRectangle( - lt=(0, 0), rt=(optimizer.image.width, 0), - lb=(0, optimizer.image.height), rb=(optimizer.image.width, optimizer.image.height) - ) - dummy = MDRPlainLayoutElement( - cls=MDRLayoutClass.PLAIN_TEXT, rect=page_rect, fragments=frags.copy() - ) - layouts.append(dummy) - print(" Engine: Removing overlaps...") - layouts = mdr_remove_overlap_layouts(layouts) - print(f" Engine: {len(layouts)} layouts after overlap removal.") - if self._ocr_each and layouts: - print(" Engine: OCR correction...") - self._run_ocr_correction(optimizer.image, layouts) - print(" Engine: Determining reading order...") - layouts = self._layout_reader.determine_reading_order(layouts, optimizer.image.size) - layouts = [l for l in layouts if self._should_keep_layout(l)] - print(f" Engine: {len(layouts)} layouts after filtering.") - if self._ext_table or self._ext_formula: - print(" Engine: Parsing tables/formulas...") - self._parse_special_layouts(layouts, optimizer) - print(" Engine: Merging fragments...") - [setattr(l, 'fragments', mdr_merge_fragments_into_lines(l.fragments)) for l in layouts] - print(" Engine: Finalizing coords...") - optimizer.finalize_layout_coords(layouts) - print(" Engine: Analysis complete.") - return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image) - - def _run_yolo_detection(self, img: Image, yolo: YOLOv10): - img_rgb = img.convert("RGB") - res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.20, - device=self._device, verbose=False) - - if not res or not res[0].boxes: - return - - plain_classes: set[MDRLayoutClass] = { - MDRLayoutClass.TITLE, - MDRLayoutClass.PLAIN_TEXT, - MDRLayoutClass.ABANDON, - MDRLayoutClass.FIGURE_CAPTION, - MDRLayoutClass.TABLE_CAPTION, - MDRLayoutClass.TABLE_FOOTNOTE, - MDRLayoutClass.FORMULA_CAPTION, - } - - for cls_id_t, xyxy_t in zip(res[0].boxes.cls, res[0].boxes.xyxy): - cls = MDRLayoutClass(int(cls_id_t)) - x1, y1, x2, y2 = map(float, xyxy_t) - rect = MDRRectangle((x1, y1), (x2, y1), (x1, y2), (x2, y2)) - if rect.area < 10: - continue - - if cls == MDRLayoutClass.TABLE: - yield MDRTableLayoutElement(rect=rect, fragments=[], parsed=None) - elif cls == MDRLayoutClass.ISOLATE_FORMULA: - yield MDRFormulaLayoutElement(rect=rect, fragments=[], latex=None) - elif cls in plain_classes: - yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[]) - - def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]: - if not frags or not layouts: - return layouts - layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts] - for frag in frags: - try: - frag_poly = Polygon(frag.rect) - frag_area = frag_poly.area - except: - continue - if not frag_poly.is_valid or frag_area < 1e-6: - continue - candidates = [] # (layout_idx, layout_area, overlap_ratio) - for idx, l_poly in enumerate(layout_polys): - if l_poly is None: + dummy = MDRPlainLayoutElement( + cls=MDRLayoutClass.PLAIN_TEXT, rect=page_rect, fragments=frags.copy() + ) + layouts.append(dummy) + print(" Engine: Removing overlaps...") + layouts = mdr_remove_overlap_layouts(layouts) + print(f" Engine: {len(layouts)} layouts after overlap removal.") + if self._ocr_each and layouts: + print(" Engine: OCR correction...") + self._run_ocr_correction(optimizer.image, layouts) + print(" Engine: Determining reading order...") + layouts = self._layout_reader.determine_reading_order(layouts, optimizer.image.size) + layouts = [l for l in layouts if self._should_keep_layout(l)] + print(f" Engine: {len(layouts)} layouts after filtering.") + if self._ext_table or self._ext_formula: + print(" Engine: Parsing tables/formulas...") + self._parse_special_layouts(layouts, optimizer) + print(" Engine: Merging fragments...") + [setattr(l, 'fragments', mdr_merge_fragments_into_lines(l.fragments)) for l in layouts] + print(" Engine: Finalizing coords...") + optimizer.finalize_layout_coords(layouts) + print(" Engine: Analysis complete.") + return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, + adjusted_image=optimizer.adjusted_image) + + def _run_yolo_detection(self, img: Image, yolo: YOLOv10): + img_rgb = img.convert("RGB") + res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.20, + device=self._device, verbose=False) + + if not res or not res[0].boxes: + return + + plain_classes: set[MDRLayoutClass] = { + MDRLayoutClass.TITLE, + MDRLayoutClass.PLAIN_TEXT, + MDRLayoutClass.ABANDON, + MDRLayoutClass.FIGURE_CAPTION, + MDRLayoutClass.TABLE_CAPTION, + MDRLayoutClass.TABLE_FOOTNOTE, + MDRLayoutClass.FORMULA_CAPTION, + } + + for cls_id_t, xyxy_t in zip(res[0].boxes.cls, res[0].boxes.xyxy): + cls = MDRLayoutClass(int(cls_id_t)) + x1, y1, x2, y2 = map(float, xyxy_t) + rect = MDRRectangle((x1, y1), (x2, y1), (x1, y2), (x2, y2)) + if rect.area < 10: continue + + if cls == MDRLayoutClass.TABLE: + yield MDRTableLayoutElement(rect=rect, fragments=[], parsed=None) + elif cls == MDRLayoutClass.ISOLATE_FORMULA: + yield MDRFormulaLayoutElement(rect=rect, fragments=[], latex=None) + elif cls in plain_classes: + yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[]) + + def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[ + MDRLayoutElement]: + if not frags or not layouts: + return layouts + layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts] + for frag in frags: try: - inter_area = frag_poly.intersection(l_poly).area + frag_poly = Polygon(frag.rect) + frag_area = frag_poly.area except: continue - overlap = inter_area / frag_area if frag_area > 0 else 0 - if overlap > 0.85: - candidates.append((idx, l_poly.area, overlap)) - if candidates: - candidates.sort(key=lambda x: (x[1], -x[2])) - best_idx = candidates[0][0] - layouts[best_idx].fragments.append(frag) - for l in layouts: - l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) - return layouts - - def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]): - for i, l in enumerate(layouts): - if l.cls == MDRLayoutClass.FIGURE: continue - try: mdr_correct_layout_fragments(self._ocr_engine, img, l) - except Exception as e: print(f" Engine: OCR correction error layout {i}: {e}") + if not frag_poly.is_valid or frag_area < 1e-6: + continue + candidates = [] # (layout_idx, layout_area, overlap_ratio) + for idx, l_poly in enumerate(layout_polys): + if l_poly is None: + continue + try: + inter_area = frag_poly.intersection(l_poly).area + except: + continue + overlap = inter_area / frag_area if frag_area > 0 else 0 + if overlap > 0.85: + candidates.append((idx, l_poly.area, overlap)) + if candidates: + candidates.sort(key=lambda x: (x[1], -x[2])) + best_idx = candidates[0][0] + layouts[best_idx].fragments.append(frag) + for l in layouts: + l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + return layouts - def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer): - img_to_clip = optimizer.image - for l in layouts: - if isinstance(l, MDRFormulaLayoutElement) and self._ext_formula: - try: - f_img = mdr_clip_from_image(img_to_clip, l.rect) - l.latex = self._latex_extractor.extract(f_img) if f_img.width > 1 and f_img.height > 1 else None - except Exception as e: - print(f" Engine: LaTeX extract error: {e}") - elif isinstance(l, MDRTableLayoutElement) and self._ext_table is not None: + def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]): + for i, l in enumerate(layouts): + if l.cls == MDRLayoutClass.FIGURE: continue try: - t_img = mdr_clip_from_image(img_to_clip, l.rect) - parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width > 1 and t_img.height > 1 else None + mdr_correct_layout_fragments(self._ocr_engine, img, l) except Exception as e: - print(f" Engine: Table parse error: {e}") - parsed = None - if parsed: - l.parsed = (parsed, self._ext_table) + print(f" Engine: OCR correction error layout {i}: {e}") + + def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer): + img_to_clip = optimizer.image + for l in layouts: + if isinstance(l, MDRFormulaLayoutElement) and self._ext_formula: + try: + f_img = mdr_clip_from_image(img_to_clip, l.rect) + l.latex = self._latex_extractor.extract(f_img) if f_img.width > 1 and f_img.height > 1 else None + except Exception as e: + print(f" Engine: LaTeX extract error: {e}") + elif isinstance(l, MDRTableLayoutElement) and self._ext_table is not None: + try: + t_img = mdr_clip_from_image(img_to_clip, l.rect) + parsed = self._table_parser.parse_table_image(t_img, + self._ext_table) if t_img.width > 1 and t_img.height > 1 else None + except Exception as e: + print(f" Engine: Table parse error: {e}") + parsed = None + if parsed: + l.parsed = (parsed, self._ext_table) + + def _should_keep_layout(self, l: MDRLayoutElement) -> bool: + if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True + return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA] - def _should_keep_layout(self, l: MDRLayoutElement) -> bool: - if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True - return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA] # --- MDR Page Section Linking --- class _MDR_LinkedShape: - """Internal helper for managing layout linking across pages.""" - - def __init__(self, layout: MDRLayoutElement): self.layout=layout; self.pre:list[MDRLayoutElement|None]=[None,None]; self.nex:list[MDRLayoutElement|None]=[None,None] - - @property - def distance2(self) -> float: x,y=self.layout.rect.lt; return x*x+y*y + """Internal helper for managing layout linking across pages.""" -class MDRPageSection: - """Represents a page's layouts for framework detection via linking.""" + def __init__(self, layout: MDRLayoutElement): self.layout = layout; self.pre: list[MDRLayoutElement | None] = [None, + None]; self.nex: \ + list[MDRLayoutElement | None] = [None, None] - def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]): - self._page_index = page_index; self._shapes = [_MDR_LinkedShape(l) for l in layouts]; self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0])) + @property + def distance2(self) -> float: x, y = self.layout.rect.lt; return x * x + y * y - @property - def page_index(self) -> int: return self._page_index - def find_framework_elements(self) -> list[MDRLayoutElement]: - """Identifies framework layouts based on links to other pages.""" - return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)] +class MDRPageSection: + """Represents a page's layouts for framework detection via linking.""" + + def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]): + self._page_index = page_index; + self._shapes = [_MDR_LinkedShape(l) for l in layouts]; + self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0])) + + @property + def page_index(self) -> int: + return self._page_index + + def find_framework_elements(self) -> list[MDRLayoutElement]: + """Identifies framework layouts based on links to other pages.""" + return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)] + + def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None: + """Links matching shapes between this section and the next.""" + if offset not in (1, 2): + return + matches_matrix = [[sn for sn in next_section._shapes if self._shapes_match(ss, sn)] for ss in self._shapes] + origin_pair = self._find_origin_pair(matches_matrix, next_section._shapes) + if origin_pair is None: + return + orig_s, orig_n = origin_pair + orig_s_pt = orig_s.layout.rect.lt + orig_n_pt = orig_n.layout.rect.lt + for i, s1 in enumerate(self._shapes): + potentials = matches_matrix[i] + if not potentials: + continue + r1_rel = self._relative_rect(orig_s_pt, s1.layout.rect) + best_s2 = None + max_ovr = -1.0 + for s2 in potentials: + r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect) + ovr = self._symmetric_iou(r1_rel, r2_rel) + if ovr > max_ovr: + max_ovr = ovr + best_s2 = s2 + if max_ovr >= 0.80 and best_s2 is not None: + s1.nex[offset - 1] = best_s2.layout + best_s2.pre[offset - 1] = s1.layout # Link both ways + + def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool: + l1 = s1.layout + l2 = s2.layout + sz1 = l1.rect.size + sz2 = l2.rect.size + thresh = 0.90 + if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: + return False + f1 = l1.fragments + f2 = l2.fragments + c1 = len(f1) + c2 = len(f2) + if c1 == 0 and c2 == 0: + return True + if c1 == 0 or c2 == 0: + return False + matches = 0 + used_f2 = [False] * c2 + for frag1 in f1: + best_j = -1 + max_sim = -1.0 + for j, frag2 in enumerate(f2): + if not used_f2[j]: + sim = self._fragment_sim(l1, l2, frag1, frag2) + if sim > max_sim: + max_sim = sim + best_j = j + if max_sim > 0.75: + matches += 1 + if best_j != -1: + used_f2[best_j] = True + max_c = max(c1, c2) + rate_frags = matches / max_c + return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95)) + + def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, + f2: MDROcrFragment) -> float: + r1_rel = self._relative_rect(l1.rect.lt, f1.rect) + r2_rel = self._relative_rect(l2.rect.lt, f2.rect) + geom_sim = self._symmetric_iou(r1_rel, r2_rel) + text_sim, _ = mdr_check_text_similarity(f1.text, f2.text) + return (geom_sim + text_sim) / 2.0 + + def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> \ + tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None: + best_pair = None + min_dist2 = float('inf') + for i, s1 in enumerate(self._shapes): + match_list = matches_matrix[i] + if not match_list: + continue + for s2 in match_list: + dist2 = s1.distance2 + s2.distance2 + if dist2 < min_dist2: + min_dist2 = dist2 + best_pair = (s1, s2) + return best_pair + + def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool: + if not thresholds: return False; idx = min(count, len(thresholds) - 1); return rate >= thresholds[idx] + + def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle: + ox, oy = origin + r = rect + return MDRRectangle(lt=(r.lt[0] - ox, r.lt[1] - oy), rt=(r.rt[0] - ox, r.rt[1] - oy), + lb=(r.lb[0] - ox, r.lb[1] - oy), rb=(r.rb[0] - ox, r.rb[1] - oy)) + + def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float: + try: + p1 = Polygon(r1) + p2 = Polygon(r2) + except: + return 0.0 + if not p1.is_valid or not p2.is_valid: + return 0.0 + try: + inter = p1.intersection(p2) + union = p1.union(p2) + except: + return 0.0 + if inter.is_empty or inter.area < 1e-6: + return 0.0 + union_area = union.area + return inter.area / union_area if union_area > 1e-6 else 1.0 - def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None: - """Links matching shapes between this section and the next.""" - if offset not in (1, 2): - return - matches_matrix = [[sn for sn in next_section._shapes if self._shapes_match(ss, sn)] for ss in self._shapes] - origin_pair = self._find_origin_pair(matches_matrix, next_section._shapes) - if origin_pair is None: - return - orig_s, orig_n = origin_pair - orig_s_pt = orig_s.layout.rect.lt - orig_n_pt = orig_n.layout.rect.lt - for i, s1 in enumerate(self._shapes): - potentials = matches_matrix[i] - if not potentials: - continue - r1_rel = self._relative_rect(orig_s_pt, s1.layout.rect) - best_s2 = None - max_ovr = -1.0 - for s2 in potentials: - r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect) - ovr = self._symmetric_iou(r1_rel, r2_rel) - if ovr > max_ovr: - max_ovr = ovr - best_s2 = s2 - if max_ovr >= 0.80 and best_s2 is not None: - s1.nex[offset - 1] = best_s2.layout - best_s2.pre[offset - 1] = s1.layout # Link both ways - - def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool: - l1 = s1.layout - l2 = s2.layout - sz1 = l1.rect.size - sz2 = l2.rect.size - thresh = 0.90 - if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: - return False - f1 = l1.fragments - f2 = l2.fragments - c1 = len(f1) - c2 = len(f2) - if c1 == 0 and c2 == 0: - return True - if c1 == 0 or c2 == 0: - return False - matches = 0 - used_f2 = [False] * c2 - for frag1 in f1: - best_j = -1 - max_sim = -1.0 - for j, frag2 in enumerate(f2): - if not used_f2[j]: - sim = self._fragment_sim(l1, l2, frag1, frag2) - if sim > max_sim: - max_sim = sim - best_j = j - if max_sim > 0.75: - matches += 1 - if best_j != -1: - used_f2[best_j] = True - max_c = max(c1, c2) - rate_frags = matches / max_c - return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95)) - - def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float: - r1_rel = self._relative_rect(l1.rect.lt, f1.rect) - r2_rel = self._relative_rect(l2.rect.lt, f2.rect) - geom_sim = self._symmetric_iou(r1_rel, r2_rel) - text_sim, _ = mdr_check_text_similarity(f1.text, f2.text) - return (geom_sim + text_sim) / 2.0 - - def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None: - best_pair = None - min_dist2 = float('inf') - for i, s1 in enumerate(self._shapes): - match_list = matches_matrix[i] - if not match_list: - continue - for s2 in match_list: - dist2 = s1.distance2 + s2.distance2 - if dist2 < min_dist2: - min_dist2 = dist2 - best_pair = (s1, s2) - return best_pair - - def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool: - if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx] - - def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle: - ox, oy = origin - r = rect - return MDRRectangle(lt=(r.lt[0] - ox, r.lt[1] - oy), rt=(r.rt[0] - ox, r.rt[1] - oy), lb=(r.lb[0] - ox, r.lb[1] - oy), rb=(r.rb[0] - ox, r.rb[1] - oy)) - - def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float: - try: - p1 = Polygon(r1) - p2 = Polygon(r2) - except: - return 0.0 - if not p1.is_valid or not p2.is_valid: - return 0.0 - try: - inter = p1.intersection(p2) - union = p1.union(p2) - except: - return 0.0 - if inter.is_empty or inter.area < 1e-6: - return 0.0 - union_area = union.area - return inter.area / union_area if union_area > 1e-6 else 1.0 # --- MDR Document Iterator --- -_MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context +_MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context + @dataclass class MDRProcessingParams: - """Parameters for processing a document.""" - pdf: str | FitzDocument; page_indexes: Iterable[int] | None; report_progress: MDRProgressReportCallback | None + """Parameters for processing a document.""" + pdf: str | FitzDocument; + page_indexes: Iterable[int] | None; + report_progress: MDRProgressReportCallback | None + class MDRDocumentIterator: - """Iterates through document pages, handles context, and calls the extraction engine.""" - - def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, debug_dir_path: str | None): - self._debug_dir = debug_dir_path - self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, ocr_for_each_layouts=(ocr_level==MDROcrLevel.OncePerLayout), extract_formula=extract_formula, extract_table_format=extract_table_format) - - def iterate_sections(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]: - """Yields page index, extraction result, and content layouts for each requested page.""" - for res, sec in self._process_and_link_sections(params): - framework = set(sec.find_framework_elements()); content = [l for l in res.layouts if l not in framework]; yield sec.page_index, res, content - - def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[tuple[MDRExtractionResult, MDRPageSection], None, None]: - queue: list[tuple[MDRExtractionResult, MDRPageSection]] = [] - for page_idx, res in self._run_extraction_on_pages(params): - cur_sec = MDRPageSection(page_idx, res.layouts) - for i, (_, prev_sec) in enumerate(queue): - offset = len(queue) - i - if offset <= _MDR_CONTEXT_PAGES: - prev_sec.link_to_next(cur_sec, offset) - queue.append((res, cur_sec)) - if len(queue) > _MDR_CONTEXT_PAGES: - yield queue.pop(0) - for res, sec in queue: - yield res, sec - - def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]: - if self._debug_dir: mdr_ensure_directory(self._debug_dir) - doc, should_close = None, False - if isinstance(params.pdf, str): - try: doc = fitz.open(params.pdf); should_close = True - except Exception as e: print(f"ERROR: PDF open failed: {e}"); return - elif isinstance(params.pdf, FitzDocument): doc = params.pdf - else: print(f"ERROR: Invalid PDF type: {type(params.pdf)}"); return - scan_idxs, enable_idxs = self._get_page_ranges(doc, params.page_indexes) - enable_set = set(enable_idxs); total_scan = len(scan_idxs) - try: - for i, page_idx in enumerate(scan_idxs): - print(f" Iterator: Processing page {page_idx+1}/{doc.page_count} (Scan {i+1}/{total_scan})...") + """Iterates through document pages, handles context, and calls the extraction engine.""" + + def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, + extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, + debug_dir_path: str | None): + self._debug_dir = debug_dir_path + self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, + ocr_for_each_layouts=(ocr_level == MDROcrLevel.OncePerLayout), + extract_formula=extract_formula, extract_table_format=extract_table_format) + + def iterate_sections(self, params: MDRProcessingParams) -> Generator[ + tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]: + """Yields page index, extraction result, and content layouts for each requested page.""" + for res, sec in self._process_and_link_sections(params): + framework = set(sec.find_framework_elements()); + content = [l for l in res.layouts if l not in framework]; + yield sec.page_index, res, content + + def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[ + tuple[MDRExtractionResult, MDRPageSection], None, None]: + queue: list[tuple[MDRExtractionResult, MDRPageSection]] = [] + for page_idx, res in self._run_extraction_on_pages(params): + cur_sec = MDRPageSection(page_idx, res.layouts) + for i, (_, prev_sec) in enumerate(queue): + offset = len(queue) - i + if offset <= _MDR_CONTEXT_PAGES: + prev_sec.link_to_next(cur_sec, offset) + queue.append((res, cur_sec)) + if len(queue) > _MDR_CONTEXT_PAGES: + yield queue.pop(0) + for res, sec in queue: + yield res, sec + + def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[ + tuple[int, MDRExtractionResult], None, None]: + if self._debug_dir: mdr_ensure_directory(self._debug_dir) + doc, should_close = None, False + if isinstance(params.pdf, str): + try: + doc = fitz.open(params.pdf); should_close = True + except Exception as e: + print(f"ERROR: PDF open failed: {e}"); return + elif isinstance(params.pdf, FitzDocument): + doc = params.pdf + else: + print(f"ERROR: Invalid PDF type: {type(params.pdf)}"); return + scan_idxs, enable_idxs = self._get_page_ranges(doc, params.page_indexes) + enable_set = set(enable_idxs); + total_scan = len(scan_idxs) + try: + for i, page_idx in enumerate(scan_idxs): + print(f" Iterator: Processing page {page_idx + 1}/{doc.page_count} (Scan {i + 1}/{total_scan})...") + try: + page = doc.load_page(page_idx) + img = self._render_page_image(page, 300) + res = self._engine.analyze_image(image=img, adjust_points=False) # Engine analyzes image + if self._debug_dir: + self._save_debug_plot(img, page_idx, res, self._debug_dir) + if page_idx in enable_set: + yield page_idx, res # Yield result for requested pages + if params.report_progress: + params.report_progress(i + 1, total_scan) + except Exception as e: + print(f" Iterator: Page {page_idx + 1} processing error: {e}") + finally: + if should_close and doc: doc.close() + + def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int] | None) -> tuple[Sequence[int], Sequence[int]]: + count = doc.page_count + if idxs is None: + all_p = list(range(count)) + return all_p, all_p + enable = set() + scan = set() + for i in idxs: + if 0 <= i < count: + enable.add(i) + [scan.add(j) for j in range(max(0, i - _MDR_CONTEXT_PAGES), min(count, i + _MDR_CONTEXT_PAGES + 1))] + return sorted(list(scan)), sorted(list(enable)) + + def _render_page_image(self, page: FitzPage, dpi: int) -> Image: + mat = FitzMatrix(dpi / 72.0, dpi / 72.0) + pix = page.get_pixmap(matrix=mat, alpha=False) + return frombytes("RGB", (pix.width, pix.height), pix.samples) + + def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str): try: - page = doc.load_page(page_idx) - img = self._render_page_image(page, 300) - res = self._engine.analyze_image(image=img, adjust_points=False) # Engine analyzes image - if self._debug_dir: - self._save_debug_plot(img, page_idx, res, self._debug_dir) - if page_idx in enable_set: - yield page_idx, res # Yield result for requested pages - if params.report_progress: - params.report_progress(i + 1, total_scan) + plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy() + mdr_plot_layout(plot_img, res.layouts) + plot_img.save(os.path.join(path, f"mdr_plot_page_{idx + 1}.png")) except Exception as e: - print(f" Iterator: Page {page_idx + 1} processing error: {e}") - finally: - if should_close and doc: doc.close() - - def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]: - count = doc.page_count - if idxs is None: - all_p = list(range(count)) - return all_p, all_p - enable = set() - scan = set() - for i in idxs: - if 0 <= i < count: - enable.add(i) - [scan.add(j) for j in range(max(0, i - _MDR_CONTEXT_PAGES), min(count, i + _MDR_CONTEXT_PAGES + 1))] - return sorted(list(scan)), sorted(list(enable)) - - def _render_page_image(self, page: FitzPage, dpi: int) -> Image: - mat = FitzMatrix(dpi / 72.0, dpi / 72.0) - pix = page.get_pixmap(matrix=mat, alpha=False) - return frombytes("RGB", (pix.width, pix.height), pix.samples) - - def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str): - try: - plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy() - mdr_plot_layout(plot_img, res.layouts) - plot_img.save(os.path.join(path, f"mdr_plot_page_{idx + 1}.png")) - except Exception as e: - print(f" Iterator: Plot generation error page {idx + 1}: {e}") + print(f" Iterator: Plot generation error page {idx + 1}: {e}") + # --- MagicDataReadiness Main Processor --- class MagicPDFProcessor: - """ - Main class for processing PDF documents to extract structured data blocks - using the MagicDataReadiness pipeline. - """ - - def __init__(self, device: Literal["cpu", "cuda"]="cuda", model_dir_path: str="./mdr_models", ocr_level: MDROcrLevel=MDROcrLevel.Once, extract_formula: bool=True, extract_table_format: MDRExtractedTableFormat|None=None, debug_dir_path: str|None=None): - """ - Initializes the MagicPDFProcessor. - Args: - device: Computation device ('cpu' or 'cuda'). Defaults to 'cuda'. Fallbacks to 'cpu' if CUDA not available. - model_dir_path: Path to directory for storing/caching downloaded models. Defaults to './mdr_models'. - ocr_level: Level of OCR application (Once per page or Once per layout). Defaults to Once per page. - extract_formula: Whether to attempt LaTeX extraction from formula images. Defaults to True. - extract_table_format: Desired format for extracted table content (LATEX, MARKDOWN, HTML, DISABLE, or None). - Defaults to LATEX if CUDA is available, otherwise DISABLE. - debug_dir_path: Optional path to save debug plots and intermediate files. Defaults to None (disabled). """ - actual_dev = device if torch.cuda.is_available() else "cpu"; print(f"MagicPDFProcessor using device: {actual_dev}.") - if extract_table_format is None: extract_table_format = MDRExtractedTableFormat.LATEX if actual_dev=="cuda" else MDRExtractedTableFormat.DISABLE - table_fmt_internal: MDRTableLayoutParsedFormat|None = None - if extract_table_format==MDRExtractedTableFormat.LATEX: table_fmt_internal=MDRTableLayoutParsedFormat.LATEX - elif extract_table_format==MDRExtractedTableFormat.MARKDOWN: table_fmt_internal=MDRTableLayoutParsedFormat.MARKDOWN - elif extract_table_format==MDRExtractedTableFormat.HTML: table_fmt_internal=MDRTableLayoutParsedFormat.HTML - self._iterator = MDRDocumentIterator(device=actual_dev, model_dir_path=model_dir_path, ocr_level=ocr_level, extract_formula=extract_formula, extract_table_format=table_fmt_internal, debug_dir_path=debug_dir_path) - print("MagicPDFProcessor initialized.") - - def process_document(self, pdf_input: str|FitzDocument, report_progress: MDRProgressReportCallback|None=None) -> Generator[MDRStructuredBlock, None, None]: + Main class for processing PDF documents to extract structured data blocks + using the MagicDataReadiness pipeline. """ - Processes the entire PDF document and yields all extracted structured blocks. - Args: - pdf_input: Path to the PDF file or a loaded fitz.Document object. - report_progress: Optional callback function for progress updates (receives completed_scan_pages, total_scan_pages). - Yields: - MDRStructuredBlock: An extracted block (MDRTextBlock, MDRTableBlock, etc.). - """ - print(f"Processing document: {pdf_input if isinstance(pdf_input, str) else 'FitzDocument object'}") - for _, blocks, _ in self.process_document_pages(pdf_input=pdf_input, report_progress=report_progress, page_indexes=None): - yield from blocks - print("Document processing complete.") - def process_document_pages(self, pdf_input: str|FitzDocument, page_indexes: Iterable[int]|None=None, report_progress: MDRProgressReportCallback|None=None) -> Generator[tuple[int, list[MDRStructuredBlock], Image], None, None]: - """ - Processes specific pages (or all if page_indexes is None) of the PDF document. - Yields results page by page, including the page index, extracted blocks, and the original page image. - Args: - pdf_input: Path to the PDF file or a loaded fitz.Document object. - page_indexes: An iterable of 0-based page indices to process. If None, processes all pages. - report_progress: Optional callback function for progress updates. - Yields: - tuple[int, list[MDRStructuredBlock], Image]: - - page_index (0-based) - - list of extracted MDRStructuredBlock objects for that page - - PIL Image object of the original rendered page - """ - params = MDRProcessingParams(pdf=pdf_input, page_indexes=page_indexes, report_progress=report_progress) - page_count = 0 - for page_idx, extraction_result, content_layouts in self._iterator.iterate_sections(params): - page_count += 1 - print(f"Processor: Converting layouts to blocks for page {page_idx+1}...") - blocks = self._create_structured_blocks(extraction_result, content_layouts) - print(f"Processor: Analyzing paragraph structure for page {page_idx+1}...") - self._analyze_paragraph_structure(blocks) - print(f"Processor: Yielding results for page {page_idx+1}.") - yield page_idx, blocks, extraction_result.extracted_image # Yield original image - print(f"Processor: Finished processing {page_count} pages.") - - def _create_structured_blocks(self, result: MDRExtractionResult, layouts: list[MDRLayoutElement]) -> list[MDRStructuredBlock]: - """Converts MDRLayoutElement objects into MDRStructuredBlock objects.""" - temp_store: list[tuple[MDRLayoutElement, MDRStructuredBlock]] = [] - for layout in layouts: - if isinstance(layout, MDRPlainLayoutElement): self._add_plain_block(temp_store, layout, result) - elif isinstance(layout, MDRTableLayoutElement): temp_store.append((layout, self._create_table_block(layout, result))) - elif isinstance(layout, MDRFormulaLayoutElement): temp_store.append((layout, self._create_formula_block(layout, result))) - self._assign_relative_font_sizes(temp_store) - return [block for _, block in temp_store] - - def _analyze_paragraph_structure(self, blocks: list[MDRStructuredBlock]): - """ - Calculates indentation and line-end heuristics for MDRTextBlocks - based on page-level text boundaries and average line height. - """ - # Define constants for clarity and maintainability - MIN_VALID_HEIGHT = 1e-6 - # Heuristic: Indent if first line starts more than 1.0 * avg line height from page text start - INDENTATION_THRESHOLD_FACTOR = 1.0 - # Heuristic: Last line touches end if it ends less than 1.0 * avg line height from page text end - LINE_END_THRESHOLD_FACTOR = 1.0 - - # Calculate average line height and text boundaries for the relevant text blocks on the page - page_avg_line_height, page_min_x, page_max_x = self._calculate_text_range( - (b for b in blocks if isinstance(b, MDRTextBlock) and b.kind != MDRTextKind.ABANDON) - ) - - # Avoid calculations if page metrics are invalid (e.g., no text, zero height) - if page_avg_line_height <= MIN_VALID_HEIGHT: - return - - # Iterate through each block to determine its paragraph properties - for block in blocks: - # Process only valid text blocks with actual text content - if not isinstance(block, MDRTextBlock) or block.kind == MDRTextKind.ABANDON or not block.texts: - continue - - # Use calculated page-level metrics for consistency in thresholds - avg_line_height = page_avg_line_height - page_text_start_x = page_min_x - page_text_end_x = page_max_x - - # Get the first and last text spans (assumed to be lines after merging) within the block - first_text_span = block.texts[0] - last_text_span = block.texts[-1] - - try: - # --- Calculate Indentation --- - # Estimate the starting x-coordinate of the first line (average of left top/bottom) - first_line_start_x = (first_text_span.rect.lt[0] + first_text_span.rect.lb[0]) / 2.0 - # Calculate the difference between the first line's start and the page's text start boundary - indentation_delta = first_line_start_x - page_text_start_x - # Determine indentation based on the heuristic threshold relative to average line height - block.has_paragraph_indentation = indentation_delta > (avg_line_height * INDENTATION_THRESHOLD_FACTOR) - - # --- Calculate Last Line End --- - # Estimate the ending x-coordinate of the last line (average of right top/bottom) - last_line_end_x = (last_text_span.rect.rt[0] + last_text_span.rect.rb[0]) / 2.0 - # Calculate the difference between the page's text end boundary and the last line's end - line_end_delta = page_text_end_x - last_line_end_x - # Determine if the last line reaches near the end based on the heuristic threshold - block.last_line_touch_end = line_end_delta < (avg_line_height * LINE_END_THRESHOLD_FACTOR) - - except Exception as e: - # Handle potential errors during calculation (e.g., invalid rect data) - print(f"Warn: Error calculating paragraph structure for block: {e}") - # Default to False if calculation fails to ensure attributes are set - block.has_paragraph_indentation = False - block.last_line_touch_end = False - - def _calculate_text_range(self, blocks_iter: Iterable[MDRStructuredBlock]) -> tuple[float, float, float]: - """Calculates average line height and min/max x-coordinates for text.""" - h_sum = 0.0 - count = 0 - x1 = float('inf') - x2 = float('-inf') - for b in blocks_iter: - if not isinstance(b, MDRTextBlock) or b.kind == MDRTextKind.ABANDON: - continue - for t in b.texts: - _, h = t.rect.size - if h > 1e-6: # Use small threshold for valid height - h_sum += h - count += 1 - tx1, _, tx2, _ = t.rect.wrapper - x1 = min(x1, tx1) - x2 = max(x2, tx2) - if count == 0: - return 0.0, 0.0, 0.0 - mean_h = h_sum / count - x1 = 0.0 if x1 == float('inf') else x1 - x2 = 0.0 if x2 == float('-inf') else x2 - return mean_h, x1, x2 - - def _add_plain_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], layout: MDRPlainLayoutElement, result: MDRExtractionResult): - """Creates MDRStructuredBlocks for plain layout types.""" - cls = layout.cls - texts = self._convert_fragments_to_spans(layout.fragments) - if cls == MDRLayoutClass.TITLE: - store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.TITLE))) - elif cls == MDRLayoutClass.PLAIN_TEXT: - store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.PLAIN_TEXT))) - elif cls == MDRLayoutClass.ABANDON: - store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.ABANDON))) - elif cls == MDRLayoutClass.FIGURE: - store.append((layout, MDRFigureBlock(layout.rect, [], 0.0, mdr_clip_layout(result, layout)))) - elif cls == MDRLayoutClass.FIGURE_CAPTION: - block = self._find_previous_block(store, MDRFigureBlock) - if block: block.texts.extend(texts) - elif cls == MDRLayoutClass.TABLE_CAPTION or cls == MDRLayoutClass.TABLE_FOOTNOTE: - block = self._find_previous_block(store, MDRTableBlock) - if block: block.texts.extend(texts) - elif cls == MDRLayoutClass.FORMULA_CAPTION: - block = self._find_previous_block(store, MDRFormulaBlock) - if block: block.texts.extend(texts) - - def _find_previous_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], block_type: type) -> MDRStructuredBlock | None: - """Finds the most recent block of a specific type.""" - for i in range(len(store) - 1, -1, -1): - _, block = store[i] - if isinstance(block, block_type): - return block - return None - - def _create_table_block(self, layout: MDRTableLayoutElement, result: MDRExtractionResult) -> MDRTableBlock: - """Converts MDRTableLayoutElement to MDRTableBlock.""" - fmt = MDRTableFormat.UNRECOGNIZABLE - content = "" - if layout.parsed: - p_content, p_fmt = layout.parsed - can_use = not (p_fmt == MDRTableLayoutParsedFormat.LATEX and mdr_contains_cjka("".join(f.text for f in layout.fragments))) - if can_use: - content = p_content - if p_fmt == MDRTableLayoutParsedFormat.LATEX: - fmt = MDRTableFormat.LATEX - elif p_fmt == MDRTableLayoutParsedFormat.MARKDOWN: - fmt = MDRTableFormat.MARKDOWN - elif p_fmt == MDRTableLayoutParsedFormat.HTML: - fmt = MDRTableFormat.HTML - return MDRTableBlock(layout.rect, [], 0.0, fmt, content, mdr_clip_layout(result, layout)) - - def _create_formula_block(self, layout: MDRFormulaLayoutElement, result: MDRExtractionResult) -> MDRFormulaBlock: - """Converts MDRFormulaLayoutElement to MDRFormulaBlock.""" - content = layout.latex if layout.latex and not mdr_contains_cjka("".join(f.text for f in layout.fragments)) else None - return MDRFormulaBlock(layout.rect, [], 0.0, content, mdr_clip_layout(result, layout)) - - def _assign_relative_font_sizes(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]]): - """Calculates and assigns relative font size (0-1) to blocks.""" - sizes = [] - for l, _ in store: - heights = [f.rect.size[1] for f in l.fragments if f.rect.size[1] > 1e-6] # Use small threshold - avg_h = sum(heights) / len(heights) if heights else 0.0 - sizes.append(avg_h) - valid = [s for s in sizes if s > 1e-6] - min_s, max_s = (min(valid), max(valid)) if valid else (0.0, 0.0) - rng = max_s - min_s - if rng < 1e-6: - [setattr(b, 'font_size', 0.0) for _, b in store] - else: - [setattr(b, 'font_size', (s - min_s) / rng if s > 1e-6 else 0.0) for s, (_, b) in zip(sizes, store)] + def __init__(self, device: Literal["cpu", "cuda"] = "cuda", model_dir_path: str = "./mdr_models", + ocr_level: MDROcrLevel = MDROcrLevel.Once, extract_formula: bool = True, + extract_table_format: MDRExtractedTableFormat | None = None, debug_dir_path: str | None = None): + """ + Initializes the MagicPDFProcessor. + Args: + device: Computation device ('cpu' or 'cuda'). Defaults to 'cuda'. Fallbacks to 'cpu' if CUDA not available. + model_dir_path: Path to directory for storing/caching downloaded models. Defaults to './mdr_models'. + ocr_level: Level of OCR application (Once per page or Once per layout). Defaults to Once per page. + extract_formula: Whether to attempt LaTeX extraction from formula images. Defaults to True. + extract_table_format: Desired format for extracted table content (LATEX, MARKDOWN, HTML, DISABLE, or None). + Defaults to LATEX if CUDA is available, otherwise DISABLE. + debug_dir_path: Optional path to save debug plots and intermediate files. Defaults to None (disabled). + """ + actual_dev = device if torch.cuda.is_available() else "cpu"; + print(f"MagicPDFProcessor using device: {actual_dev}.") + if extract_table_format is None: extract_table_format = MDRExtractedTableFormat.LATEX if actual_dev == "cuda" else MDRExtractedTableFormat.DISABLE + table_fmt_internal: MDRTableLayoutParsedFormat | None = None + if extract_table_format == MDRExtractedTableFormat.LATEX: + table_fmt_internal = MDRTableLayoutParsedFormat.LATEX + elif extract_table_format == MDRExtractedTableFormat.MARKDOWN: + table_fmt_internal = MDRTableLayoutParsedFormat.MARKDOWN + elif extract_table_format == MDRExtractedTableFormat.HTML: + table_fmt_internal = MDRTableLayoutParsedFormat.HTML + self._iterator = MDRDocumentIterator(device=actual_dev, model_dir_path=model_dir_path, ocr_level=ocr_level, + extract_formula=extract_formula, extract_table_format=table_fmt_internal, + debug_dir_path=debug_dir_path) + print("MagicPDFProcessor initialized.") + + def process_document(self, pdf_input: str | FitzDocument, + report_progress: MDRProgressReportCallback | None = None) -> Generator[ + MDRStructuredBlock, None, None]: + """ + Processes the entire PDF document and yields all extracted structured blocks. + Args: + pdf_input: Path to the PDF file or a loaded fitz.Document object. + report_progress: Optional callback function for progress updates (receives completed_scan_pages, total_scan_pages). + Yields: + MDRStructuredBlock: An extracted block (MDRTextBlock, MDRTableBlock, etc.). + """ + print(f"Processing document: {pdf_input if isinstance(pdf_input, str) else 'FitzDocument object'}") + for _, blocks, _ in self.process_document_pages(pdf_input=pdf_input, report_progress=report_progress, + page_indexes=None): + yield from blocks + print("Document processing complete.") + + def process_document_pages(self, pdf_input: str | FitzDocument, page_indexes: Iterable[int] | None = None, + report_progress: MDRProgressReportCallback | None = None) -> Generator[ + tuple[int, list[MDRStructuredBlock], Image], None, None]: + """ + Processes specific pages (or all if page_indexes is None) of the PDF document. + Yields results page by page, including the page index, extracted blocks, and the original page image. + Args: + pdf_input: Path to the PDF file or a loaded fitz.Document object. + page_indexes: An iterable of 0-based page indices to process. If None, processes all pages. + report_progress: Optional callback function for progress updates. + Yields: + tuple[int, list[MDRStructuredBlock], Image]: + - page_index (0-based) + - list of extracted MDRStructuredBlock objects for that page + - PIL Image object of the original rendered page + """ + params = MDRProcessingParams(pdf=pdf_input, page_indexes=page_indexes, report_progress=report_progress) + page_count = 0 + for page_idx, extraction_result, content_layouts in self._iterator.iterate_sections(params): + page_count += 1 + print(f"Processor: Converting layouts to blocks for page {page_idx + 1}...") + blocks = self._create_structured_blocks(extraction_result, content_layouts) + print(f"Processor: Analyzing paragraph structure for page {page_idx + 1}...") + self._analyze_paragraph_structure(blocks) + print(f"Processor: Yielding results for page {page_idx + 1}.") + yield page_idx, blocks, extraction_result.extracted_image # Yield original image + print(f"Processor: Finished processing {page_count} pages.") + + def _create_structured_blocks(self, result: MDRExtractionResult, layouts: list[MDRLayoutElement]) -> list[ + MDRStructuredBlock]: + """Converts MDRLayoutElement objects into MDRStructuredBlock objects.""" + temp_store: list[tuple[MDRLayoutElement, MDRStructuredBlock]] = [] + for layout in layouts: + if isinstance(layout, MDRPlainLayoutElement): + self._add_plain_block(temp_store, layout, result) + elif isinstance(layout, MDRTableLayoutElement): + temp_store.append((layout, self._create_table_block(layout, result))) + elif isinstance(layout, MDRFormulaLayoutElement): + temp_store.append((layout, self._create_formula_block(layout, result))) + self._assign_relative_font_sizes(temp_store) + return [block for _, block in temp_store] + + def _analyze_paragraph_structure(self, blocks: list[MDRStructuredBlock]): + """ + Calculates indentation and line-end heuristics for MDRTextBlocks + based on page-level text boundaries and average line height. + """ + # Define constants for clarity and maintainability + MIN_VALID_HEIGHT = 1e-6 + # Heuristic: Indent if first line starts more than 1.0 * avg line height from page text start + INDENTATION_THRESHOLD_FACTOR = 1.0 + # Heuristic: Last line touches end if it ends less than 1.0 * avg line height from page text end + LINE_END_THRESHOLD_FACTOR = 1.0 + + # Calculate average line height and text boundaries for the relevant text blocks on the page + page_avg_line_height, page_min_x, page_max_x = self._calculate_text_range( + (b for b in blocks if isinstance(b, MDRTextBlock) and b.kind != MDRTextKind.ABANDON) + ) + + # Avoid calculations if page metrics are invalid (e.g., no text, zero height) + if page_avg_line_height <= MIN_VALID_HEIGHT: + return + + # Iterate through each block to determine its paragraph properties + for block in blocks: + # Process only valid text blocks with actual text content + if not isinstance(block, MDRTextBlock) or block.kind == MDRTextKind.ABANDON or not block.texts: + continue + + # Use calculated page-level metrics for consistency in thresholds + avg_line_height = page_avg_line_height + page_text_start_x = page_min_x + page_text_end_x = page_max_x + + # Get the first and last text spans (assumed to be lines after merging) within the block + first_text_span = block.texts[0] + last_text_span = block.texts[-1] + + try: + # --- Calculate Indentation --- + # Estimate the starting x-coordinate of the first line (average of left top/bottom) + first_line_start_x = (first_text_span.rect.lt[0] + first_text_span.rect.lb[0]) / 2.0 + # Calculate the difference between the first line's start and the page's text start boundary + indentation_delta = first_line_start_x - page_text_start_x + # Determine indentation based on the heuristic threshold relative to average line height + block.has_paragraph_indentation = indentation_delta > (avg_line_height * INDENTATION_THRESHOLD_FACTOR) + + # --- Calculate Last Line End --- + # Estimate the ending x-coordinate of the last line (average of right top/bottom) + last_line_end_x = (last_text_span.rect.rt[0] + last_text_span.rect.rb[0]) / 2.0 + # Calculate the difference between the page's text end boundary and the last line's end + line_end_delta = page_text_end_x - last_line_end_x + # Determine if the last line reaches near the end based on the heuristic threshold + block.last_line_touch_end = line_end_delta < (avg_line_height * LINE_END_THRESHOLD_FACTOR) + + except Exception as e: + # Handle potential errors during calculation (e.g., invalid rect data) + print(f"Warn: Error calculating paragraph structure for block: {e}") + # Default to False if calculation fails to ensure attributes are set + block.has_paragraph_indentation = False + block.last_line_touch_end = False + + def _calculate_text_range(self, blocks_iter: Iterable[MDRStructuredBlock]) -> tuple[float, float, float]: + """Calculates average line height and min/max x-coordinates for text.""" + h_sum = 0.0 + count = 0 + x1 = float('inf') + x2 = float('-inf') + for b in blocks_iter: + if not isinstance(b, MDRTextBlock) or b.kind == MDRTextKind.ABANDON: + continue + for t in b.texts: + _, h = t.rect.size + if h > 1e-6: # Use small threshold for valid height + h_sum += h + count += 1 + tx1, _, tx2, _ = t.rect.wrapper + x1 = min(x1, tx1) + x2 = max(x2, tx2) + if count == 0: + return 0.0, 0.0, 0.0 + mean_h = h_sum / count + x1 = 0.0 if x1 == float('inf') else x1 + x2 = 0.0 if x2 == float('-inf') else x2 + return mean_h, x1, x2 + + def _add_plain_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], layout: MDRPlainLayoutElement, + result: MDRExtractionResult): + """Creates MDRStructuredBlocks for plain layout types.""" + cls = layout.cls + texts = self._convert_fragments_to_spans(layout.fragments) + if cls == MDRLayoutClass.TITLE: + store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.TITLE))) + elif cls == MDRLayoutClass.PLAIN_TEXT: + store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.PLAIN_TEXT))) + elif cls == MDRLayoutClass.ABANDON: + store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.ABANDON))) + elif cls == MDRLayoutClass.FIGURE: + store.append((layout, MDRFigureBlock(layout.rect, [], 0.0, mdr_clip_layout(result, layout)))) + elif cls == MDRLayoutClass.FIGURE_CAPTION: + block = self._find_previous_block(store, MDRFigureBlock) + if block: block.texts.extend(texts) + elif cls == MDRLayoutClass.TABLE_CAPTION or cls == MDRLayoutClass.TABLE_FOOTNOTE: + block = self._find_previous_block(store, MDRTableBlock) + if block: block.texts.extend(texts) + elif cls == MDRLayoutClass.FORMULA_CAPTION: + block = self._find_previous_block(store, MDRFormulaBlock) + if block: block.texts.extend(texts) + + def _find_previous_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], + block_type: type) -> MDRStructuredBlock | None: + """Finds the most recent block of a specific type.""" + for i in range(len(store) - 1, -1, -1): + _, block = store[i] + if isinstance(block, block_type): + return block + return None + + def _create_table_block(self, layout: MDRTableLayoutElement, result: MDRExtractionResult) -> MDRTableBlock: + """Converts MDRTableLayoutElement to MDRTableBlock.""" + fmt = MDRTableFormat.UNRECOGNIZABLE + content = "" + if layout.parsed: + p_content, p_fmt = layout.parsed + can_use = not (p_fmt == MDRTableLayoutParsedFormat.LATEX and mdr_contains_cjka( + "".join(f.text for f in layout.fragments))) + if can_use: + content = p_content + if p_fmt == MDRTableLayoutParsedFormat.LATEX: + fmt = MDRTableFormat.LATEX + elif p_fmt == MDRTableLayoutParsedFormat.MARKDOWN: + fmt = MDRTableFormat.MARKDOWN + elif p_fmt == MDRTableLayoutParsedFormat.HTML: + fmt = MDRTableFormat.HTML + return MDRTableBlock(layout.rect, [], 0.0, fmt, content, mdr_clip_layout(result, layout)) + + def _create_formula_block(self, layout: MDRFormulaLayoutElement, result: MDRExtractionResult) -> MDRFormulaBlock: + """Converts MDRFormulaLayoutElement to MDRFormulaBlock.""" + content = layout.latex if layout.latex and not mdr_contains_cjka( + "".join(f.text for f in layout.fragments)) else None + return MDRFormulaBlock(layout.rect, [], 0.0, content, mdr_clip_layout(result, layout)) + + def _assign_relative_font_sizes(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]]): + """Calculates and assigns relative font size (0-1) to blocks.""" + sizes = [] + for l, _ in store: + heights = [f.rect.size[1] for f in l.fragments if f.rect.size[1] > 1e-6] # Use small threshold + avg_h = sum(heights) / len(heights) if heights else 0.0 + sizes.append(avg_h) + valid = [s for s in sizes if s > 1e-6] + min_s, max_s = (min(valid), max(valid)) if valid else (0.0, 0.0) + rng = max_s - min_s + if rng < 1e-6: + [setattr(b, 'font_size', 0.0) for _, b in store] + else: + [setattr(b, 'font_size', (s - min_s) / rng if s > 1e-6 else 0.0) for s, (_, b) in zip(sizes, store)] + + def _convert_fragments_to_spans(self, frags: list[MDROcrFragment]) -> list[MDRTextSpan]: + """Converts MDROcrFragment list to MDRTextSpan list.""" + return [MDRTextSpan(f.text, f.rank, f.rect) for f in frags] - def _convert_fragments_to_spans(self, frags: list[MDROcrFragment]) -> list[MDRTextSpan]: - """Converts MDROcrFragment list to MDRTextSpan list.""" - return [MDRTextSpan(f.text, f.rank, f.rect) for f in frags] # --- MagicDataReadiness Example Usage --- if __name__ == '__main__': - print("="*60) + print("=" * 60) print(" MagicDataReadiness PDF Processor - Example Usage") - print("="*60) + print("=" * 60) # --- 1. Configuration (!!! MODIFY THESE PATHS WHEN OUTSIDE HF !!!) --- # Directory where models are stored or will be downloaded @@ -2844,7 +3244,7 @@ if __name__ == '__main__': # Path to the PDF file you want to process # IMPORTANT: Place a PDF file here for testing! # Create a dummy PDF if it doesn't exist for the example to run - MDR_INPUT_PDF = "example_input.pdf" # <--- CHANGE THIS + MDR_INPUT_PDF = "example_input.pdf" # <--- CHANGE THIS if not Path(MDR_INPUT_PDF).exists(): try: print(f"Creating dummy PDF: {MDR_INPUT_PDF}") @@ -2856,7 +3256,6 @@ if __name__ == '__main__': except Exception as e: print(f"Warning: Could not create dummy PDF: {e}") - # Optional: Directory to save debug plots (set to None to disable) MDR_DEBUG_DIRECTORY = "./mdr_debug_output" @@ -2867,7 +3266,7 @@ if __name__ == '__main__': MDR_TABLE_FORMAT = MDRExtractedTableFormat.MARKDOWN # Specify pages (list of 0-based indices, or None for all) - MDR_PAGES = None # Example: [0, 1] for first two pages + MDR_PAGES = None # Example: [0, 1] for first two pages # --- 2. Setup & Pre-checks --- print(f"Model Directory: {os.path.abspath(MDR_MODEL_DIRECTORY)}") @@ -2885,11 +3284,13 @@ if __name__ == '__main__': print(f"ERROR: Input PDF not found at '{MDR_INPUT_PDF}'. Please place a PDF file there or update the path.") exit(1) + # --- 3. Progress Callback --- def mdr_progress_update(completed, total): perc = (completed / total) * 100 if total > 0 else 0 print(f" [Progress] Scanned {completed}/{total} pages ({perc:.1f}%)") + # --- 4. Initialize Processor --- print("Initializing MagicPDFProcessor...") init_start = time.time() @@ -2904,6 +3305,7 @@ if __name__ == '__main__': except Exception as e: print(f"FATAL ERROR during initialization: {e}") import traceback + traceback.print_exc() exit(1) @@ -2935,7 +3337,7 @@ if __name__ == '__main__': info = f" - Block {block_idx + 1}: {type(block).__name__}" if isinstance(block, MDRTextBlock): preview = block.texts[0].content[:70].replace('\n', ' ') + "..." if block.texts else "[EMPTY]" - info += f" (Kind: {block.kind.name}, FontSz: {block.font_size:.2f}, Indent: {block.has_paragraph_indentation}, EndTouch: {block.last_line_touch_end}) | Text: '{preview}'" # Added indent/endtouch + info += f" (Kind: {block.kind.name}, FontSz: {block.font_size:.2f}, Indent: {block.has_paragraph_indentation}, EndTouch: {block.last_line_touch_end}) | Text: '{preview}'" # Added indent/endtouch elif isinstance(block, MDRTableBlock): info += f" (Format: {block.format.name}, HasContent: {bool(block.content)}, FontSz: {block.font_size:.2f})" # if block.content: print(f" Content:\n{block.content}") # Uncomment to see content @@ -2947,15 +3349,16 @@ if __name__ == '__main__': print(info) proc_time = time.time() - proc_start - print("\n" + "="*60) + print("\n" + "=" * 60) print(" Processing Summary") print(f" Total time: {proc_time:.2f} seconds") print(f" Pages processed: {processed_pages_count}") print(f" Total blocks extracted: {all_blocks_count}") - print("="*60) + print("=" * 60) except Exception as e: print(f"\nFATAL ERROR during processing: {e}") import traceback + traceback.print_exc() exit(1) \ No newline at end of file