diff --git "a/mdr_pdf_parser.py" "b/mdr_pdf_parser.py" --- "a/mdr_pdf_parser.py" +++ "b/mdr_pdf_parser.py" @@ -1,21 +1,3 @@ -# -*- coding: utf-8 -*- -# /=====================================================================\ # -# | MagicDataReadiness - MAGIC PDF Parser | # -# |---------------------------------------------------------------------| # -# | Description: | # -# | Extracts structured content (text, tables, figures, formulas) | # -# | from PDF documents using layout analysis and OCR. | # -# | Combines logic from various internal components. | # -# |---------------------------------------------------------------------| # -# | Dependencies: | # -# | - Python 3.11+ | # -# | - External Libraries (See imports below and installation notes) | # -# | - Pre-trained CV Models (Downloaded automatically to model dir) | # -# |---------------------------------------------------------------------| # -# | Usage: | # -# | See the __main__ block at the end of the script for an example. | # -# \=====================================================================/ # - # --- External Library Imports --- import os import re @@ -101,7 +83,8 @@ def mdr_is_whitespace(text: str) -> bool: def mdr_expand_image(image: Image, percent: float) -> Image: """Expands an image with a white border.""" if percent <= 0: return image.copy() - w, h = image.size; bw, bh = ceil(w * percent), ceil(h * percent) + w, h = image.size + bw, bh = ceil(w * percent), ceil(h * percent) fill: tuple[int, ...] | int if image.mode == "RGBA": fill = (255, 255, 255, 255) elif image.mode in ("LA", "L"): fill = 255 @@ -127,26 +110,32 @@ class MDRRectangle: except: return 0.0 @property def size(self) -> tuple[float, float]: - widths, heights = [], []; + widths, heights = [], [] for i, (p1, p2) in enumerate(self.segments): - dx, dy = p1[0]-p2[0], p1[1]-p2[1]; dist = sqrt(dx*dx + dy*dy) + dx, dy = p1[0]-p2[0], p1[1]-p2[1] + dist = sqrt(dx*dx + dy*dy) if i % 2 == 0: heights.append(dist) else: widths.append(dist) - avg_w = sum(widths)/len(widths) if widths else 0.0; avg_h = sum(heights)/len(heights) if heights else 0.0 + avg_w = sum(widths)/len(widths) if widths else 0.0 + avg_h = sum(heights)/len(heights) if heights else 0.0 return avg_w, avg_h @property def wrapper(self) -> tuple[float, float, float, float]: x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") - for x, y in self: x1, y1, x2, y2 = min(x1, x), min(y1, y), max(x2, x), max(y2, y) + for x, y in self: + x1, y1, x2, y2 = min(x1, x), min(y1, y), max(x2, x), max(y2, y) return x1, y1, x2, y2 def mdr_intersection_area(rect1: MDRRectangle, rect2: MDRRectangle) -> float: """Calculates intersection area between two MDRRectangles.""" try: - p1, p2 = Polygon(rect1), Polygon(rect2); - if not p1.is_valid or not p2.is_valid: return 0.0 + p1 = Polygon(rect1) + p2 = Polygon(rect2) + if not p1.is_valid or not p2.is_valid: + return 0.0 return p1.intersection(p2).area - except: return 0.0 + except: + return 0.0 # --- MDR Data Structures --- @dataclass @@ -238,20 +227,27 @@ MDRStructuredBlock = MDRTextBlock | MDRAssetBlock # Type alias # --- MDR Utilities --- def mdr_similarity_ratio(v1: float, v2: float) -> float: """Calculates the ratio of the smaller value to the larger value (0-1).""" - if v1==0 and v2==0: return 1.0; - if v1<0 or v2<0: return 0.0; - v1, v2 = (v2, v1) if v1 > v2 else (v1, v2); - return 1.0 if v2==0 else v1/v2 + if v1 == 0 and v2 == 0: + return 1.0 + if v1 < 0 or v2 < 0: + return 0.0 + v1, v2 = (v2, v1) if v1 > v2 else (v1, v2) + return 1.0 if v2 == 0 else v1 / v2 def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[float, float]: """Calculates width/height of the intersection bounding box.""" try: - p1, p2 = Polygon(r1), Polygon(r2); - if not p1.is_valid or not p2.is_valid: return 0.0, 0.0; - inter = p1.intersection(p2); - if inter.is_empty: return 0.0, 0.0; - minx, miny, maxx, maxy = inter.bounds; return maxx-minx, maxy-miny - except: return 0.0, 0.0 + p1 = Polygon(r1) + p2 = Polygon(r2) + if not p1.is_valid or not p2.is_valid: + return 0.0, 0.0 + inter = p1.intersection(p2) + if inter.is_empty: + return 0.0, 0.0 + minx, miny, maxx, maxy = inter.bounds + return maxx - minx, maxy - miny + except: + return 0.0, 0.0 _MDR_CJKA_PATTERN = re.compile(r"[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3\u0600-\u06ff]") @@ -271,35 +267,69 @@ def _mdr_is_letter(char: str): def mdr_split_into_words(text: str): """Splits text into words, numbers, and individual non-alphanumeric chars.""" - if not text: return; - sp, np, nsp = re.compile(r"\s"), re.compile(r"\d"), re.compile(r"[\.,']"); - buf, phase = io.StringIO(), _MDR_TokenPhase.Init; + if not text: return + sp = re.compile(r"\s") + np = re.compile(r"\d") + nsp = re.compile(r"[\.,']") + buf = io.StringIO() + phase = _MDR_TokenPhase.Init for char in text: - is_l, is_d, is_s, is_ns = _mdr_is_letter(char), np.match(char), sp.match(char), nsp.match(char) + is_l = _mdr_is_letter(char) + is_d = np.match(char) + is_s = sp.match(char) + is_ns = nsp.match(char) if is_l: - if phase in (_MDR_TokenPhase.Number, _MDR_TokenPhase.Character): w=buf.getvalue(); yield w if w else None; buf=io.StringIO() - buf.write(char); phase=_MDR_TokenPhase.Letter + if phase in (_MDR_TokenPhase.Number, _MDR_TokenPhase.Character): + w = buf.getvalue() + yield w if w else None + buf = io.StringIO() + buf.write(char) + phase = _MDR_TokenPhase.Letter elif is_d: - if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Character): w=buf.getvalue(); yield w if w else None; buf=io.StringIO() - buf.write(char); phase=_MDR_TokenPhase.Number - elif phase==_MDR_TokenPhase.Number and is_ns: buf.write(char) + if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Character): + w = buf.getvalue() + yield w if w else None + buf = io.StringIO() + buf.write(char) + phase = _MDR_TokenPhase.Number + elif phase == _MDR_TokenPhase.Number and is_ns: + buf.write(char) else: - if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): w=buf.getvalue(); yield w if w else None; buf=io.StringIO() - if is_s: phase=_MDR_TokenPhase.Space - else: yield char; phase=_MDR_TokenPhase.Character - if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): w=buf.getvalue(); yield w if w else None + if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): + w = buf.getvalue() + yield w if w else None + buf = io.StringIO() + if is_s: + phase = _MDR_TokenPhase.Space + else: + yield char + phase = _MDR_TokenPhase.Character + if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): + w = buf.getvalue() + yield w if w else None def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]: """Calculates word-based similarity between two texts.""" - w1, w2 = list(mdr_split_into_words(t1)), list(mdr_split_into_words(t2)); l1, l2 = len(w1), len(w2) - if l1==0 and l2==0: return 1.0, 0; - if l1==0 or l2==0: return 0.0, max(l1, l2); - if l1 > l2: w1, w2, l1, l2 = w2, w1, l2, l1; - taken = [False]*l2; matches = 0 + w1 = list(mdr_split_into_words(t1)) + w2 = list(mdr_split_into_words(t2)) + l1 = len(w1) + l2 = len(w2) + if l1 == 0 and l2 == 0: + return 1.0, 0 + if l1 == 0 or l2 == 0: + return 0.0, max(l1, l2) + if l1 > l2: + w1, w2, l1, l2 = w2, w1, l2, l1 + taken = [False] * l2 + matches = 0 for word1 in w1: for i, word2 in enumerate(w2): - if not taken[i] and word1==word2: taken[i]=True; matches+=1; break - mismatches = l2 - matches; return 1.0 - (mismatches/l2), l2 + if not taken[i] and word1 == word2: + taken[i] = True + matches += 1 + break + mismatches = l2 - matches + return 1.0 - (mismatches / l2), l2 # --- MDR Geometry Processing --- class MDRRotationAdjuster: @@ -308,12 +338,17 @@ class MDRRotationAdjuster: def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool): fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size) self._rot = rotation if to_origin_coordinate else -rotation - self._c_off = (fs[0]/2.0, fs[1]/2.0); self._n_off = (ts[0]/2.0, ts[1]/2.0) + self._c_off = (fs[0]/2.0, fs[1]/2.0) + self._n_off = (ts[0]/2.0, ts[1]/2.0) def adjust(self, point: MDRPoint) -> MDRPoint: - x, y = point[0]-self._c_off[0], point[1]-self._c_off[1] - if x!=0 or y!=0: cos_r, sin_r = cos(self._rot), sin(self._rot); x, y = x*cos_r-y*sin_r, x*sin_r+y*cos_r - return x+self._n_off[0], y+self._n_off[1] + x = point[0] - self._c_off[0] + y = point[1] - self._c_off[1] + if x != 0 or y != 0: + cos_r = cos(self._rot) + sin_r = sin(self._rot) + x, y = x * cos_r - y * sin_r, x * sin_r + y * cos_r + return x + self._n_off[0], y + self._n_off[1] def mdr_normalize_vertical_rotation(rot: float) -> float: while rot >= pi: rot -= pi; @@ -323,34 +358,51 @@ def mdr_normalize_vertical_rotation(rot: float) -> float: def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[float]] | None: h_angs, v_angs = [], [] for i, (p1, p2) in enumerate(rect.segments): - dx, dy = p2[0]-p1[0], p2[1]-p1[1]; - if abs(dx)<1e-6 and abs(dy)<1e-6: continue; - ang = atan2(dy, dx); - if ang < 0: ang += pi; - if ang < pi*0.25 or ang >= pi*0.75: h_angs.append(ang-pi if ang>=pi*0.75 else ang) - else: v_angs.append(ang) - if not h_angs or not v_angs: return None + dx = p2[0] - p1[0] + dy = p2[1] - p1[1] + if abs(dx) < 1e-6 and abs(dy) < 1e-6: + continue + ang = atan2(dy, dx) + if ang < 0: + ang += pi + if ang < pi * 0.25 or ang >= pi * 0.75: + h_angs.append(ang - pi if ang >= pi * 0.75 else ang) + else: + v_angs.append(ang) + if not h_angs or not v_angs: + return None return h_angs, v_angs def _mdr_normalize_horizontal_angles(rots: list[float]) -> list[float]: return rots def _mdr_find_median(data: list[float]) -> float: - if not data: return 0.0; s_data = sorted(data); n = len(s_data); - return s_data[n//2] if n%2==1 else (s_data[n//2-1]+s_data[n//2])/2.0 + if not data: + return 0.0 + s_data = sorted(data) + n = len(s_data) + return s_data[n // 2] if n % 2 == 1 else (s_data[n // 2 - 1] + s_data[n // 2]) / 2.0 def _mdr_find_mean(data: list[float]) -> float: return sum(data)/len(data) if data else 0.0 def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float: - all_h, all_v = [], []; + all_h, all_v = [], [] for f in frags: - res = _mdr_get_rectangle_angles(f.rect); - if res: h, v = res; all_h.extend(h); all_v.extend(v) - if not all_h or not all_v: return 0.0; - all_h = _mdr_normalize_horizontal_angles(all_h); all_v = [mdr_normalize_vertical_rotation(a) for a in all_v] - med_h, med_v = _mdr_find_median(all_h), _mdr_find_median(all_v); - rot_est = ((pi/2 - med_v) - med_h) / 2.0; - while rot_est >= pi/2: rot_est -= pi; - while rot_est < -pi/2: rot_est += pi; + res = _mdr_get_rectangle_angles(f.rect) + if res: + h, v = res + all_h.extend(h) + all_v.extend(v) + if not all_h or not all_v: + return 0.0 + all_h = _mdr_normalize_horizontal_angles(all_h) + all_v = [mdr_normalize_vertical_rotation(a) for a in all_v] + med_h = _mdr_find_median(all_h) + med_v = _mdr_find_median(all_v) + rot_est = ((pi / 2 - med_v) - med_h) / 2.0 + while rot_est >= pi / 2: + rot_est -= pi + while rot_est < -pi / 2: + rot_est += pi return rot_est def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]: @@ -366,7 +418,8 @@ class _MDR_PredictBase: def get_onnx_session(self, model_path: str, use_gpu: bool): try: - sess_opts = onnxruntime.SessionOptions(); sess_opts.log_severity_level = 3 + sess_opts = onnxruntime.SessionOptions() + sess_opts.log_severity_level = 3 providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu and 'CUDAExecutionProvider' in onnxruntime.get_available_providers() else ['CPUExecutionProvider'] session = onnxruntime.InferenceSession(model_path, sess_options=sess_opts, providers=providers) print(f" ONNX session loaded: {Path(model_path).name} ({session.get_providers()})") @@ -388,40 +441,120 @@ class _MDR_NormalizeImage: def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0)) - mean = mean if mean is not None else [0.485, 0.456, 0.406]; std = std if std is not None else [0.229, 0.224, 0.225] - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3); self.mean = np.array(mean).reshape(shape).astype('float32'); self.std = np.array(std).reshape(shape).astype('float32') + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') - def __call__(self, data): img = data['image']; img = np.array(img) if isinstance(img, Image) else img; data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std; return data + def __call__(self, data): + img = data['image'] + img = np.array(img) if isinstance(img, Image) else img + data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std + return data class _MDR_DetResizeForTest: def __init__(self, **kwargs): - self.resize_type = 0; self.keep_ratio = False - if 'image_shape' in kwargs: self.image_shape = kwargs['image_shape']; self.resize_type = 1; self.keep_ratio = kwargs.get('keep_ratio', False) - elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len']; self.limit_type = kwargs.get('limit_type', 'min') - elif 'resize_long' in kwargs: self.resize_type = 2; self.resize_long = kwargs.get('resize_long', 960) - else: self.limit_side_len = 736; self.limit_type = 'min' + self.resize_type = 0 + self.keep_ratio = False + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + self.keep_ratio = kwargs.get('keep_ratio', False) + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' def __call__(self, data): - img = data['image']; src_h, src_w, _ = img.shape - if src_h + src_w < 64: img = self._pad(img) - if self.resize_type == 0: img, ratios = self._resize0(img) - elif self.resize_type == 2: img, ratios = self._resize2(img) - else: img, ratios = self._resize1(img) - if img is None: return None - data['image'] = img; data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]); return data - - def _pad(self, im, v=0): h,w,c=im.shape; p=np.zeros((max(32,h),max(32,w),c),np.uint8)+v; p[:h,:w,:]=im; return p - - def _resize1(self, img): rh,rw=self.image_shape; oh,ow=img.shape[:2]; if self.keep_ratio: rw=ow*rh/oh; N=ceil(rw/32); rw=N*32; r_h,r_w=float(rh)/oh,float(rw)/ow; img=cv2.resize(img,(int(rw),int(rh))); return img,[r_h,r_w] - - def _resize0(self, img): lsl=self.limit_side_len; h,w,_=img.shape; r=1.0; if self.limit_type=='max': r=float(lsl)/max(h,w) if max(h,w)>lsl else 1.0; elif self.limit_type=='min': r=float(lsl)/min(h,w) if min(h,w) lsl else 1.0 + elif self.limit_type == 'min': + r = float(lsl) / min(h, w) if min(h, w) < lsl else 1.0 + elif self.limit_type == 'resize_long': + r = float(lsl) / max(h, w) + else: + raise Exception('Unsupported limit_type') + rh = int(h * r) + rw = int(w * r) + rh = max(int(round(rh / 32) * 32), 32) + rw = max(int(round(rw / 32) * 32), 32) + if int(rw) <= 0 or int(rh) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(rw), int(rh))) + r_h = rh / float(h) + r_w = rw / float(w) + return img, [r_h, r_w] + + def _resize2(self, img): + h, w, _ = img.shape + rl = self.resize_long + r = float(rl) / max(h, w) + rh = int(h * r) + rw = int(w * r) + ms = 128 + rh = (rh + ms - 1) // ms * ms + rw = (rw + ms - 1) // ms * ms + img = cv2.resize(img, (int(rw), int(rh))) + r_h = rh / float(h) + r_w = rw / float(w) + return img, [r_h, r_w] class _MDR_ToCHWImage: - def __call__(self, data): img=data['image']; img=np.array(img) if isinstance(img,Image) else img; data['image']=img.transpose((2,0,1)); return data + def __call__(self, data): + img = data['image'] + img = np.array(img) if isinstance(img, Image) else img + data['image'] = img.transpose((2, 0, 1)) + return data class _MDR_KeepKeys: @@ -448,101 +581,205 @@ def mdr_ocr_create_operators(op_param_list, global_config=None): class _MDR_DBPostProcess: def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs): - self.thresh, self.box_thresh, self.max_cand = thresh, box_thresh, max_candidates; self.unclip_r, self.min_sz, self.score_m, self.box_t = unclip_ratio, 3, score_mode, box_type - assert score_mode in ["slow", "fast"]; self.dila_k = np.array([[1,1],[1,1]], dtype=np.uint8) if use_dilation else None + self.thresh = thresh + self.box_thresh = box_thresh + self.max_cand = max_candidates + self.unclip_r = unclip_ratio + self.min_sz = 3 + self.score_m = score_mode + self.box_t = box_type + assert score_mode in ["slow", "fast"] + self.dila_k = np.array([[1, 1], [1, 1]], dtype=np.uint8) if use_dilation else None def _polygons_from_bitmap(self, pred, bmp, dw, dh): - h, w = bmp.shape; boxes, scores = [], [] - contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + h, w = bmp.shape + boxes, scores = [], [] + contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) for contour in contours[:self.max_cand]: - eps = 0.002*cv2.arcLength(contour,True); approx = cv2.approxPolyDP(contour,eps,True); pts = approx.reshape((-1,2)) - if pts.shape[0]<4: continue - score = self._box_score_fast(pred, pts.reshape(-1,2)); - if self.box_thresh > score: continue - try: box = self._unclip(pts, self.unclip_r); - except: continue - if len(box)>1: continue; box = box.reshape(-1,2) - _, sside = self._get_mini_boxes(box.reshape((-1,1,2))); - if sside < self.min_sz+2: continue - box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh) - boxes.append(box.tolist()); scores.append(score) + eps = 0.002 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, eps, True) + pts = approx.reshape((-1, 2)) + if pts.shape[0] < 4: + continue + score = self._box_score_fast(pred, pts.reshape(-1, 2)) + if self.box_thresh > score: + continue + try: + box = self._unclip(pts, self.unclip_r) + except: + continue + if len(box) > 1: + continue + box = box.reshape(-1, 2) + _, sside = self._get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_sz + 2: + continue + box = np.array(box) + box[:, 0] = np.clip(np.round(box[:, 0] / w * dw), 0, dw) + box[:, 1] = np.clip(np.round(box[:, 1] / h * dh), 0, dh) + boxes.append(box.tolist()) + scores.append(score) return boxes, scores def _boxes_from_bitmap(self, pred, bmp, dw, dh): - h, w = bmp.shape; contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - num_contours = min(len(contours), self.max_cand); boxes, scores = [], [] + h, w = bmp.shape + contours, _ = cv2.findContours((bmp * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + num_contours = min(len(contours), self.max_cand) + boxes, scores = [], [] for i in range(num_contours): - contour = contours[i]; pts, sside = self._get_mini_boxes(contour); - if sside < self.min_sz: continue - pts = np.array(pts); score = self._box_score_fast(pred, pts.reshape(-1,2)) if self.score_m=="fast" else self._box_score_slow(pred, contour) - if self.box_thresh > score: continue - try: box = self._unclip(pts, self.unclip_r).reshape(-1,1,2) - except: continue - box, sside = self._get_mini_boxes(box); - if sside < self.min_sz+2: continue - box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh) - boxes.append(box.astype("int32")); scores.append(score) + contour = contours[i] + pts, sside = self._get_mini_boxes(contour) + if sside < self.min_sz: + continue + pts = np.array(pts) + score = self._box_score_fast(pred, pts.reshape(-1, 2)) if self.score_m == "fast" else self._box_score_slow(pred, contour) + if self.box_thresh > score: + continue + try: + box = self._unclip(pts, self.unclip_r).reshape(-1, 1, 2) + except: + continue + box, sside = self._get_mini_boxes(box) + if sside < self.min_sz + 2: + continue + box = np.array(box) + box[:, 0] = np.clip(np.round(box[:, 0] / w * dw), 0, dw) + box[:, 1] = np.clip(np.round(box[:, 1] / h * dh), 0, dh) + boxes.append(box.astype("int32")) + scores.append(score) return np.array(boxes, dtype="int32"), scores def _unclip(self, box, ratio): - poly = Polygon(box); dist = poly.area*ratio/poly.length; offset = pyclipper.PyclipperOffset(); offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - expanded = offset.Execute(dist); - if not expanded: raise ValueError("Unclip failed"); return np.array(expanded[0]) + poly = Polygon(box) + dist = poly.area * ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = offset.Execute(dist) + if not expanded: + raise ValueError("Unclip failed") + return np.array(expanded[0]) def _get_mini_boxes(self, contour): - bb = cv2.minAreaRect(contour); pts = sorted(list(cv2.boxPoints(bb)), key=lambda x:x[0]) - i1,i4 = (0,1) if pts[1][1]>pts[0][1] else (1,0); i2,i3 = (2,3) if pts[3][1]>pts[2][1] else (3,2) - box = [pts[i1], pts[i2], pts[i3], pts[i4]]; return box, min(bb[1]) + bb = cv2.minAreaRect(contour) + pts = sorted(list(cv2.boxPoints(bb)), key=lambda x: x[0]) + i1, i4 = (0, 1) if pts[1][1] > pts[0][1] else (1, 0) + i2, i3 = (2, 3) if pts[3][1] > pts[2][1] else (3, 2) + box = [pts[i1], pts[i2], pts[i3], pts[i4]] + return box, min(bb[1]) def _box_score_fast(self, bmp, box): - h,w = bmp.shape[:2]; xmin=np.clip(np.floor(box[:,0].min()).astype("int32"),0,w-1); xmax=np.clip(np.ceil(box[:,0].max()).astype("int32"),0,w-1) - ymin=np.clip(np.floor(box[:,1].min()).astype("int32"),0,h-1); ymax=np.clip(np.ceil(box[:,1].max()).astype("int32"),0,h-1) - mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=np.uint8); box[:,0]-=xmin; box[:,1]-=ymin - cv2.fillPoly(mask, box.reshape(1,-1,2).astype("int32"), 1); - return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0 + h, w = bmp.shape[:2] + xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] -= xmin + box[:, 1] -= ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) + return cv2.mean(bmp[ymin : ymax + 1, xmin : xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 def _box_score_slow(self, bmp, contour): # Not used if fast - h,w = bmp.shape[:2]; contour = np.reshape(contour.copy(),(-1,2)); xmin=np.clip(np.min(contour[:,0]),0,w-1); xmax=np.clip(np.max(contour[:,0]),0,w-1) - ymin=np.clip(np.min(contour[:,1]),0,h-1); ymax=np.clip(np.max(contour[:,1]),0,h-1); mask=np.zeros((ymax-ymin+1,xmax-xmin+1),dtype=np.uint8) - contour[:,0]-=xmin; contour[:,1]-=ymin; cv2.fillPoly(mask, contour.reshape(1,-1,2).astype("int32"), 1); - return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0 + h, w = bmp.shape[:2] + contour = np.reshape(contour.copy(), (-1, 2)) + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + contour[:, 0] -= xmin + contour[:, 1] -= ymin + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) + return cv2.mean(bmp[ymin : ymax + 1, xmin : xmax + 1], mask)[0] if np.sum(mask) > 0 else 0.0 def __call__(self, outs_dict, shape_list): - pred = outs_dict['maps'][:,0,:,:]; seg = pred > self.thresh; boxes_batch = [] + pred = outs_dict['maps'][:, 0, :, :] + seg = pred > self.thresh + boxes_batch = [] for batch_idx in range(pred.shape[0]): - sh, sw, _, _ = shape_list[batch_idx]; mask = cv2.dilate(np.array(seg[batch_idx]).astype(np.uint8), self.dila_k) if self.dila_k is not None else seg[batch_idx] - if self.box_t=='poly': boxes, _ = self._polygons_from_bitmap(pred[batch_idx], mask, sw, sh) - elif self.box_t=='quad': boxes, _ = self._boxes_from_bitmap(pred[batch_idx], mask, sw, sh) - else: raise ValueError("box_type must be 'quad' or 'poly'") + sh, sw, _, _ = shape_list[batch_idx] + mask = cv2.dilate(np.array(seg[batch_idx]).astype(np.uint8), self.dila_k) if self.dila_k is not None else seg[batch_idx] + if self.box_t == 'poly': + boxes, _ = self._polygons_from_bitmap(pred[batch_idx], mask, sw, sh) + elif self.box_t == 'quad': + boxes, _ = self._boxes_from_bitmap(pred[batch_idx], mask, sw, sh) + else: + raise ValueError("box_type must be 'quad' or 'poly'") boxes_batch.append({'points': boxes}) return boxes_batch class _MDR_TextDetector(_MDR_PredictBase): def __init__(self, args): - super().__init__(); self.args = args + super().__init__() + self.args = args pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}] self.pre_op = mdr_ocr_create_operators(pre_ops) post_params = {'thresh': args.det_db_thresh, 'box_thresh': args.det_db_box_thresh, 'max_candidates': 1000, 'unclip_ratio': args.det_db_unclip_ratio, 'use_dilation': args.use_dilation, 'score_mode': args.det_db_score_mode, 'box_type': args.det_box_type} self.post_op = _MDR_DBPostProcess(**post_params) self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu) - self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess) - - def _order_pts(self, pts): r=np.zeros((4,2),dtype="float32"); s=pts.sum(axis=1); r[0]=pts[np.argmin(s)]; r[2]=pts[np.argmax(s)]; tmp=np.delete(pts,(np.argmin(s),np.argmax(s)),axis=0); d=np.diff(np.array(tmp),axis=1); r[1]=tmp[np.argmin(d)]; r[3]=tmp[np.argmax(d)]; return r - - def _clip_pts(self, pts, h, w): pts[:,0]=np.clip(pts[:,0],0,w-1); pts[:,1]=np.clip(pts[:,1],0,h-1); return pts - - def _filter_quad(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._order_pts(box); box=self._clip_pts(box,h,w); rw=int(np.linalg.norm(box[0]-box[1])); rh=int(np.linalg.norm(box[0]-box[3])); if rw<=3 or rh<=3: continue; new_boxes.append(box); return np.array(new_boxes) - - def _filter_poly(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._clip_pts(box,h,w); if Polygon(box).area<10: continue; new_boxes.append(box); return np.array(new_boxes) + self.input_name = self.get_input_name(self.sess) + self.output_name = self.get_output_name(self.sess) + + def _order_pts(self, pts): + r = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + r[0] = pts[np.argmin(s)] + r[2] = pts[np.argmax(s)] + tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) + d = np.diff(np.array(tmp), axis=1) + r[1] = tmp[np.argmin(d)] + r[3] = tmp[np.argmax(d)] + return r + + def _clip_pts(self, pts, h, w): + pts[:, 0] = np.clip(pts[:, 0], 0, w - 1) + pts[:, 1] = np.clip(pts[:, 1], 0, h - 1) + return pts + + def _filter_quad(self, boxes, shape): + h, w = shape[0:2] + new_boxes = [] + for box in boxes: + box = np.array(box) if isinstance(box, list) else box + box = self._order_pts(box) + box = self._clip_pts(box, h, w) + rw = int(np.linalg.norm(box[0] - box[1])) + rh = int(np.linalg.norm(box[0] - box[3])) + if rw <= 3 or rh <= 3: + continue + new_boxes.append(box) + return np.array(new_boxes) + + def _filter_poly(self, boxes, shape): + h, w = shape[0:2] + new_boxes = [] + for box in boxes: + box = np.array(box) if isinstance(box, list) else box + box = self._clip_pts(box, h, w) + if Polygon(box).area < 10: + continue + new_boxes.append(box) + return np.array(new_boxes) def __call__(self, img): - ori_im = img.copy(); data = {"image": img}; data = mdr_ocr_transform(data, self.pre_op) - if data is None: return None; img, shape_list = data; - if img is None: return None; img = np.expand_dims(img, axis=0); shape_list = np.expand_dims(shape_list, axis=0); img = img.copy() - inputs = self.get_input_feed(self.input_name, img); outputs = self.sess.run(self.output_name, input_feed=inputs) - preds = {"maps": outputs[0]}; post_res = self.post_op(preds, shape_list); boxes = post_res[0]['points'] - return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type=='poly' else self._filter_quad(boxes, ori_im.shape) + ori_im = img.copy() + data = {"image": img} + data = mdr_ocr_transform(data, self.pre_op) + if data is None: + return None + img, shape_list = data + if img is None: + return None + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + inputs = self.get_input_feed(self.input_name, img) + outputs = self.sess.run(self.output_name, input_feed=inputs) + preds = {"maps": outputs[0]} + post_res = self.post_op(preds, shape_list) + boxes = post_res[0]['points'] + return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type == 'poly' else self._filter_quad(boxes, ori_im.shape) class _MDR_ClsPostProcess: @@ -555,30 +792,60 @@ class _MDR_ClsPostProcess: class _MDR_TextClassifier(_MDR_PredictBase): def __init__(self, args): - super().__init__(); self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape - self.batch_num = args.cls_batch_num; self.thresh = args.cls_thresh; self.post_op = _MDR_ClsPostProcess(label_list=args.label_list) - self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu); self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess) + super().__init__() + self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape + self.batch_num = args.cls_batch_num + self.thresh = args.cls_thresh + self.post_op = _MDR_ClsPostProcess(label_list=args.label_list) + self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu) + self.input_name = self.get_input_name(self.sess) + self.output_name = self.get_output_name(self.sess) def _resize_norm(self, img): - imgC,imgH,imgW = self.shape; h,w = img.shape[:2]; r=w/float(h) if h>0 else 0; rw=int(ceil(imgH*r)); rw=min(rw,imgW) - resized = cv2.resize(img,(rw,imgH)); resized = resized.astype("float32") - if imgC==1: resized = resized/255.0; resized = resized[np.newaxis,:] - else: resized = resized.transpose((2,0,1))/255.0 - resized -= 0.5; resized /= 0.5; padding = np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:rw]=resized; return padding + imgC, imgH, imgW = self.shape + h, w = img.shape[:2] + r = w / float(h) if h > 0 else 0 + rw = int(ceil(imgH * r)) + rw = min(rw, imgW) + resized = cv2.resize(img, (rw, imgH)) + resized = resized.astype("float32") + if imgC == 1: + resized = resized / 255.0 + resized = resized[np.newaxis, :] + else: + resized = resized.transpose((2, 0, 1)) / 255.0 + resized -= 0.5 + resized /= 0.5 + padding = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding[:, :, 0:rw] = resized + return padding def __call__(self, img_list): - if not img_list: return img_list, []; img_list_cp = copy.deepcopy(img_list); num = len(img_list_cp) - ratios = [img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list_cp]; indices = np.argsort(np.array(ratios)) - results = [["",0.0]]*num; batch_n = self.batch_num + if not img_list: + return img_list, [] + img_list_cp = copy.deepcopy(img_list) + num = len(img_list_cp) + ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list_cp] + indices = np.argsort(np.array(ratios)) + results = [["", 0.0]] * num + batch_n = self.batch_num for start in range(0, num, batch_n): - end = min(num, start+batch_n); batch = [] - for i in range(start, end): batch.append(self._resize_norm(img_list_cp[indices[i]])[np.newaxis,:]) - if not batch: continue; batch = np.concatenate(batch, axis=0).copy() - inputs = self.get_input_feed(self.input_name, batch); outputs = self.sess.run(self.output_name, input_feed=inputs) + end = min(num, start + batch_n) + batch = [] + for i in range(start, end): + batch.append(self._resize_norm(img_list_cp[indices[i]])[np.newaxis, :]) + if not batch: + continue + batch = np.concatenate(batch, axis=0).copy() + inputs = self.get_input_feed(self.input_name, batch) + outputs = self.sess.run(self.output_name, input_feed=inputs) cls_out = self.post_op(outputs[0]) for i in range(len(cls_out)): - orig_idx = indices[start+i]; label, score = cls_out[i]; results[orig_idx] = [label, score] - if "180" in label and score > self.thresh: img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180) + orig_idx = indices[start + i] + label, score = cls_out[i] + results[orig_idx] = [label, score] + if "180" in label and score > self.thresh: + img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180) return img_list, results class _MDR_BaseRecLabelDecode: @@ -601,16 +868,22 @@ class _MDR_BaseRecLabelDecode: def _reverse(self, pred): res=[]; cur=""; for c in pred: if not re.search("[a-zA-Z0-9 :*./%+-]",c): res.extend([cur,c] if cur!="" else [c]); cur="" else: cur+=c; if cur!="": res.append(cur); return "".join(res[::-1]) def decode(self, idxs, probs=None, remove_dup=False): - res=[]; ignored=self.get_ignored_tokens(); bs=len(idxs) + res = [] + ignored = self.get_ignored_tokens() + bs = len(idxs) for b_idx in range(bs): - sel=np.ones(len(idxs[b_idx]),dtype=bool); - if remove_dup: sel[1:]=idxs[b_idx][1:]!=idxs[b_idx][:-1] - for ig_tok in ignored: sel &= idxs[b_idx]!=ig_tok - char_l = [self.character[tid] for tid in idxs[b_idx][sel] if 0<=tid0 else 0 for img in img_list] - indices=np.argsort(np.array(ratios)); results=[["",0.0]]*num; batch_n=self.batch_num + if not img_list: + return [] + num = len(img_list) + ratios = [img.shape[1] / float(img.shape[0]) if img.shape[0] > 0 else 0 for img in img_list] + indices = np.argsort(np.array(ratios)) + results = [["", 0.0]] * num + batch_n = self.batch_num for start in range(0, num, batch_n): - end=min(num, start+batch_n); batch=[]; max_r_batch=0 - for i in range(start, end): h,w=img_list[indices[i]].shape[0:2]; if h>0: max_r_batch=max(max_r_batch, w/float(h)) - for i in range(start, end): batch.append(self._resize_norm(img_list[indices[i]], max_r_batch)[np.newaxis,:]) - if not batch: continue; batch=np.concatenate(batch, axis=0).copy() - inputs=self.get_input_feed(self.input_name, batch); outputs=self.sess.run(self.output_name, input_feed=inputs) - rec_out=self.post_op(outputs[0]) - for i in range(len(rec_out)): results[indices[start+i]]=rec_out[i] + end = min(num, start + batch_n) + batch = [] + max_r_batch = 0 + for i in range(start, end): + h, w = img_list[indices[i]].shape[0:2] + if h > 0: + max_r_batch = max(max_r_batch, w / float(h)) + for i in range(start, end): + batch.append(self._resize_norm(img_list[indices[i]], max_r_batch)[np.newaxis, :]) + if not batch: + continue + batch = np.concatenate(batch, axis=0).copy() + inputs = self.get_input_feed(self.input_name, batch) + outputs = self.sess.run(self.output_name, input_feed=inputs) + rec_out = self.post_op(outputs[0]) + for i in range(len(rec_out)): + results[indices[start + i]] = rec_out[i] return results # --- MDR ONNX OCR System --- @@ -666,7 +971,9 @@ class _MDR_TextSystem: self.use_cls = getattr(args, 'use_angle_cls', True) self.drop_score = getattr(args, 'drop_score', 0.5) self.classifier = _MDR_TextClassifier(args) if self.use_cls else None - self.crop_idx = 0; self.save_crop = getattr(args, 'save_crop_res', False); self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res") + self.crop_idx = 0 + self.save_crop = getattr(args, 'save_crop_res', False) + self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res") def _sort_boxes(self, boxes): if boxes is None or len(boxes)==0: return [] @@ -675,48 +982,71 @@ class _MDR_TextSystem: except: return list(boxes) # Fallback def __call__(self, img, classify=True): - ori_im = img.copy(); boxes = self.detector(img) - if boxes is None or len(boxes)==0: return [], [] - boxes = self._sort_boxes(boxes); crops = [] + ori_im = img.copy() + boxes = self.detector(img) + if boxes is None or len(boxes) == 0: + return [], [] + boxes = self._sort_boxes(boxes) + crops = [] for b in boxes: - try: crops.append(mdr_get_rotated_crop(ori_im, b)) # Use renamed util - except: crops.append(None) - valid_idxs = [i for i,c in enumerate(crops) if c is not None]; - if not valid_idxs: return [], [] - crops = [crops[i] for i in valid_idxs]; boxes = [boxes[i] for i in valid_idxs] + try: + crops.append(mdr_get_rotated_crop(ori_im, b)) # Use renamed util + except: + crops.append(None) + valid_idxs = [i for i, c in enumerate(crops) if c is not None] + if not valid_idxs: + return [], [] + crops = [crops[i] for i in valid_idxs] + boxes = [boxes[i] for i in valid_idxs] if self.use_cls and self.classifier and classify: - try: crops, _ = self.classifier(crops) # Ignore cls results, just use rotated crops - except Exception as e: print(f"Classifier error: {e}") - try: rec_res = self.recognizer(crops) - except Exception as e: print(f"Recognizer error: {e}"); return boxes, [["",0.0]]*len(boxes) + try: + crops, _ = self.classifier(crops) # Ignore cls results, just use rotated crops + except Exception as e: + print(f"Classifier error: {e}") + try: + rec_res = self.recognizer(crops) + except Exception as e: + print(f"Recognizer error: {e}") + return boxes, [["", 0.0]] * len(boxes) final_boxes, final_rec = [], [] for box, res in zip(boxes, rec_res): - txt, score = res; - if score >= self.drop_score: final_boxes.append(box); final_rec.append(res) - if self.save_crop: self._save_crops(crops, rec_res) + txt, score = res + if score >= self.drop_score: + final_boxes.append(box) + final_rec.append(res) + if self.save_crop: + self._save_crops(crops, rec_res) return final_boxes, final_rec def _save_crops(self, crops, recs): - mdr_ensure_directory(self.crop_dir); num = len(crops) - for i in range(num): txt, score = recs[i]; safe=re.sub(r'\W+', '_', txt)[:20]; fname=f"crop_{self.crop_idx+i}_{safe}_{score:.2f}.jpg"; cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i]) + mdr_ensure_directory(self.crop_dir) + num = len(crops) + for i in range(num): + txt, score = recs[i] + safe = re.sub(r'\W+', '_', txt)[:20] + fname = f"crop_{self.crop_idx + i}_{safe}_{score:.2f}.jpg" + cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i]) self.crop_idx += num # --- MDR ONNX OCR Utilities --- def mdr_get_rotated_crop(img, points): """Crops and perspective-transforms a quadrilateral region.""" - pts = np.array(points, dtype="float32"); assert len(pts)==4 - w = int(max(np.linalg.norm(pts[0]-pts[1]), np.linalg.norm(pts[2]-pts[3]))) - h = int(max(np.linalg.norm(pts[0]-pts[3]), np.linalg.norm(pts[1]-pts[2]))) - std = np.float32([[0,0],[w,0],[w,h],[0,h]]) + pts = np.array(points, dtype="float32") + assert len(pts) == 4 + w = int(max(np.linalg.norm(pts[0] - pts[1]), np.linalg.norm(pts[2] - pts[3]))) + h = int(max(np.linalg.norm(pts[0] - pts[3]), np.linalg.norm(pts[1] - pts[2]))) + std = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) M = cv2.getPerspectiveTransform(pts, std) - dst = cv2.warpPerspective(img, M, (w,h), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC) + dst = cv2.warpPerspective(img, M, (w, h), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC) dh, dw = dst.shape[0:2] - if dh>0 and dw>0 and dh*1.0/dw >= 1.5: dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE) + if dh > 0 and dw > 0 and dh * 1.0 / dw >= 1.5: + dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE) return dst def mdr_get_min_area_crop(img, points): """Crops the minimum area rectangle containing the points.""" - bb = cv2.minAreaRect(np.array(points).astype(np.int32)); box_pts = cv2.boxPoints(bb) + bb = cv2.minAreaRect(np.array(points).astype(np.int32)) + box_pts = cv2.boxPoints(bb) return mdr_get_rotated_crop(img, box_pts) # --- MDR Layout Processing --- @@ -739,13 +1069,23 @@ class _MDR_OverlapMatrixContext: r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2 - try: inter = p1.intersection(p2); - except: return 0.0 - if inter.is_empty or inter.area < 1e-6: return 0.0 - _, _, ix1, iy1 = inter.bounds; iw, ih = ix1-inter.bounds[0], iy1-inter.bounds[1] - _, _, px1, py1 = p2.bounds; pw, ph = px1-p2.bounds[0], py1-p2.bounds[1] - if pw < 1e-6 or ph < 1e-6: return 0.0 - wr = min(iw/pw, 1.0); hr = min(ih/ph, 1.0); return (wr+hr)/2.0 + try: + inter = p1.intersection(p2) + except: + return 0.0 + if inter.is_empty or inter.area < 1e-6: + return 0.0 + _, _, ix1, iy1 = inter.bounds + iw = ix1 - inter.bounds[0] + ih = iy1 - inter.bounds[1] + _, _, px1, py1 = p2.bounds + pw = px1 - p2.bounds[0] + ph = py1 - p2.bounds[1] + if pw < 1e-6 or ph < 1e-6: + return 0.0 + wr = min(iw / pw, 1.0) + hr = min(ih / ph, 1.0) + return (wr + hr) / 2.0 def others(self, idx: int): for i, r in enumerate(self.matrix[idx]): @@ -757,97 +1097,171 @@ class _MDR_OverlapMatrixContext: if self.matrix[i][idx] < _MDR_INCLUDES_MIN_RATE: yield i def mdr_remove_overlap_layouts(layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]: - if not layouts: return []; ctx = _MDR_OverlapMatrixContext(layouts); prev_removed = -1 + if not layouts: + return [] + ctx = _MDR_OverlapMatrixContext(layouts) + prev_removed = -1 while len(ctx.removed) != prev_removed: - prev_removed = len(ctx.removed); current_removed = set() + prev_removed = len(ctx.removed) + current_removed = set() for i in range(len(layouts)): - if i in ctx.removed or i in current_removed: continue; - li = layouts[i]; pi = ctx.polys[i] - if pi is None: current_removed.add(i); continue; + if i in ctx.removed or i in current_removed: + continue + li = layouts[i] + pi = ctx.polys[i] + if pi is None: + current_removed.add(i) + continue contained = False for j in range(len(layouts)): - if i==j or j in ctx.removed or j in current_removed: continue - if ctx.matrix[j][i] >= _MDR_INCLUDES_MIN_RATE and ctx.matrix[i][j] < _MDR_INCLUDES_MIN_RATE: contained=True; break - if contained: current_removed.add(i); continue; + if i == j or j in ctx.removed or j in current_removed: + continue + if ctx.matrix[j][i] >= _MDR_INCLUDES_MIN_RATE and ctx.matrix[i][j] < _MDR_INCLUDES_MIN_RATE: + contained = True + break + if contained: + current_removed.add(i) + continue contained_by_i = list(ctx.includes(i)) if contained_by_i: for j in contained_by_i: - if j not in ctx.removed and j not in current_removed: li.fragments.extend(layouts[j].fragments); current_removed.add(j) + if j not in ctx.removed and j not in current_removed: + li.fragments.extend(layouts[j].fragments) + current_removed.add(j) li.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) ctx.removed.update(current_removed) return [l for i, l in enumerate(layouts) if i not in ctx.removed] def _mdr_split_fragments_into_lines(frags: list[MDROcrFragment]) -> Generator[list[MDROcrFragment], None, None]: - if not frags: return; frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])); + if not frags: + return + frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) group, y_sum, h_sum = [], 0.0, 0.0 for f in frags: - _, y1, _, y2 = f.rect.wrapper; h = y2-y1; med_y = (y1+y2)/2.0 - if h <= 0: continue; - if not group: group.append(f); y_sum, h_sum = med_y, h + _, y1, _, y2 = f.rect.wrapper + h = y2 - y1 + med_y = (y1 + y2) / 2.0 + if h <= 0: + continue + if not group: + group.append(f) + y_sum, h_sum = med_y, h else: - g_len = len(group); avg_med_y, avg_h = y_sum/g_len, h_sum/g_len + g_len = len(group) + avg_med_y = y_sum / g_len + avg_h = h_sum / g_len max_dev = avg_h * 0.40 - if abs(med_y - avg_med_y) > max_dev: yield group; group, y_sum, h_sum = [f], med_y, h - else: group.append(f); y_sum += med_y; h_sum += h - if group: yield group + if abs(med_y - avg_med_y) > max_dev: + yield group + group, y_sum, h_sum = [f], med_y, h + else: + group.append(f) + y_sum += med_y + h_sum += h + if group: + yield group def mdr_merge_fragments_into_lines(orig_frags: list[MDROcrFragment]) -> list[MDROcrFragment]: - merged = []; + merged = [] for group in _mdr_split_fragments_into_lines(orig_frags): - if not group: continue; - if len(group) == 1: merged.append(group[0]); continue; - group.sort(key=lambda f: f.rect.lt[0]); + if not group: + continue + if len(group) == 1: + merged.append(group[0]) + continue + group.sort(key=lambda f: f.rect.lt[0]) min_order = min(f.order for f in group if hasattr(f, 'order')) if group else 0 texts, rank_w, txt_len = [], 0.0, 0 x1, y1, x2, y2 = float("inf"), float("inf"), float("-inf"), float("-inf") for f in group: - fx1, fy1, fx2, fy2 = f.rect.wrapper; x1,y1,x2,y2 = min(x1,fx1), min(y1,fy1), max(x2,fx2), max(y2,fy2) - t = f.text; l = len(t); - if l > 0: texts.append(t); rank_w += f.rank*l; txt_len += l - if txt_len == 0: continue; - m_txt = " ".join(texts); m_rank = rank_w/txt_len if txt_len>0 else 0.0 - m_rect = MDRRectangle(lt=(x1,y1), rt=(x2,y1), lb=(x1,y2), rb=(x2,y2)) + fx1, fy1, fx2, fy2 = f.rect.wrapper + x1, y1, x2, y2 = min(x1, fx1), min(y1, fy1), max(x2, fx2), max(y2, fy2) + t = f.text + l = len(t) + if l > 0: + texts.append(t) + rank_w += f.rank * l + txt_len += l + if txt_len == 0: + continue + m_txt = " ".join(texts) + m_rank = rank_w / txt_len if txt_len > 0 else 0.0 + m_rect = MDRRectangle(lt=(x1, y1), rt=(x2, y1), lb=(x1, y2), rb=(x2, y2)) merged.append(MDROcrFragment(order=min_order, text=m_txt, rank=m_rank, rect=m_rect)) merged.sort(key=lambda f: (f.order, f.rect.lt[1], f.rect.lt[0])) - for i, f in enumerate(merged): f.order = i + for i, f in enumerate(merged): + f.order = i return merged # --- MDR Layout Processing --- _MDR_CORRECTION_MIN_OVERLAP = 0.5 def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image, layout: MDRLayoutElement): - if not layout.fragments: return; + if not layout.fragments: + return try: - x1,y1,x2,y2 = layout.rect.wrapper; margin=5; crop_box=(max(0,round(x1)-margin), max(0,round(y1)-margin), min(source_img.width,round(x2)+margin), min(source_img.height,round(y2)+margin)) - if crop_box[0]>=crop_box[2] or crop_box[1]>=crop_box[3]: return; - cropped = source_img.crop(crop_box); off_x, off_y = crop_box[0], crop_box[1] - except Exception as e: print(f"Correct: Crop error: {e}"); return; + x1, y1, x2, y2 = layout.rect.wrapper + margin = 5 + crop_box = (max(0, round(x1) - margin), max(0, round(y1) - margin), min(source_img.width, round(x2) + margin), min(source_img.height, round(y2) + margin)) + if crop_box[0] >= crop_box[2] or crop_box[1] >= crop_box[3]: + return + cropped = source_img.crop(crop_box) + off_x, off_y = crop_box[0], crop_box[1] + except Exception as e: + print(f"Correct: Crop error: {e}") + return try: - cropped_np = np.array(cropped.convert("RGB"))[:,:,::-1]; new_frags_local = list(ocr_engine.find_text_fragments(cropped_np)) - except Exception as e: print(f"Correct: OCR error: {e}"); return; + cropped_np = np.array(cropped.convert("RGB"))[:, :, ::-1] + new_frags_local = list(ocr_engine.find_text_fragments(cropped_np)) + except Exception as e: + print(f"Correct: OCR error: {e}") + return new_frags_global = [] for f in new_frags_local: - r=f.rect; lt,rt,lb,rb=r.lt,r.rt,r.lb,r.rb; f.rect=MDRRectangle(lt=(lt[0]+off_x,lt[1]+off_y), rt=(rt[0]+off_x,rt[1]+off_y), lb=(lb[0]+off_x,lb[1]+off_y), rb=(rb[0]+off_x,rb[1]+off_y)); new_frags_global.append(f) - orig_frags = layout.fragments; matched, unmatched_orig = [], []; used_new = set() + r = f.rect + lt, rt, lb, rb = r.lt, r.rt, r.lb, r.rb + f.rect = MDRRectangle(lt=(lt[0] + off_x, lt[1] + off_y), rt=(rt[0] + off_x, rt[1] + off_y), lb=(lb[0] + off_x, lb[1] + off_y), rb=(rb[0] + off_x, rb[1] + off_y)) + new_frags_global.append(f) + orig_frags = layout.fragments + matched, unmatched_orig = [], [] + used_new = set() for i, orig_f in enumerate(orig_frags): - best_j, best_rate = -1, -1.0; - try: poly_o = Polygon(orig_f.rect); - except: continue; - if not poly_o.is_valid: continue; + best_j, best_rate = -1, -1.0 + try: + poly_o = Polygon(orig_f.rect) + except: + continue + if not poly_o.is_valid: + continue for j, new_f in enumerate(new_frags_global): - if j in used_new: continue; - try: poly_n = Polygon(new_f.rect); - except: continue; - if not poly_n.is_valid: continue; - try: inter=poly_o.intersection(poly_n); union=poly_o.union(poly_n) - except: continue; + if j in used_new: + continue + try: + poly_n = Polygon(new_f.rect) + except: + continue + if not poly_n.is_valid: + continue + try: + inter = poly_o.intersection(poly_n) + union = poly_o.union(poly_n) + except: + continue rate = inter.area / union.area if union.area > 1e-6 else 0.0 - if rate > _MDR_CORRECTION_MIN_OVERLAP and rate > best_rate: best_rate = rate; best_j = j - if best_j != -1: matched.append((orig_f, new_frags_global[best_j])); used_new.add(best_j) - else: unmatched_orig.append(orig_f) + if rate > _MDR_CORRECTION_MIN_OVERLAP and rate > best_rate: + best_rate = rate + best_j = j + if best_j != -1: + matched.append((orig_f, new_frags_global[best_j])) + used_new.add(best_j) + else: + unmatched_orig.append(orig_f) unmatched_new = [f for j, f in enumerate(new_frags_global) if j not in used_new] - final = [n if n.rank >= o.rank else o for o, n in matched]; final.extend(unmatched_orig); final.extend(unmatched_new) - layout.fragments = final; layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + final = [n if n.rank >= o.rank else o for o, n in matched] + final.extend(unmatched_orig) + final.extend(unmatched_new) + layout.fragments = final + layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) # --- MDR OCR Engine --- @@ -895,12 +1309,23 @@ class MDROcrEngine: for box_pts, (txt, conf) in zip(boxes, recs): if not txt or mdr_is_whitespace(txt) or conf < 0.1: continue pts = [(float(p[0]), float(p[1])) for p in box_pts] - if len(pts)==4: r=MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]); if r.is_valid and r.area>1: yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r) + if len(pts) == 4: + r = MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]) + if r.is_valid and r.area > 1: + yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r) def _preprocess(self, img: np.ndarray) -> np.ndarray: - if len(img.shape)==3 and img.shape[2]==4: a=img[:,:,3]/255.0; bg=(255,255,255); new=np.zeros_like(img[:,:,:3]); [setattr(new[:,:,i], 'flags.writeable', True) for i in range(3)]; [np.copyto(new[:,:,i], (bg[i]*(1-a)+img[:,:,i]*a)) for i in range(3)]; img=new.astype(np.uint8) - elif len(img.shape)==2: img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - elif not (len(img.shape)==3 and img.shape[2]==3): raise ValueError("Unsupported image format") + if len(img.shape) == 3 and img.shape[2] == 4: + a = img[:, :, 3] / 255.0 + bg = (255, 255, 255) + new = np.zeros_like(img[:, :, :3]) + [setattr(new[:, :, i], 'flags.writeable', True) for i in range(3)] + [np.copyto(new[:, :, i], (bg[i] * (1 - a) + img[:, :, i] * a)) for i in range(3)] + img = new.astype(np.uint8) + elif len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif not (len(img.shape) == 3 and img.shape[2] == 3): + raise ValueError("Unsupported image format") return img # --- MDR Layout Reading Internals --- @@ -917,18 +1342,30 @@ def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3 return {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]: - if length == 0: return []; rel_logits = logits[1:length+1, :length]; orders = rel_logits.argmax(dim=1).tolist() + if length == 0: + return [] + rel_logits = logits[1 : length + 1, :length] + orders = rel_logits.argmax(dim=1).tolist() while True: - conflicts = defaultdict(list); [conflicts[order].append(idx) for idx, order in enumerate(orders)] + conflicts = defaultdict(list) + [conflicts[order].append(idx) for idx, order in enumerate(orders)] conflicting_orders = {o: idxs for o, idxs in conflicts.items() if len(idxs) > 1} - if not conflicting_orders: break + if not conflicting_orders: + break for order, idxs in conflicting_orders.items(): - best_idx, max_logit = -1, -float('inf') - for idx in idxs: logit = rel_logits[idx, order].item(); if logit > max_logit: max_logit, best_idx = logit, idx + best_idx = -1 + max_logit = -float('inf') + for idx in idxs: + logit = rel_logits[idx, order].item() + if logit > max_logit: + max_logit = logit + best_idx = idx for idx in idxs: if idx != best_idx: - orig_logit = rel_logits[idx, order].item(); rel_logits[idx, order] = -float('inf') - orders[idx] = rel_logits[idx, :].argmax().item(); rel_logits[idx, order] = orig_logit + orig_logit = rel_logits[idx, order].item() + rel_logits[idx, order] = -float('inf') + orders[idx] = rel_logits[idx, :].argmax().item() + rel_logits[idx, order] = orig_logit return orders # --- MDR Layout Reading Engine --- @@ -953,76 +1390,124 @@ class MDRLayoutReader: return self._model def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]: - w, h = size; - if w<=0 or h<=0 or not layouts: return layouts; + w, h = size + if w <= 0 or h <= 0 or not layouts: + return layouts model = self._get_model() if model is None: # Fallback geometric sort - layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])); nfo = 0 - for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])); [setattr(f,'order',i+nfo) for i,f in enumerate(l.fragments)]; nfo += len(l.fragments) + layouts.sort(key=lambda l: (l.rect.lt[1], l.rect.lt[0])) + nfo = 0 + for l in layouts: + l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + [setattr(f, 'order', i + nfo) for i, f in enumerate(l.fragments)] + nfo += len(l.fragments) return layouts bbox_list = self._prepare_bboxes(layouts, w, h) - if bbox_list is None or len(bbox_list) == 0: return layouts - l_size = 1000.0; xs, ys = l_size/float(w), l_size/float(h) + if bbox_list is None or len(bbox_list) == 0: + return layouts + l_size = 1000.0 + xs = l_size / float(w) + ys = l_size / float(h) scaled_bboxes = [] for bbox in bbox_list: x0, y0, x1, y1 = bbox.value - sx0, sy0 = max(0, min(l_size-1, round(x0*xs))), max(0, min(l_size-1, round(y0*ys))) - sx1, sy1 = max(0, min(l_size-1, round(x1*xs))), max(0, min(l_size-1, round(y1*ys))) + sx0 = max(0, min(l_size - 1, round(x0 * xs))) + sy0 = max(0, min(l_size - 1, round(y0 * ys))) + sx1 = max(0, min(l_size - 1, round(x1 * xs))) + sy1 = max(0, min(l_size - 1, round(y1 * ys))) scaled_bboxes.append([min(sx0, sx1), min(sy0, sy1), max(sx0, sx1), max(sy0, sy1)]) orders = [] try: with torch.no_grad(): - inputs = mdr_boxes_to_reader_inputs(scaled_bboxes); inputs = mdr_prepare_reader_inputs(inputs, model) - logits = model(**inputs).logits.cpu().squeeze(0); orders = mdr_parse_reader_logits(logits, len(bbox_list)) - except Exception as e: print(f"MDR LayoutReader prediction error: {e}"); return layouts # Fallback - if len(orders) != len(bbox_list): print("MDR LayoutReader order mismatch"); return layouts # Fallback - for i, order_idx in enumerate(orders): bbox_list[i].order = order_idx + inputs = mdr_boxes_to_reader_inputs(scaled_bboxes) + inputs = mdr_prepare_reader_inputs(inputs, model) + logits = model(**inputs).logits.cpu().squeeze(0) + orders = mdr_parse_reader_logits(logits, len(bbox_list)) + except Exception as e: + print(f"MDR LayoutReader prediction error: {e}") + return layouts # Fallback + if len(orders) != len(bbox_list): + print("MDR LayoutReader order mismatch") + return layouts # Fallback + for i, order_idx in enumerate(orders): + bbox_list[i].order = order_idx return self._apply_order(layouts, bbox_list) def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None: - line_h = self._estimate_line_h(layouts); bbox_list = [] + line_h = self._estimate_line_h(layouts) + bbox_list = [] for i, l in enumerate(layouts): - if l.cls == MDRLayoutClass.PLAIN_TEXT and l.fragments: [bbox_list.append(_MDR_ReaderBBox(i, j, False, -1, f.rect.wrapper)) for j, f in enumerate(l.fragments)] - else: bbox_list.extend(self._gen_virtual(l, i, line_h, w, h)) - if len(bbox_list) > _MDR_MAX_LEN: print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})"); return None - bbox_list.sort(key=lambda b: (b.value[1], b.value[0])); return bbox_list + if l.cls == MDRLayoutClass.PLAIN_TEXT and l.fragments: + [bbox_list.append(_MDR_ReaderBBox(i, j, False, -1, f.rect.wrapper)) for j, f in enumerate(l.fragments)] + else: + bbox_list.extend(self._gen_virtual(l, i, line_h, w, h)) + if len(bbox_list) > _MDR_MAX_LEN: + print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})") + return None + bbox_list.sort(key=lambda b: (b.value[1], b.value[0])) + return bbox_list def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]: - layout_map = defaultdict(list); [layout_map[b.layout_index].append(b) for b in bbox_list] + layout_map = defaultdict(list) + [layout_map[b.layout_index].append(b) for b in bbox_list] layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes] - layout_orders.sort(key=lambda x: x[1]); sorted_layouts = [layouts[idx] for idx, _ in layout_orders] + layout_orders.sort(key=lambda x: x[1]) + sorted_layouts = [layouts[idx] for idx, _ in layout_orders] nfo = 0 for l in sorted_layouts: - frags = l.fragments; - if not frags: continue; + frags = l.fragments + if not frags: + continue frag_bboxes = [b for b in layout_map[layouts.index(l)] if not b.virtual] - if frag_bboxes: idx_to_order = {b.fragment_index: b.order for b in frag_bboxes}; frags.sort(key=lambda f: idx_to_order.get(frags.index(f), float('inf'))) - else: frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) - for frag in frags: frag.order = nfo; nfo += 1 + if frag_bboxes: + idx_to_order = {b.fragment_index: b.order for b in frag_bboxes} + frags.sort(key=lambda f: idx_to_order.get(frags.index(f), float('inf'))) + else: + frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + for frag in frags: + frag.order = nfo + nfo += 1 return sorted_layouts def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float: - heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1]>0] + heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1] > 0] return self._median(heights) if heights else 15.0 def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]: - x0,y0,x1,y1 = l.rect.wrapper; lh,lw = y1-y0,x1-x0 - if lh<=0 or lw<=0 or line_h<=0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(x0,y0,x1,y1)); return + x0, y0, x1, y1 = l.rect.wrapper + lh = y1 - y0 + lw = x1 - x0 + if lh <= 0 or lw <= 0 or line_h <= 0: + yield _MDR_ReaderBBox(l_idx, -1, True, -1, (x0, y0, x1, y1)) + return lines = 1 - if lh > line_h*1.5: - if lh<=ph*0.25 or lw>=pw*0.5: lines=3 - elif lw>pw*0.25: lines = 3 if lw>pw*0.4 else 2 - elif lw<=pw*0.25: lines = max(1, int(lh/(line_h*1.5))) if lh/lw>1.5 else 2 - else: lines = max(1, int(round(lh/line_h))) - lines = max(1, lines); act_line_h = lh/lines; cur_y = y0 + if lh > line_h * 1.5: + if lh <= ph * 0.25 or lw >= pw * 0.5: + lines = 3 + elif lw > pw * 0.25: + lines = 3 if lw > pw * 0.4 else 2 + elif lw <= pw * 0.25: + lines = max(1, int(lh / (line_h * 1.5))) if lh / lw > 1.5 else 2 + else: + lines = max(1, int(round(lh / line_h))) + lines = max(1, lines) + act_line_h = lh / lines + cur_y = y0 for i in range(lines): - ly0,ly1 = max(0,min(ph,cur_y)), max(0,min(ph,cur_y+act_line_h)); lx0,lx1 = max(0,min(pw,x0)), max(0,min(pw,x1)) - if ly1>ly0 and lx1>lx0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(lx0,ly0,lx1,ly1)) + ly0 = max(0, min(ph, cur_y)) + ly1 = max(0, min(ph, cur_y + act_line_h)) + lx0 = max(0, min(pw, x0)) + lx1 = max(0, min(pw, x1)) + if ly1 > ly0 and lx1 > lx0: + yield _MDR_ReaderBBox(l_idx, -1, True, -1, (lx0, ly0, lx1, ly1)) cur_y += act_line_h def _median(self, nums: list[float|int]) -> float: - if not nums: return 0.0; s_nums = sorted(nums); n = len(s_nums) - return float(s_nums[n//2]) if n%2==1 else float((s_nums[n//2-1]+s_nums[n//2])/2.0) + if not nums: + return 0.0 + s_nums = sorted(nums) + n = len(s_nums) + return float(s_nums[n // 2]) if n % 2 == 1 else float((s_nums[n // 2 - 1] + s_nums[n // 2]) / 2.0) # --- MDR LaTeX Extractor --- class MDRLatexExtractor: @@ -1034,7 +1519,8 @@ class MDRLatexExtractor: def extract(self, image: Image) -> str | None: if LatexOCR is None: return None; - image = mdr_expand_image(image, 0.1); model = self._get_model() + image = mdr_expand_image(image, 0.1) + model = self._get_model() if model is None: return None; try: with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None @@ -1042,16 +1528,30 @@ class MDRLatexExtractor: def _get_model(self) -> LatexOCR | None: if self._model is None and LatexOCR is not None: - mdr_ensure_directory(self._model_path); wp, rp, cp = Path(self._model_path)/"weights.pth", Path(self._model_path)/"image_resizer.pth", Path(self._model_path)/"config.yaml" - if not wp.exists() or not rp.exists(): print("Downloading MDR LaTeX models..."); self._download() - if not cp.exists(): print(f"Warn: MDR LaTeX config not found {self._model_path}") - try: args = Munch({"config":str(cp), "checkpoint":str(wp), "device":self._device, "no_cuda":self._device=="cpu", "no_resize":False, "temperature":0.0}); self._model = LatexOCR(args); print(f"MDR LaTeX loaded on {self._device}.") - except Exception as e: print(f"ERROR initializing MDR LatexOCR: {e}"); self._model = None + mdr_ensure_directory(self._model_path) + wp = Path(self._model_path) / "weights.pth" + rp = Path(self._model_path) / "image_resizer.pth" + cp = Path(self._model_path) / "config.yaml" + if not wp.exists() or not rp.exists(): + print("Downloading MDR LaTeX models...") + self._download() + if not cp.exists(): + print(f"Warn: MDR LaTeX config not found {self._model_path}") + try: + args = Munch({"config": str(cp), "checkpoint": str(wp), "device": self._device, "no_cuda": self._device == "cpu", "no_resize": False, "temperature": 0.0}) + self._model = LatexOCR(args) + print(f"MDR LaTeX loaded on {self._device}.") + except Exception as e: + print(f"ERROR initializing MDR LatexOCR: {e}") + self._model = None return self._model def _download(self): - tag = "v0.0.1"; base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/"; files = {"weights.pth": base+"weights.pth", "image_resizer.pth": base+"image_resizer.pth"} - mdr_ensure_directory(self._model_path); [mdr_download_model(url, Path(self._model_path)/name) for name, url in files.items() if not (Path(self._model_path)/name).exists()] + tag = "v0.0.1" + base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/" + files = {"weights.pth": base + "weights.pth", "image_resizer.pth": base + "image_resizer.pth"} + mdr_ensure_directory(self._model_path) + [mdr_download_model(url, Path(self._model_path) / name) for name, url in files.items() if not (Path(self._model_path) / name).exists()] # --- MDR Table Parser --- MDRTableOutputFormat = Literal["latex", "markdown", "html"] @@ -1072,7 +1572,8 @@ class MDRTableParser: elif format == MDRTableLayoutParsedFormat.MARKDOWN: fmt="markdown" elif format == MDRTableLayoutParsedFormat.HTML: fmt="html" else: return None - image = mdr_expand_image(image, 0.05); model = self._get_model() + image = mdr_expand_image(image, 0.05) + model = self._get_model() if model is None: return None; try: img_rgb = image.convert('RGB') if image.mode!='RGB' else image @@ -1118,19 +1619,28 @@ class MDRImageOptimizer: def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) def receive_fragments(self, fragments: list[MDROcrFragment]): - self._fragments = fragments; - if not fragments: return; + self._fragments = fragments + if not fragments: + return self._rotation = mdr_calculate_image_rotation(fragments) - if abs(self._rotation) < _MDR_TINY_ROTATION: self._rotation = 0.0; return + if abs(self._rotation) < _MDR_TINY_ROTATION: + self._rotation = 0.0 + return orig_sz = self._raw.size - try: self._image = self._raw.rotate(-np.degrees(self._rotation), resample=PILResampling.BICUBIC, fillcolor=(255,255,255), expand=True) - except Exception as e: print(f"Optimizer rotation error: {e}"); self._rotation=0.0; self._image=self._raw; return + try: + self._image = self._raw.rotate(-np.degrees(self._rotation), resample=PILResampling.BICUBIC, fillcolor=(255, 255, 255), expand=True) + except Exception as e: + print(f"Optimizer rotation error: {e}") + self._rotation = 0.0 + self._image = self._raw + return new_sz = self._image.size self._rot_ctx = _MDR_RotationContext( fragment_origin_rectangles=[f.rect for f in fragments], to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False), to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True)) - adj = self._rot_ctx.to_new; [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r:=f.rect)] + adj = self._rot_ctx.to_new + [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r := f.rect)] def finalize_layout_coords(self, layouts: list[MDRLayoutElement]): if self._rot_ctx is None or self._adjust_points: return @@ -1141,18 +1651,31 @@ class MDRImageOptimizer: def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image: """Clips a potentially rotated rectangle from an image.""" try: - h_rot, _ = mdr_calculate_rectangle_rotation(rect); avg_w, avg_h = rect.size - if avg_w<=0 or avg_h<=0: return new_image("RGB", (1,1), (255,255,255)) - tx, ty = rect.lt; trans_orig = np.array([[1,0,-tx],[0,1,-ty],[0,0,1]]) - cos_r, sin_r = cos(-h_rot), sin(-h_rot); rot = np.array([[cos_r,-sin_r,0],[sin_r,cos_r,0],[0,0,1]]) - pad_dx, pad_dy = wrap_w/2.0, wrap_h/2.0; trans_pad = np.array([[1,0,pad_dx],[0,1,pad_dy],[0,0,1]]) + h_rot, _ = mdr_calculate_rectangle_rotation(rect) + avg_w, avg_h = rect.size + if avg_w <= 0 or avg_h <= 0: + return new_image("RGB", (1, 1), (255, 255, 255)) + tx, ty = rect.lt + trans_orig = np.array([[1, 0, -tx], [0, 1, -ty], [0, 0, 1]]) + cos_r = cos(-h_rot) + sin_r = sin(-h_rot) + rot = np.array([[cos_r, -sin_r, 0], [sin_r, cos_r, 0], [0, 0, 1]]) + pad_dx = wrap_w / 2.0 + pad_dy = wrap_h / 2.0 + trans_pad = np.array([[1, 0, pad_dx], [0, 1, pad_dy], [0, 0, 1]]) matrix = trans_pad @ rot @ trans_orig - try: inv_matrix = np.linalg.inv(matrix) - except np.linalg.LinAlgError: x0,y0,x1,y1=rect.wrapper; return image.crop((round(x0),round(y0),round(x1),round(y1))) - p_mat = (inv_matrix[0,0], inv_matrix[0,1], inv_matrix[0,2], inv_matrix[1,0], inv_matrix[1,1], inv_matrix[1,2]) - out_w, out_h = ceil(avg_w+wrap_w), ceil(avg_h+wrap_h) - return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, fillcolor=(255,255,255)) - except Exception as e: print(f"MDR Clipping error: {e}"); return new_image("RGB", (10,10), (255,255,255)) + try: + inv_matrix = np.linalg.inv(matrix) + except np.linalg.LinAlgError: + x0, y0, x1, y1 = rect.wrapper + return image.crop((round(x0), round(y0), round(x1), round(y1))) + p_mat = (inv_matrix[0, 0], inv_matrix[0, 1], inv_matrix[0, 2], inv_matrix[1, 0], inv_matrix[1, 1], inv_matrix[1, 2]) + out_w = ceil(avg_w + wrap_w) + out_h = ceil(avg_h + wrap_h) + return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, fillcolor=(255, 255, 255)) + except Exception as e: + print(f"MDR Clipping error: {e}") + return new_image("RGB", (10, 10), (255, 255, 255)) def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image: """Clips a layout region from the MDRExtractionResult image.""" @@ -1165,19 +1688,40 @@ _MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); _MDR_LAYOUT_COLORS = { MDRLayoutClass def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None: """Draws layout and fragment boxes onto an image for debugging.""" if not layouts: return; - try: l_font, f_font = load_default(size=25), load_default(size=15); draw = ImageDraw.Draw(image, mode="RGBA") - except Exception as e: print(f"MDR Plot init error: {e}"); return + try: + l_font = load_default(size=25) + f_font = load_default(size=15) # Not used currently, but kept for potential future use + draw = ImageDraw.Draw(image, mode="RGBA") + except Exception as e: + print(f"MDR Plot init error: {e}") + return def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA): - try: x,y=pos; txt=str(num); txt_pos=(round(x)+3, round(y)+1); bbox=draw.textbbox(txt_pos,txt,font=font); bg_rect=(bbox[0]-2,bbox[1]-1,bbox[2]+2,bbox[3]+1); bg_color=(color[0],color[1],color[2],180); draw.rectangle(bg_rect,fill=bg_color); draw.text(txt_pos,txt,font=font,fill=(255,255,255,255)) - except Exception as e: print(f"MDR Draw num error: {e}") + try: + x, y = pos + txt = str(num) + txt_pos = (round(x) + 3, round(y) + 1) + bbox = draw.textbbox(txt_pos, txt, font=font) + bg_rect = (bbox[0] - 2, bbox[1] - 1, bbox[2] + 2, bbox[3] + 1) + bg_color = (color[0], color[1], color[2], 180) + draw.rectangle(bg_rect, fill=bg_color) + draw.text(txt_pos, txt, font=font, fill=(255, 255, 255, 255)) + except Exception as e: + print(f"MDR Draw num error: {e}") + for i, l in enumerate(layouts): - try: l_color = _MDR_LAYOUT_COLORS.get(l.cls, _MDR_DEFAULT_COLOR); draw.polygon([p for p in l.rect], outline=l_color, width=3); _draw_num(l.rect.lt, i+1, l_font, l_color) - except Exception as e: print(f"MDR Layout draw error: {e}") + try: + l_color = _MDR_LAYOUT_COLORS.get(l.cls, _MDR_DEFAULT_COLOR) + draw.polygon([p for p in l.rect], outline=l_color, width=3) + _draw_num(l.rect.lt, i + 1, l_font, l_color) + except Exception as e: + print(f"MDR Layout draw error: {e}") for l in layouts: for f in l.fragments: - try: draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1) - except Exception as e: print(f"MDR Fragment draw error: {e}") + try: + draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1) + except Exception as e: + print(f"MDR Fragment draw error: {e}") # --- MDR Extraction Engine --- class MDRExtractionEngine: @@ -1240,56 +1784,95 @@ class MDRExtractionEngine: def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult: """Analyzes a single page image to extract layout and content.""" - print(" Engine: Analyzing image..."); optimizer = MDRImageOptimizer(image, adjust_points) - print(" Engine: Initial OCR..."); frags = list(self._ocr_engine.find_text_fragments(optimizer.image_np)); print(f" Engine: {len(frags)} fragments found.") - optimizer.receive_fragments(frags); frags = optimizer._fragments # Use adjusted fragments - print(" Engine: Layout detection..."); yolo = self._get_yolo_model(); raw_layouts = [] + print(" Engine: Analyzing image...") + optimizer = MDRImageOptimizer(image, adjust_points) + print(" Engine: Initial OCR...") + frags = list(self._ocr_engine.find_text_fragments(optimizer.image_np)) + print(f" Engine: {len(frags)} fragments found.") + optimizer.receive_fragments(frags) + frags = optimizer._fragments # Use adjusted fragments + print(" Engine: Layout detection...") + yolo = self._get_yolo_model() + raw_layouts = [] if yolo: - try: raw_layouts = list(self._run_yolo_detection(optimizer.image, yolo)); print(f" Engine: {len(raw_layouts)} raw layouts found.") - except Exception as e: print(f" Engine: YOLO error: {e}") - print(" Engine: Matching fragments..."); layouts = self._match_fragments_to_layouts(frags, raw_layouts) - print(" Engine: Removing overlaps..."); layouts = mdr_remove_overlap_layouts(layouts); print(f" Engine: {len(layouts)} layouts after overlap removal.") - if self._ocr_each and layouts: print(" Engine: OCR correction..."); self._run_ocr_correction(optimizer.image, layouts) - print(" Engine: Determining reading order..."); layouts = self._layout_reader.determine_reading_order(layouts, optimizer.image.size) - layouts = [l for l in layouts if self._should_keep_layout(l)]; print(f" Engine: {len(layouts)} layouts after filtering.") - if self._ext_table or self._ext_formula: print(" Engine: Parsing tables/formulas..."); self._parse_special_layouts(layouts, optimizer) - print(" Engine: Merging fragments..."); [setattr(l, 'fragments', mdr_merge_fragments_into_lines(l.fragments)) for l in layouts] - print(" Engine: Finalizing coords..."); optimizer.finalize_layout_coords(layouts) + try: + raw_layouts = list(self._run_yolo_detection(optimizer.image, yolo)) + print(f" Engine: {len(raw_layouts)} raw layouts found.") + except Exception as e: + print(f" Engine: YOLO error: {e}") + print(" Engine: Matching fragments...") + layouts = self._match_fragments_to_layouts(frags, raw_layouts) + print(" Engine: Removing overlaps...") + layouts = mdr_remove_overlap_layouts(layouts) + print(f" Engine: {len(layouts)} layouts after overlap removal.") + if self._ocr_each and layouts: + print(" Engine: OCR correction...") + self._run_ocr_correction(optimizer.image, layouts) + print(" Engine: Determining reading order...") + layouts = self._layout_reader.determine_reading_order(layouts, optimizer.image.size) + layouts = [l for l in layouts if self._should_keep_layout(l)] + print(f" Engine: {len(layouts)} layouts after filtering.") + if self._ext_table or self._ext_formula: + print(" Engine: Parsing tables/formulas...") + self._parse_special_layouts(layouts, optimizer) + print(" Engine: Merging fragments...") + [setattr(l, 'fragments', mdr_merge_fragments_into_lines(l.fragments)) for l in layouts] + print(" Engine: Finalizing coords...") + optimizer.finalize_layout_coords(layouts) print(" Engine: Analysis complete.") return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image) def _run_yolo_detection(self, img: Image, yolo: YOLOv10) -> Generator[MDRLayoutElement, None, None]: - img_rgb = img.convert("RGB"); res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False) - if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: return + img_rgb = img.convert("RGB") + res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False) + if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: + return boxes = res[0].boxes for cls_id_t, xyxy_t in zip(boxes.cls, boxes.xyxy): - cls_id = int(cls_id_t.item()); - try: cls = MDRLayoutClass(cls_id) - except ValueError: continue - x1,y1,x2,y2 = [c.item() for c in xyxy_t]; rect = MDRRectangle(lt=(x1,y1), rt=(x2,y1), lb=(x1,y2), rb=(x2,y2)) + cls_id = int(cls_id_t.item()) + try: + cls = MDRLayoutClass(cls_id) + except ValueError: + continue + x1, y1, x2, y2 = [c.item() for c in xyxy_t] + rect = MDRRectangle(lt=(x1, y1), rt=(x2, y1), lb=(x1, y2), rb=(x2, y2)) if rect.is_valid and rect.area > 10: - if cls == MDRLayoutClass.TABLE: yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None) - elif cls == MDRLayoutClass.ISOLATE_FORMULA: yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None) - elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[]) + if cls == MDRLayoutClass.TABLE: + yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None) + elif cls == MDRLayoutClass.ISOLATE_FORMULA: + yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None) + elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: + yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[]) def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]: - if not frags or not layouts: return layouts + if not frags or not layouts: + return layouts layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts] for frag in frags: - try: frag_poly = Polygon(frag.rect); frag_area = frag_poly.area - except: continue - if not frag_poly.is_valid or frag_area < 1e-6: continue + try: + frag_poly = Polygon(frag.rect) + frag_area = frag_poly.area + except: + continue + if not frag_poly.is_valid or frag_area < 1e-6: + continue candidates = [] # (layout_idx, layout_area, overlap_ratio) for idx, l_poly in enumerate(layout_polys): - if l_poly is None: continue - try: inter_area = frag_poly.intersection(l_poly).area - except: continue + if l_poly is None: + continue + try: + inter_area = frag_poly.intersection(l_poly).area + except: + continue overlap = inter_area / frag_area if frag_area > 0 else 0 - if overlap > 0.85: candidates.append((idx, l_poly.area, overlap)) + if overlap > 0.85: + candidates.append((idx, l_poly.area, overlap)) if candidates: - candidates.sort(key=lambda x: (x[1], -x[2])); best_idx = candidates[0][0] + candidates.sort(key=lambda x: (x[1], -x[2])) + best_idx = candidates[0][0] layouts[best_idx].fragments.append(frag) - for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) + for l in layouts: + l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0])) return layouts def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]): @@ -1302,12 +1885,20 @@ class MDRExtractionEngine: img_to_clip = optimizer.image for l in layouts: if isinstance(l, MDRFormulaLayoutElement) and self._ext_formula: - try: f_img = mdr_clip_from_image(img_to_clip, l.rect); l.latex = self._latex_extractor.extract(f_img) if f_img.width>1 and f_img.height>1 else None - except Exception as e: print(f" Engine: LaTeX extract error: {e}") + try: + f_img = mdr_clip_from_image(img_to_clip, l.rect) + l.latex = self._latex_extractor.extract(f_img) if f_img.width > 1 and f_img.height > 1 else None + except Exception as e: + print(f" Engine: LaTeX extract error: {e}") elif isinstance(l, MDRTableLayoutElement) and self._ext_table is not None: - try: t_img = mdr_clip_from_image(img_to_clip, l.rect); parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width>1 and t_img.height>1 else None - except Exception as e: print(f" Engine: Table parse error: {e}"); parsed = None - if parsed: l.parsed = (parsed, self._ext_table) + try: + t_img = mdr_clip_from_image(img_to_clip, l.rect) + parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width > 1 and t_img.height > 1 else None + except Exception as e: + print(f" Engine: Table parse error: {e}") + parsed = None + if parsed: + l.parsed = (parsed, self._ext_table) def _should_keep_layout(self, l: MDRLayoutElement) -> bool: if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True @@ -1337,62 +1928,113 @@ class MDRPageSection: def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None: """Links matching shapes between this section and the next.""" - if offset not in (1,2): return + if offset not in (1, 2): + return matches_matrix = [[sn for sn in next_section._shapes if self._shapes_match(ss, sn)] for ss in self._shapes] origin_pair = self._find_origin_pair(matches_matrix, next_section._shapes) - if origin_pair is None: return - orig_s, orig_n = origin_pair; orig_s_pt, orig_n_pt = orig_s.layout.rect.lt, orig_n.layout.rect.lt + if origin_pair is None: + return + orig_s, orig_n = origin_pair + orig_s_pt = orig_s.layout.rect.lt + orig_n_pt = orig_n.layout.rect.lt for i, s1 in enumerate(self._shapes): - potentials = matches_matrix[i]; - if not potentials: continue - r1_rel = self._relative_rect(orig_s_pt, s1.layout.rect); best_s2, max_ovr = None, -1.0 + potentials = matches_matrix[i] + if not potentials: + continue + r1_rel = self._relative_rect(orig_s_pt, s1.layout.rect) + best_s2 = None + max_ovr = -1.0 for s2 in potentials: - r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect); ovr = self._symmetric_iou(r1_rel, r2_rel) - if ovr > max_ovr: max_ovr, best_s2 = ovr, s2 - if max_ovr >= 0.80 and best_s2 is not None: s1.nex[offset-1] = best_s2.layout; best_s2.pre[offset-1] = s1.layout # Link both ways + r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect) + ovr = self._symmetric_iou(r1_rel, r2_rel) + if ovr > max_ovr: + max_ovr = ovr + best_s2 = s2 + if max_ovr >= 0.80 and best_s2 is not None: + s1.nex[offset - 1] = best_s2.layout + best_s2.pre[offset - 1] = s1.layout # Link both ways def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool: - l1, l2 = s1.layout, s2.layout; sz1, sz2 = l1.rect.size, l2.rect.size; thresh = 0.90 - if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: return False - f1, f2 = l1.fragments, l2.fragments; c1, c2 = len(f1), len(f2) - if c1==0 and c2==0: return True; - if c1==0 or c2==0: return False; - matches, used_f2 = 0, [False]*c2 + l1 = s1.layout + l2 = s2.layout + sz1 = l1.rect.size + sz2 = l2.rect.size + thresh = 0.90 + if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: + return False + f1 = l1.fragments + f2 = l2.fragments + c1 = len(f1) + c2 = len(f2) + if c1 == 0 and c2 == 0: + return True + if c1 == 0 or c2 == 0: + return False + matches = 0 + used_f2 = [False] * c2 for frag1 in f1: - best_j, max_sim = -1, -1.0 + best_j = -1 + max_sim = -1.0 for j, frag2 in enumerate(f2): - if not used_f2[j]: sim = self._fragment_sim(l1, l2, frag1, frag2); if sim > max_sim: max_sim, best_j = sim, j - if max_sim > 0.75: matches += 1; if best_j != -1: used_f2[best_j] = True - max_c = max(c1, c2); rate_frags = matches / max_c + if not used_f2[j]: + sim = self._fragment_sim(l1, l2, frag1, frag2) + if sim > max_sim: + max_sim = sim + best_j = j + if max_sim > 0.75: + matches += 1 + if best_j != -1: + used_f2[best_j] = True + max_c = max(c1, c2) + rate_frags = matches / max_c return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95)) def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float: - r1_rel = self._relative_rect(l1.rect.lt, f1.rect); r2_rel = self._relative_rect(l2.rect.lt, f2.rect) - geom_sim = self._symmetric_iou(r1_rel, r2_rel); text_sim, _ = mdr_check_text_similarity(f1.text, f2.text) + r1_rel = self._relative_rect(l1.rect.lt, f1.rect) + r2_rel = self._relative_rect(l2.rect.lt, f2.rect) + geom_sim = self._symmetric_iou(r1_rel, r2_rel) + text_sim, _ = mdr_check_text_similarity(f1.text, f2.text) return (geom_sim + text_sim) / 2.0 def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None: - best_pair, min_dist2 = None, float('inf') + best_pair = None + min_dist2 = float('inf') for i, s1 in enumerate(self._shapes): - match_list = matches_matrix[i]; - if not match_list: continue - for s2 in match_list: dist2 = s1.distance2 + s2.distance2; if dist2 < min_dist2: min_dist2, best_pair = dist2, (s1, s2) + match_list = matches_matrix[i] + if not match_list: + continue + for s2 in match_list: + dist2 = s1.distance2 + s2.distance2 + if dist2 < min_dist2: + min_dist2 = dist2 + best_pair = (s1, s2) return best_pair def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool: if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx] def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle: - ox, oy = origin; r=rect; return MDRRectangle(lt=(r.lt[0]-ox, r.lt[1]-oy), rt=(r.rt[0]-ox, r.rt[1]-oy), lb=(r.lb[0]-ox, r.lb[1]-oy), rb=(r.rb[0]-ox, r.rb[1]-oy)) + ox, oy = origin + r = rect + return MDRRectangle(lt=(r.lt[0] - ox, r.lt[1] - oy), rt=(r.rt[0] - ox, r.rt[1] - oy), lb=(r.lb[0] - ox, r.lb[1] - oy), rb=(r.rb[0] - ox, r.rb[1] - oy)) def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float: - try: p1, p2 = Polygon(r1), Polygon(r2); - except: return 0.0 - if not p1.is_valid or not p2.is_valid: return 0.0 - try: inter = p1.intersection(p2); union = p1.union(p2) - except: return 0.0 - if inter.is_empty or inter.area < 1e-6: return 0.0 - union_area = union.area; return inter.area / union_area if union_area > 1e-6 else 1.0 + try: + p1 = Polygon(r1) + p2 = Polygon(r2) + except: + return 0.0 + if not p1.is_valid or not p2.is_valid: + return 0.0 + try: + inter = p1.intersection(p2) + union = p1.union(p2) + except: + return 0.0 + if inter.is_empty or inter.area < 1e-6: + return 0.0 + union_area = union.area + return inter.area / union_area if union_area > 1e-6 else 1.0 # --- MDR Document Iterator --- _MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context @@ -1419,11 +2061,14 @@ class MDRDocumentIterator: for page_idx, res in self._run_extraction_on_pages(params): cur_sec = MDRPageSection(page_idx, res.layouts) for i, (_, prev_sec) in enumerate(queue): - offset = len(queue)-i; - if offset <= _MDR_CONTEXT_PAGES: prev_sec.link_to_next(cur_sec, offset) + offset = len(queue) - i + if offset <= _MDR_CONTEXT_PAGES: + prev_sec.link_to_next(cur_sec, offset) queue.append((res, cur_sec)) - if len(queue) > _MDR_CONTEXT_PAGES: yield queue.pop(0) - for res, sec in queue: yield res, sec + if len(queue) > _MDR_CONTEXT_PAGES: + yield queue.pop(0) + for res, sec in queue: + yield res, sec def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]: if self._debug_dir: mdr_ensure_directory(self._debug_dir) @@ -1439,30 +2084,45 @@ class MDRDocumentIterator: for i, page_idx in enumerate(scan_idxs): print(f" Iterator: Processing page {page_idx+1}/{doc.page_count} (Scan {i+1}/{total_scan})...") try: - page = doc.load_page(page_idx); img = self._render_page_image(page, 300) + page = doc.load_page(page_idx) + img = self._render_page_image(page, 300) res = self._engine.analyze_image(image=img, adjust_points=False) # Engine analyzes image - if self._debug_dir: self._save_debug_plot(img, page_idx, res, self._debug_dir) - if page_idx in enable_set: yield page_idx, res # Yield result for requested pages - if params.report_progress: params.report_progress(i+1, total_scan) - except Exception as e: print(f" Iterator: Page {page_idx+1} processing error: {e}") + if self._debug_dir: + self._save_debug_plot(img, page_idx, res, self._debug_dir) + if page_idx in enable_set: + yield page_idx, res # Yield result for requested pages + if params.report_progress: + params.report_progress(i + 1, total_scan) + except Exception as e: + print(f" Iterator: Page {page_idx + 1} processing error: {e}") finally: if should_close and doc: doc.close() def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]: - count = doc.page_count; - if idxs is None: all_p = list(range(count)); return all_p, all_p - enable, scan = set(), set() + count = doc.page_count + if idxs is None: + all_p = list(range(count)) + return all_p, all_p + enable = set() + scan = set() for i in idxs: - if 0<=i Image: - mat = FitzMatrix(dpi/72.0, dpi/72.0); pix = page.get_pixmap(matrix=mat, alpha=False) + mat = FitzMatrix(dpi / 72.0, dpi / 72.0) + pix = page.get_pixmap(matrix=mat, alpha=False) return frombytes("RGB", (pix.width, pix.height), pix.samples) def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str): - try: plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy(); mdr_plot_layout(plot_img, res.layouts); plot_img.save(os.path.join(path, f"mdr_plot_page_{idx+1}.png")) - except Exception as e: print(f" Iterator: Plot generation error page {idx+1}: {e}") + try: + plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy() + mdr_plot_layout(plot_img, res.layouts) + plot_img.save(os.path.join(path, f"mdr_plot_page_{idx + 1}.png")) + except Exception as e: + print(f" Iterator: Plot generation error page {idx + 1}: {e}") # --- MagicDataReadiness Main Processor --- class MagicPDFProcessor: @@ -1544,10 +2204,9 @@ class MagicPDFProcessor: if isinstance(layout, MDRPlainLayoutElement): self._add_plain_block(temp_store, layout, result) elif isinstance(layout, MDRTableLayoutElement): temp_store.append((layout, self._create_table_block(layout, result))) elif isinstance(layout, MDRFormulaLayoutElement): temp_store.append((layout, self._create_formula_block(layout, result))) - self._assign_relative_font_sizes(temp_store); + self._assign_relative_font_sizes(temp_store) return [block for _, block in temp_store] - # --- START REFACTORED METHOD --- def _analyze_paragraph_structure(self, blocks: list[MDRStructuredBlock]): """ Calculates indentation and line-end heuristics for MDRTextBlocks @@ -1606,50 +2265,77 @@ class MagicPDFProcessor: print(f"Warn: Error calculating paragraph structure for block: {e}") # Default to False if calculation fails to ensure attributes are set block.has_paragraph_indentation = False - block.last_line_touch_end = False # Removed semicolon from original - - # --- END REFACTORED METHOD --- + block.last_line_touch_end = False def _calculate_text_range(self, blocks_iter: Iterable[MDRStructuredBlock]) -> tuple[float, float, float]: """Calculates average line height and min/max x-coordinates for text.""" - h_sum, count, x1, x2 = 0.0, 0, float('inf'), float('-inf') + h_sum = 0.0 + count = 0 + x1 = float('inf') + x2 = float('-inf') for b in blocks_iter: - if not isinstance(b, MDRTextBlock) or b.kind==MDRTextKind.ABANDON: continue + if not isinstance(b, MDRTextBlock) or b.kind == MDRTextKind.ABANDON: + continue for t in b.texts: - _, h = t.rect.size; - if h>1e-6: h_sum += h; count += 1 # Use small threshold for valid height - tx1, _, tx2, _ = t.rect.wrapper; x1, x2 = min(x1, tx1), max(x2, tx2) - if count==0: return 0.0, 0.0, 0.0 - mean_h = h_sum/count; x1 = 0.0 if x1==float('inf') else x1; x2 = 0.0 if x2==float('-inf') else x2; return mean_h, x1, x2 + _, h = t.rect.size + if h > 1e-6: # Use small threshold for valid height + h_sum += h + count += 1 + tx1, _, tx2, _ = t.rect.wrapper + x1 = min(x1, tx1) + x2 = max(x2, tx2) + if count == 0: + return 0.0, 0.0, 0.0 + mean_h = h_sum / count + x1 = 0.0 if x1 == float('inf') else x1 + x2 = 0.0 if x2 == float('-inf') else x2 + return mean_h, x1, x2 def _add_plain_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], layout: MDRPlainLayoutElement, result: MDRExtractionResult): """Creates MDRStructuredBlocks for plain layout types.""" - cls = layout.cls; texts = self._convert_fragments_to_spans(layout.fragments) - if cls==MDRLayoutClass.TITLE: store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.TITLE))) - elif cls==MDRLayoutClass.PLAIN_TEXT: store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.PLAIN_TEXT))) - elif cls==MDRLayoutClass.ABANDON: store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.ABANDON))) - elif cls==MDRLayoutClass.FIGURE: store.append((layout, MDRFigureBlock(layout.rect, [], 0.0, mdr_clip_layout(result, layout)))) - elif cls==MDRLayoutClass.FIGURE_CAPTION: block=self._find_previous_block(store, MDRFigureBlock); block.texts.extend(texts) if block else None - elif cls==MDRLayoutClass.TABLE_CAPTION or cls==MDRLayoutClass.TABLE_FOOTNOTE: block=self._find_previous_block(store, MDRTableBlock); block.texts.extend(texts) if block else None - elif cls==MDRLayoutClass.FORMULA_CAPTION: block=self._find_previous_block(store, MDRFormulaBlock); block.texts.extend(texts) if block else None + cls = layout.cls + texts = self._convert_fragments_to_spans(layout.fragments) + if cls == MDRLayoutClass.TITLE: + store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.TITLE))) + elif cls == MDRLayoutClass.PLAIN_TEXT: + store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.PLAIN_TEXT))) + elif cls == MDRLayoutClass.ABANDON: + store.append((layout, MDRTextBlock(layout.rect, texts, 0.0, MDRTextKind.ABANDON))) + elif cls == MDRLayoutClass.FIGURE: + store.append((layout, MDRFigureBlock(layout.rect, [], 0.0, mdr_clip_layout(result, layout)))) + elif cls == MDRLayoutClass.FIGURE_CAPTION: + block = self._find_previous_block(store, MDRFigureBlock) + if block: block.texts.extend(texts) + elif cls == MDRLayoutClass.TABLE_CAPTION or cls == MDRLayoutClass.TABLE_FOOTNOTE: + block = self._find_previous_block(store, MDRTableBlock) + if block: block.texts.extend(texts) + elif cls == MDRLayoutClass.FORMULA_CAPTION: + block = self._find_previous_block(store, MDRFormulaBlock) + if block: block.texts.extend(texts) def _find_previous_block(self, store: list[tuple[MDRLayoutElement, MDRStructuredBlock]], block_type: type) -> MDRStructuredBlock | None: """Finds the most recent block of a specific type.""" - for i in range(len(store)-1, -1, -1): - _, block = store[i]; - if isinstance(block, block_type): return block + for i in range(len(store) - 1, -1, -1): + _, block = store[i] + if isinstance(block, block_type): + return block return None def _create_table_block(self, layout: MDRTableLayoutElement, result: MDRExtractionResult) -> MDRTableBlock: """Converts MDRTableLayoutElement to MDRTableBlock.""" - fmt, content = MDRTableFormat.UNRECOGNIZABLE, "" + fmt = MDRTableFormat.UNRECOGNIZABLE + content = "" if layout.parsed: - p_content, p_fmt = layout.parsed; can_use = not (p_fmt==MDRTableLayoutParsedFormat.LATEX and mdr_contains_cjka("".join(f.text for f in layout.fragments))) + p_content, p_fmt = layout.parsed + can_use = not (p_fmt == MDRTableLayoutParsedFormat.LATEX and mdr_contains_cjka("".join(f.text for f in layout.fragments))) if can_use: content = p_content - if p_fmt==MDRTableLayoutParsedFormat.LATEX: fmt=MDRTableFormat.LATEX - elif p_fmt==MDRTableLayoutParsedFormat.MARKDOWN: fmt=MDRTableFormat.MARKDOWN - elif p_fmt==MDRTableLayoutParsedFormat.HTML: fmt=MDRTableFormat.HTML + if p_fmt == MDRTableLayoutParsedFormat.LATEX: + fmt = MDRTableFormat.LATEX + elif p_fmt == MDRTableLayoutParsedFormat.MARKDOWN: + fmt = MDRTableFormat.MARKDOWN + elif p_fmt == MDRTableLayoutParsedFormat.HTML: + fmt = MDRTableFormat.HTML return MDRTableBlock(layout.rect, [], 0.0, fmt, content, mdr_clip_layout(result, layout)) def _create_formula_block(self, layout: MDRFormulaLayoutElement, result: MDRExtractionResult) -> MDRFormulaBlock: @@ -1661,13 +2347,16 @@ class MagicPDFProcessor: """Calculates and assigns relative font size (0-1) to blocks.""" sizes = [] for l, _ in store: - heights = [f.rect.size[1] for f in l.fragments if f.rect.size[1]>1e-6] # Use small threshold - avg_h = sum(heights)/len(heights) if heights else 0.0 + heights = [f.rect.size[1] for f in l.fragments if f.rect.size[1] > 1e-6] # Use small threshold + avg_h = sum(heights) / len(heights) if heights else 0.0 sizes.append(avg_h) - valid = [s for s in sizes if s>1e-6]; min_s, max_s = (min(valid), max(valid)) if valid else (0.0, 0.0) + valid = [s for s in sizes if s > 1e-6] + min_s, max_s = (min(valid), max(valid)) if valid else (0.0, 0.0) rng = max_s - min_s - if rng < 1e-6: [setattr(b, 'font_size', 0.0) for _, b in store] - else: [setattr(b, 'font_size', (s-min_s)/rng if s>1e-6 else 0.0) for s, (_, b) in zip(sizes, store)] + if rng < 1e-6: + [setattr(b, 'font_size', 0.0) for _, b in store] + else: + [setattr(b, 'font_size', (s - min_s) / rng if s > 1e-6 else 0.0) for s, (_, b) in zip(sizes, store)] def _convert_fragments_to_spans(self, frags: list[MDROcrFragment]) -> list[MDRTextSpan]: """Converts MDROcrFragment list to MDRTextSpan list.""" @@ -1722,7 +2411,8 @@ if __name__ == '__main__': print("-" * 60) mdr_ensure_directory(MDR_MODEL_DIRECTORY) - if MDR_DEBUG_DIRECTORY: mdr_ensure_directory(MDR_DEBUG_DIRECTORY) + if MDR_DEBUG_DIRECTORY: + mdr_ensure_directory(MDR_DEBUG_DIRECTORY) if not Path(MDR_INPUT_PDF).is_file(): print(f"ERROR: Input PDF not found at '{MDR_INPUT_PDF}'. Please place a PDF file there or update the path.") exit(1) @@ -1774,9 +2464,9 @@ if __name__ == '__main__': print(f" Extracted {len(page_blocks)} blocks:") for block_idx, block in enumerate(page_blocks): all_blocks_count += 1 - info = f" - Block {block_idx+1}: {type(block).__name__}" + info = f" - Block {block_idx + 1}: {type(block).__name__}" if isinstance(block, MDRTextBlock): - preview = block.texts[0].content[:70].replace('\n',' ') + "..." if block.texts else "[EMPTY]" + preview = block.texts[0].content[:70].replace('\n', ' ') + "..." if block.texts else "[EMPTY]" info += f" (Kind: {block.kind.name}, FontSz: {block.font_size:.2f}, Indent: {block.has_paragraph_indentation}, EndTouch: {block.last_line_touch_end}) | Text: '{preview}'" # Added indent/endtouch elif isinstance(block, MDRTableBlock): info += f" (Format: {block.format.name}, HasContent: {bool(block.content)}, FontSz: {block.font_size:.2f})"