rodrigomasini commited on
Commit
a00ffbd
·
verified ·
1 Parent(s): aa1e17c

Update mdr_pdf_parser.py

Browse files
Files changed (1) hide show
  1. mdr_pdf_parser.py +247 -51
mdr_pdf_parser.py CHANGED
@@ -38,15 +38,16 @@ from PIL.ImageOps import expand as pil_expand
38
  from PIL import ImageDraw
39
  from PIL.ImageFont import load_default, FreeTypeFont
40
  from shapely.geometry import Polygon
41
- import pyclipper # Used in onnxocr/db_postprocess
42
  from unicodedata import category
43
  from alphabet_detector import AlphabetDetector
44
- from munch import Munch # Required by latex.py (pix2tex wrapper)
45
- from transformers import LayoutLMv3ForTokenClassification # Required by layout_order/layoutreader
46
- import onnxruntime # Required by onnxocr components
 
 
47
 
48
- # --- Potentially Installable External Dependencies ---
49
- # Ensure these are installed or available in your environment
50
  try:
51
  from doclayout_yolo import YOLOv10
52
  except ImportError:
@@ -58,14 +59,14 @@ except ImportError:
58
  print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.")
59
  LatexOCR = None
60
  try:
61
- # Dynamic import within MDRTableParser, assuming struct_eqtable.py exists or is installable
62
  pass # from struct_eqtable import build_model
63
  except ImportError:
64
  print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.")
65
 
66
  # --- MagicDataReadiness Core Components ---
67
 
68
- # --- MDR Utilities (downloader.py) ---
 
69
  def mdr_download_model(url: str, file_path: Path):
70
  """Downloads a model file from a URL to a local path."""
71
  try:
@@ -85,7 +86,7 @@ def mdr_download_model(url: str, file_path: Path):
85
  if file_path.exists(): os.remove(file_path)
86
  raise e
87
 
88
- # --- MDR Utilities (utils.py from doc_page_extractor internals) ---
89
  def mdr_ensure_directory(path: str) -> str:
90
  """Ensures a directory exists, creating it if necessary."""
91
  path = os.path.abspath(path)
@@ -146,79 +147,101 @@ def mdr_intersection_area(rect1: MDRRectangle, rect2: MDRRectangle) -> float:
146
  return p1.intersection(p2).area
147
  except: return 0.0
148
 
149
- # --- MDR Data Structures (ExtractedResult.py) ---
150
  @dataclass
151
  class MDROcrFragment:
152
  """Represents a fragment of text identified by OCR."""
153
  order: int; text: str; rank: float; rect: MDRRectangle
 
154
  class MDRLayoutClass(Enum):
155
  """Enumeration of different layout types identified."""
156
  TITLE=0; PLAIN_TEXT=1; ABANDON=2; FIGURE=3; FIGURE_CAPTION=4; TABLE=5; TABLE_CAPTION=6; TABLE_FOOTNOTE=7; ISOLATE_FORMULA=8; FORMULA_CAPTION=9
 
157
  class MDRTableLayoutParsedFormat(Enum):
158
  """Enumeration for formats of parsed table content."""
159
  LATEX=auto(); MARKDOWN=auto(); HTML=auto()
 
160
  @dataclass
161
  class MDRBaseLayoutElement:
162
  """Base class for layout elements found on a page."""
163
  rect: MDRRectangle; fragments: list[MDROcrFragment]
 
164
  @dataclass
165
  class MDRPlainLayoutElement(MDRBaseLayoutElement):
166
  """Layout element for plain text, titles, captions, figures, etc."""
167
  cls: Literal[MDRLayoutClass.TITLE, MDRLayoutClass.PLAIN_TEXT, MDRLayoutClass.ABANDON, MDRLayoutClass.FIGURE, MDRLayoutClass.FIGURE_CAPTION, MDRLayoutClass.TABLE_CAPTION, MDRLayoutClass.TABLE_FOOTNOTE, MDRLayoutClass.FORMULA_CAPTION]
 
168
  @dataclass
169
  class MDRTableLayoutElement(MDRBaseLayoutElement):
170
  """Layout element specifically for tables."""
171
  parsed: tuple[str, MDRTableLayoutParsedFormat] | None; cls: Literal[MDRLayoutClass.TABLE] = MDRLayoutClass.TABLE
 
172
  @dataclass
173
  class MDRFormulaLayoutElement(MDRBaseLayoutElement):
174
  """Layout element specifically for formulas."""
175
  latex: str | None; cls: Literal[MDRLayoutClass.ISOLATE_FORMULA] = MDRLayoutClass.ISOLATE_FORMULA
 
176
  MDRLayoutElement = MDRPlainLayoutElement | MDRTableLayoutElement | MDRFormulaLayoutElement # Type alias
 
177
  @dataclass
178
  class MDRExtractionResult:
179
  """Holds the complete result of extracting from a single page image."""
180
  rotation: float; layouts: list[MDRLayoutElement]; extracted_image: Image; adjusted_image: Image | None
181
 
182
- # --- MDR Data Structures (types.py - Original script 4) ---
 
183
  MDRProgressReportCallback: TypeAlias = Callable[[int, int], None]
 
184
  class MDROcrLevel(Enum): Once=auto(); OncePerLayout=auto()
 
185
  class MDRExtractedTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); DISABLE=auto()
 
186
  class MDRTextKind(Enum): TITLE=0; PLAIN_TEXT=1; ABANDON=2
 
187
  @dataclass
188
  class MDRTextSpan:
189
  """Represents a span of text content within a block."""
190
  content: str; rank: float; rect: MDRRectangle
 
191
  @dataclass
192
  class MDRBasicBlock:
193
  """Base class for structured blocks extracted from the document."""
194
  rect: MDRRectangle; texts: list[MDRTextSpan]; font_size: float # Relative font size (0-1)
 
195
  @dataclass
196
  class MDRTextBlock(MDRBasicBlock):
197
  """A structured block containing text content."""
198
  kind: MDRTextKind; has_paragraph_indentation: bool = False; last_line_touch_end: bool = False
 
199
  class MDRTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); UNRECOGNIZABLE=auto()
 
200
  @dataclass
201
  class MDRTableBlock(MDRBasicBlock):
202
  """A structured block representing a table."""
203
  content: str; format: MDRTableFormat; image: Image # Image clip of the table
 
204
  @dataclass
205
  class MDRFormulaBlock(MDRBasicBlock):
206
  """A structured block representing a formula."""
207
  content: str | None; image: Image # Image clip of the formula
 
208
  @dataclass
209
  class MDRFigureBlock(MDRBasicBlock):
210
  """A structured block representing a figure/image."""
211
  image: Image # Image clip of the figure
 
212
  MDRAssetBlock = MDRTableBlock | MDRFormulaBlock | MDRFigureBlock # Type alias
 
213
  MDRStructuredBlock = MDRTextBlock | MDRAssetBlock # Type alias
214
 
215
- # --- MDR Utilities (utils.py - Original script 3) ---
216
  def mdr_similarity_ratio(v1: float, v2: float) -> float:
217
  """Calculates the ratio of the smaller value to the larger value (0-1)."""
218
  if v1==0 and v2==0: return 1.0;
219
  if v1<0 or v2<0: return 0.0;
220
  v1, v2 = (v2, v1) if v1 > v2 else (v1, v2);
221
  return 1.0 if v2==0 else v1/v2
 
222
  def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[float, float]:
223
  """Calculates width/height of the intersection bounding box."""
224
  try:
@@ -228,18 +251,23 @@ def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[fl
228
  if inter.is_empty: return 0.0, 0.0;
229
  minx, miny, maxx, maxy = inter.bounds; return maxx-minx, maxy-miny
230
  except: return 0.0, 0.0
 
231
  _MDR_CJKA_PATTERN = re.compile(r"[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3\u0600-\u06ff]")
 
232
  def mdr_contains_cjka(text: str):
233
  """Checks if text contains Chinese, Japanese, Korean, or Arabic chars."""
234
  return bool(_MDR_CJKA_PATTERN.search(text)) if text else False
235
 
236
- # --- MDR Text Processing (text_matcher.py - Original script 7) ---
237
  class _MDR_TokenPhase(Enum): Init=0; Letter=1; Character=2; Number=3; Space=4
 
238
  _mdr_alphabet_detector = AlphabetDetector()
 
239
  def _mdr_is_letter(char: str):
240
  if not category(char).startswith("L"): return False
241
  try: return _mdr_alphabet_detector.is_latin(char) or _mdr_alphabet_detector.is_cyrillic(char) or _mdr_alphabet_detector.is_greek(char) or _mdr_alphabet_detector.is_hebrew(char)
242
  except: return False
 
243
  def mdr_split_into_words(text: str):
244
  """Splits text into words, numbers, and individual non-alphanumeric chars."""
245
  if not text: return;
@@ -259,6 +287,7 @@ def mdr_split_into_words(text: str):
259
  if is_s: phase=_MDR_TokenPhase.Space
260
  else: yield char; phase=_MDR_TokenPhase.Character
261
  if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): w=buf.getvalue(); yield w if w else None
 
262
  def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]:
263
  """Calculates word-based similarity between two texts."""
264
  w1, w2 = list(mdr_split_into_words(t1)), list(mdr_split_into_words(t2)); l1, l2 = len(w1), len(w2)
@@ -271,21 +300,25 @@ def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]:
271
  if not taken[i] and word1==word2: taken[i]=True; matches+=1; break
272
  mismatches = l2 - matches; return 1.0 - (mismatches/l2), l2
273
 
274
- # --- MDR Geometry Processing (rotation.py) ---
275
  class MDRRotationAdjuster:
276
  """Adjusts point coordinates based on image rotation."""
 
277
  def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool):
278
  fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size)
279
  self._rot = rotation if to_origin_coordinate else -rotation
280
  self._c_off = (fs[0]/2.0, fs[1]/2.0); self._n_off = (ts[0]/2.0, ts[1]/2.0)
 
281
  def adjust(self, point: MDRPoint) -> MDRPoint:
282
  x, y = point[0]-self._c_off[0], point[1]-self._c_off[1]
283
  if x!=0 or y!=0: cos_r, sin_r = cos(self._rot), sin(self._rot); x, y = x*cos_r-y*sin_r, x*sin_r+y*cos_r
284
  return x+self._n_off[0], y+self._n_off[1]
 
285
  def mdr_normalize_vertical_rotation(rot: float) -> float:
286
  while rot >= pi: rot -= pi;
287
  while rot < 0: rot += pi;
288
  return rot
 
289
  def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[float]] | None:
290
  h_angs, v_angs = [], []
291
  for i, (p1, p2) in enumerate(rect.segments):
@@ -297,11 +330,15 @@ def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[flo
297
  else: v_angs.append(ang)
298
  if not h_angs or not v_angs: return None
299
  return h_angs, v_angs
 
300
  def _mdr_normalize_horizontal_angles(rots: list[float]) -> list[float]: return rots
 
301
  def _mdr_find_median(data: list[float]) -> float:
302
  if not data: return 0.0; s_data = sorted(data); n = len(s_data);
303
  return s_data[n//2] if n%2==1 else (s_data[n//2-1]+s_data[n//2])/2.0
 
304
  def _mdr_find_mean(data: list[float]) -> float: return sum(data)/len(data) if data else 0.0
 
305
  def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float:
306
  all_h, all_v = [], [];
307
  for f in frags:
@@ -314,6 +351,7 @@ def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float:
314
  while rot_est >= pi/2: rot_est -= pi;
315
  while rot_est < -pi/2: rot_est += pi;
316
  return rot_est
 
317
  def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]:
318
  res = _mdr_get_rectangle_angles(rect);
319
  if res is None: return 0.0, pi/2.0;
@@ -321,9 +359,10 @@ def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]:
321
  h_rots = _mdr_normalize_horizontal_angles(h_rots); v_rots = [mdr_normalize_vertical_rotation(a) for a in v_rots]
322
  return _mdr_find_mean(h_rots), _mdr_find_mean(v_rots)
323
 
324
- # --- MDR ONNX OCR Internals (predict_base.py) ---
325
  class _MDR_PredictBase:
326
  """Base class for ONNX model prediction components."""
 
327
  def get_onnx_session(self, model_path: str, use_gpu: bool):
328
  try:
329
  sess_opts = onnxruntime.SessionOptions(); sess_opts.log_severity_level = 3
@@ -336,24 +375,32 @@ class _MDR_PredictBase:
336
  if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
337
  print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.")
338
  raise e
 
339
  def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_outputs()]
 
340
  def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_inputs()]
 
341
  def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: return {name: img_np for name in names}
342
 
343
- # --- MDR ONNX OCR Internals (operators.py) ---
344
  class _MDR_NormalizeImage:
 
345
  def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
346
  self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0))
347
  mean = mean if mean is not None else [0.485, 0.456, 0.406]; std = std if std is not None else [0.229, 0.224, 0.225]
348
  shape = (3, 1, 1) if order == 'chw' else (1, 1, 3); self.mean = np.array(mean).reshape(shape).astype('float32'); self.std = np.array(std).reshape(shape).astype('float32')
 
349
  def __call__(self, data): img = data['image']; img = np.array(img) if isinstance(img, Image) else img; data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std; return data
 
350
  class _MDR_DetResizeForTest:
 
351
  def __init__(self, **kwargs):
352
  self.resize_type = 0; self.keep_ratio = False
353
  if 'image_shape' in kwargs: self.image_shape = kwargs['image_shape']; self.resize_type = 1; self.keep_ratio = kwargs.get('keep_ratio', False)
354
  elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len']; self.limit_type = kwargs.get('limit_type', 'min')
355
  elif 'resize_long' in kwargs: self.resize_type = 2; self.resize_long = kwargs.get('resize_long', 960)
356
  else: self.limit_side_len = 736; self.limit_type = 'min'
 
357
  def __call__(self, data):
358
  img = data['image']; src_h, src_w, _ = img.shape
359
  if src_h + src_w < 64: img = self._pad(img)
@@ -362,21 +409,30 @@ class _MDR_DetResizeForTest:
362
  else: img, ratios = self._resize1(img)
363
  if img is None: return None
364
  data['image'] = img; data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]); return data
 
365
  def _pad(self, im, v=0): h,w,c=im.shape; p=np.zeros((max(32,h),max(32,w),c),np.uint8)+v; p[:h,:w,:]=im; return p
 
366
  def _resize1(self, img): rh,rw=self.image_shape; oh,ow=img.shape[:2]; if self.keep_ratio: rw=ow*rh/oh; N=math.ceil(rw/32); rw=N*32; r_h,r_w=float(rh)/oh,float(rw)/ow; img=cv2.resize(img,(int(rw),int(rh))); return img,[r_h,r_w]
 
367
  def _resize0(self, img): lsl=self.limit_side_len; h,w,_=img.shape; r=1.0; if self.limit_type=='max': r=float(lsl)/max(h,w) if max(h,w)>lsl else 1.0; elif self.limit_type=='min': r=float(lsl)/min(h,w) if min(h,w)<lsl else 1.0; elif self.limit_type=='resize_long': r=float(lsl)/max(h,w); else: raise Exception('Unsupported'); rh,rw=int(h*r),int(w*r); rh=max(int(round(rh/32)*32),32); rw=max(int(round(rw/32)*32),32); if int(rw)<=0 or int(rh)<=0: return None,(None,None); img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
 
368
  def _resize2(self, img): h,w,_=img.shape; rl=self.resize_long; r=float(rl)/max(h,w); rh,rw=int(h*r),int(w*r); ms=128; rh=(rh+ms-1)//ms*ms; rw=(rw+ms-1)//ms*ms; img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
 
369
  class _MDR_ToCHWImage:
 
370
  def __call__(self, data): img=data['image']; img=np.array(img) if isinstance(img,Image) else img; data['image']=img.transpose((2,0,1)); return data
 
371
  class _MDR_KeepKeys:
 
372
  def __init__(self, keep_keys, **kwargs): self.keep_keys=keep_keys
 
373
  def __call__(self, data): return [data[key] for key in self.keep_keys]
374
 
375
- # --- MDR ONNX OCR Internals (imaug.py) ---
376
  def mdr_ocr_transform(data, ops=None):
377
  ops = ops if ops is not None else [];
378
  for op in ops: data = op(data); if data is None: return None;
379
  return data
 
380
  def mdr_ocr_create_operators(op_param_list, global_config=None):
381
  ops = []
382
  for operator in op_param_list:
@@ -388,11 +444,12 @@ def mdr_ocr_create_operators(op_param_list, global_config=None):
388
  else: raise ValueError(f"Operator class '{op_class_name}' not found.")
389
  return ops
390
 
391
- # --- MDR ONNX OCR Internals (db_postprocess.py) ---
392
  class _MDR_DBPostProcess:
 
393
  def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs):
394
  self.thresh, self.box_thresh, self.max_cand = thresh, box_thresh, max_candidates; self.unclip_r, self.min_sz, self.score_m, self.box_t = unclip_ratio, 3, score_mode, box_type
395
  assert score_mode in ["slow", "fast"]; self.dila_k = np.array([[1,1],[1,1]], dtype=np.uint8) if use_dilation else None
 
396
  def _polygons_from_bitmap(self, pred, bmp, dw, dh):
397
  h, w = bmp.shape; boxes, scores = [], []
398
  contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
@@ -409,6 +466,7 @@ class _MDR_DBPostProcess:
409
  box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
410
  boxes.append(box.tolist()); scores.append(score)
411
  return boxes, scores
 
412
  def _boxes_from_bitmap(self, pred, bmp, dw, dh):
413
  h, w = bmp.shape; contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
414
  num_contours = min(len(contours), self.max_cand); boxes, scores = [], []
@@ -424,25 +482,30 @@ class _MDR_DBPostProcess:
424
  box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
425
  boxes.append(box.astype("int32")); scores.append(score)
426
  return np.array(boxes, dtype="int32"), scores
 
427
  def _unclip(self, box, ratio):
428
  poly = Polygon(box); dist = poly.area*ratio/poly.length; offset = pyclipper.PyclipperOffset(); offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
429
  expanded = offset.Execute(dist);
430
  if not expanded: raise ValueError("Unclip failed"); return np.array(expanded[0])
 
431
  def _get_mini_boxes(self, contour):
432
  bb = cv2.minAreaRect(contour); pts = sorted(list(cv2.boxPoints(bb)), key=lambda x:x[0])
433
  i1,i4 = (0,1) if pts[1][1]>pts[0][1] else (1,0); i2,i3 = (2,3) if pts[3][1]>pts[2][1] else (3,2)
434
  box = [pts[i1], pts[i2], pts[i3], pts[i4]]; return box, min(bb[1])
 
435
  def _box_score_fast(self, bmp, box):
436
  h,w = bmp.shape[:2]; xmin=np.clip(np.floor(box[:,0].min()).astype("int32"),0,w-1); xmax=np.clip(np.ceil(box[:,0].max()).astype("int32"),0,w-1)
437
  ymin=np.clip(np.floor(box[:,1].min()).astype("int32"),0,h-1); ymax=np.clip(np.ceil(box[:,1].max()).astype("int32"),0,h-1)
438
  mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=np.uint8); box[:,0]-=xmin; box[:,1]-=ymin
439
  cv2.fillPoly(mask, box.reshape(1,-1,2).astype("int32"), 1);
440
  return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
 
441
  def _box_score_slow(self, bmp, contour): # Not used if fast
442
  h,w = bmp.shape[:2]; contour = np.reshape(contour.copy(),(-1,2)); xmin=np.clip(np.min(contour[:,0]),0,w-1); xmax=np.clip(np.max(contour[:,0]),0,w-1)
443
  ymin=np.clip(np.min(contour[:,1]),0,h-1); ymax=np.clip(np.max(contour[:,1]),0,h-1); mask=np.zeros((ymax-ymin+1,xmax-xmin+1),dtype=np.uint8)
444
  contour[:,0]-=xmin; contour[:,1]-=ymin; cv2.fillPoly(mask, contour.reshape(1,-1,2).astype("int32"), 1);
445
  return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
 
446
  def __call__(self, outs_dict, shape_list):
447
  pred = outs_dict['maps'][:,0,:,:]; seg = pred > self.thresh; boxes_batch = []
448
  for batch_idx in range(pred.shape[0]):
@@ -453,8 +516,8 @@ class _MDR_DBPostProcess:
453
  boxes_batch.append({'points': boxes})
454
  return boxes_batch
455
 
456
- # --- MDR ONNX OCR Internals (predict_det.py) ---
457
  class _MDR_TextDetector(_MDR_PredictBase):
 
458
  def __init__(self, args):
459
  super().__init__(); self.args = args
460
  pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}]
@@ -463,10 +526,15 @@ class _MDR_TextDetector(_MDR_PredictBase):
463
  self.post_op = _MDR_DBPostProcess(**post_params)
464
  self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu)
465
  self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
 
466
  def _order_pts(self, pts): r=np.zeros((4,2),dtype="float32"); s=pts.sum(axis=1); r[0]=pts[np.argmin(s)]; r[2]=pts[np.argmax(s)]; tmp=np.delete(pts,(np.argmin(s),np.argmax(s)),axis=0); d=np.diff(np.array(tmp),axis=1); r[1]=tmp[np.argmin(d)]; r[3]=tmp[np.argmax(d)]; return r
 
467
  def _clip_pts(self, pts, h, w): pts[:,0]=np.clip(pts[:,0],0,w-1); pts[:,1]=np.clip(pts[:,1],0,h-1); return pts
 
468
  def _filter_quad(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._order_pts(box); box=self._clip_pts(box,h,w); rw=int(np.linalg.norm(box[0]-box[1])); rh=int(np.linalg.norm(box[0]-box[3])); if rw<=3 or rh<=3: continue; new_boxes.append(box); return np.array(new_boxes)
 
469
  def _filter_poly(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._clip_pts(box,h,w); if Polygon(box).area<10: continue; new_boxes.append(box); return np.array(new_boxes)
 
470
  def __call__(self, img):
471
  ori_im = img.copy(); data = {"image": img}; data = mdr_ocr_transform(data, self.pre_op)
472
  if data is None: return None; img, shape_list = data;
@@ -475,25 +543,28 @@ class _MDR_TextDetector(_MDR_PredictBase):
475
  preds = {"maps": outputs[0]}; post_res = self.post_op(preds, shape_list); boxes = post_res[0]['points']
476
  return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type=='poly' else self._filter_quad(boxes, ori_im.shape)
477
 
478
- # --- MDR ONNX OCR Internals (cls_postprocess.py) ---
479
  class _MDR_ClsPostProcess:
 
480
  def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0:'0', 1:'180'}
 
481
  def __call__(self, preds, label=None, *args, **kwargs):
482
  preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; idxs = preds.argmax(axis=1)
483
  return [(self.labels[idx], float(preds[i,idx])) for i,idx in enumerate(idxs)]
484
 
485
- # --- MDR ONNX OCR Internals (predict_cls.py) ---
486
  class _MDR_TextClassifier(_MDR_PredictBase):
 
487
  def __init__(self, args):
488
  super().__init__(); self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape
489
  self.batch_num = args.cls_batch_num; self.thresh = args.cls_thresh; self.post_op = _MDR_ClsPostProcess(label_list=args.label_list)
490
  self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu); self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
 
491
  def _resize_norm(self, img):
492
  imgC,imgH,imgW = self.shape; h,w = img.shape[:2]; r=w/float(h) if h>0 else 0; rw=int(math.ceil(imgH*r)); rw=min(rw,imgW)
493
  resized = cv2.resize(img,(rw,imgH)); resized = resized.astype("float32")
494
  if imgC==1: resized = resized/255.0; resized = resized[np.newaxis,:]
495
  else: resized = resized.transpose((2,0,1))/255.0
496
  resized -= 0.5; resized /= 0.5; padding = np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:rw]=resized; return padding
 
497
  def __call__(self, img_list):
498
  if not img_list: return img_list, []; img_list_cp = copy.deepcopy(img_list); num = len(img_list_cp)
499
  ratios = [img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list_cp]; indices = np.argsort(np.array(ratios))
@@ -509,8 +580,8 @@ class _MDR_TextClassifier(_MDR_PredictBase):
509
  if "180" in label and score > self.thresh: img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180)
510
  return img_list, results
511
 
512
- # --- MDR ONNX OCR Internals (rec_postprocess.py) ---
513
  class _MDR_BaseRecLabelDecode:
 
514
  def __init__(self, char_path=None, use_space=False):
515
  self.beg, self.end, self.rev = "sos", "eos", False; self.chars = []
516
  if char_path is None: self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz")
@@ -521,9 +592,13 @@ class _MDR_BaseRecLabelDecode:
521
  if any("\u0600"<=c<="\u06FF" for c in self.chars): self.rev=True
522
  except FileNotFoundError: print(f"Warn: Dict not found {char_path}"); self.chars=list("0123456789abcdefghijklmnopqrstuvwxyz"); if use_space: self.chars.append(" ")
523
  d_char = self.add_special_char(list(self.chars)); self.dict={c:i for i,c in enumerate(d_char)}; self.character=d_char
 
524
  def add_special_char(self, chars): return chars
 
525
  def get_ignored_tokens(self): return []
 
526
  def _reverse(self, pred): res=[]; cur=""; for c in pred: if not re.search("[a-zA-Z0-9 :*./%+-]",c): res.extend([cur,c] if cur!="" else [c]); cur="" else: cur+=c; if cur!="": res.append(cur); return "".join(res[::-1])
 
527
  def decode(self, idxs, probs=None, remove_dup=False):
528
  res=[]; ignored=self.get_ignored_tokens(); bs=len(idxs)
529
  for b_idx in range(bs):
@@ -537,6 +612,7 @@ class _MDR_BaseRecLabelDecode:
537
  if self.rev: txt=self._reverse(txt)
538
  res.append((txt, float(np.mean(conf_l))))
539
  return res
 
540
  class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode):
541
  def __init__(self, char_path=None, use_space=False, **kwargs): super().__init__(char_path, use_space)
542
  def add_special_char(self, chars): return ["blank"]+chars
@@ -545,13 +621,14 @@ class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode):
545
  preds = preds[-1] if isinstance(preds,(tuple,list)) else preds; preds = np.array(preds) if not isinstance(preds,np.ndarray) else preds
546
  idxs=preds.argmax(axis=2); probs=preds.max(axis=2); txt=self.decode(idxs, probs, remove_dup=True); return txt
547
 
548
- # --- MDR ONNX OCR Internals (predict_rec.py) ---
549
  class _MDR_TextRecognizer(_MDR_PredictBase):
 
550
  def __init__(self, args):
551
  super().__init__(); shape_str=getattr(args,'rec_image_shape',"3,48,320"); self.shape=tuple(map(int,shape_str.split(',')))
552
  self.batch_num=getattr(args,'rec_batch_num',6); self.algo=getattr(args,'rec_algorithm','SVTR_LCNet')
553
  self.post_op=_MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, use_space=getattr(args,'use_space_char',True))
554
  self.sess=self.get_onnx_session(args.rec_model_dir, args.use_gpu); self.input_name=self.get_input_name(self.sess); self.output_name=self.get_output_name(self.sess)
 
555
  def _resize_norm(self, img, max_r):
556
  imgC,imgH,imgW = self.shape; h,w = img.shape[:2];
557
  if h==0 or w==0: return np.zeros((imgC,imgH,imgW),dtype=np.float32)
@@ -561,6 +638,7 @@ class _MDR_TextRecognizer(_MDR_PredictBase):
561
  if len(resized.shape)==2: resized=resized[:,:,np.newaxis]
562
  resized=resized.transpose((2,0,1))/255.0; resized-=0.5; resized/=0.5
563
  padding=np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:tw]=resized; return padding
 
564
  def __call__(self, img_list):
565
  if not img_list: return []; num=len(img_list); ratios=[img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list]
566
  indices=np.argsort(np.array(ratios)); results=[["",0.0]]*num; batch_n=self.batch_num
@@ -574,8 +652,9 @@ class _MDR_TextRecognizer(_MDR_PredictBase):
574
  for i in range(len(rec_out)): results[indices[start+i]]=rec_out[i]
575
  return results
576
 
577
- # --- MDR ONNX OCR System (predict_system.py) ---
578
  class _MDR_TextSystem:
 
579
  def __init__(self, args):
580
  class ArgsObject: # Helper to access dict args with dot notation
581
  def __init__(self, **entries): self.__dict__.update(entries)
@@ -587,11 +666,13 @@ class _MDR_TextSystem:
587
  self.drop_score = getattr(args, 'drop_score', 0.5)
588
  self.classifier = _MDR_TextClassifier(args) if self.use_cls else None
589
  self.crop_idx = 0; self.save_crop = getattr(args, 'save_crop_res', False); self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res")
 
590
  def _sort_boxes(self, boxes):
591
  if boxes is None or len(boxes)==0: return []
592
  def key(box): min_y=min(p[1] for p in box); min_x=min(p[0] for p in box); return (min_y, min_x)
593
  try: return list(sorted(boxes, key=key))
594
  except: return list(boxes) # Fallback
 
595
  def __call__(self, img, classify=True):
596
  ori_im = img.copy(); boxes = self.detector(img)
597
  if boxes is None or len(boxes)==0: return [], []
@@ -613,12 +694,13 @@ class _MDR_TextSystem:
613
  if score >= self.drop_score: final_boxes.append(box); final_rec.append(res)
614
  if self.save_crop: self._save_crops(crops, rec_res)
615
  return final_boxes, final_rec
 
616
  def _save_crops(self, crops, recs):
617
  mdr_ensure_directory(self.crop_dir); num = len(crops)
618
  for i in range(num): txt, score = recs[i]; safe=re.sub(r'\W+', '_', txt)[:20]; fname=f"crop_{self.crop_idx+i}_{safe}_{score:.2f}.jpg"; cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i])
619
  self.crop_idx += num
620
 
621
- # --- MDR ONNX OCR Utilities (onnxocr/utils.py) ---
622
  def mdr_get_rotated_crop(img, points):
623
  """Crops and perspective-transforms a quadrilateral region."""
624
  pts = np.array(points, dtype="float32"); assert len(pts)==4
@@ -630,14 +712,17 @@ def mdr_get_rotated_crop(img, points):
630
  dh, dw = dst.shape[0:2]
631
  if dh>0 and dw>0 and dh*1.0/dw >= 1.5: dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE)
632
  return dst
 
633
  def mdr_get_min_area_crop(img, points):
634
  """Crops the minimum area rectangle containing the points."""
635
  bb = cv2.minAreaRect(np.array(points).astype(np.int32)); box_pts = cv2.boxPoints(bb)
636
  return mdr_get_rotated_crop(img, box_pts)
637
 
638
- # --- MDR Layout Processing (overlap.py) ---
639
  _MDR_INCLUDES_MIN_RATE = 0.99
 
640
  class _MDR_OverlapMatrixContext:
 
641
  def __init__(self, layouts: list[MDRLayoutElement]):
642
  length = len(layouts); self.polys: list[Polygon|None] = []
643
  for l in layouts:
@@ -651,6 +736,7 @@ class _MDR_OverlapMatrixContext:
651
  p2 = self.polys[j];
652
  if p2 is None: continue
653
  r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji
 
654
  def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2
655
  try: inter = p1.intersection(p2);
656
  except: return 0.0
@@ -659,9 +745,11 @@ class _MDR_OverlapMatrixContext:
659
  _, _, px1, py1 = p2.bounds; pw, ph = px1-p2.bounds[0], py1-p2.bounds[1]
660
  if pw < 1e-6 or ph < 1e-6: return 0.0
661
  wr = min(iw/pw, 1.0); hr = min(ih/ph, 1.0); return (wr+hr)/2.0
 
662
  def others(self, idx: int):
663
  for i, r in enumerate(self.matrix[idx]):
664
  if i != idx and i not in self.removed: yield r
 
665
  def includes(self, idx: int): # Layouts included BY idx
666
  for i, r in enumerate(self.matrix[idx]):
667
  if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE:
@@ -723,8 +811,9 @@ def mdr_merge_fragments_into_lines(orig_frags: list[MDROcrFragment]) -> list[MDR
723
  for i, f in enumerate(merged): f.order = i
724
  return merged
725
 
726
- # --- MDR Layout Processing (ocr_corrector.py) ---
727
  _MDR_CORRECTION_MIN_OVERLAP = 0.5
 
728
  def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image, layout: MDRLayoutElement):
729
  if not layout.fragments: return;
730
  try:
@@ -759,9 +848,12 @@ def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image,
759
  final = [n if n.rank >= o.rank else o for o, n in matched]; final.extend(unmatched_orig); final.extend(unmatched_new)
760
  layout.fragments = final; layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
761
 
762
- # --- MDR OCR Engine (ocr.py) ---
 
763
  _MDR_OCR_MODELS = {"det": ("ppocrv4","det","det.onnx"), "cls": ("ppocrv4","cls","cls.onnx"), "rec": ("ppocrv4","rec","rec.onnx"), "keys": ("ch_ppocr_server_v2.0","ppocr_keys_v1.txt")}
 
764
  _MDR_OCR_URL_BASE = "https://huggingface.co/moskize/OnnxOCR/resolve/main/"
 
765
  @dataclass
766
  class _MDR_ONNXParams: # Simplified container
767
  use_gpu: bool; det_model_dir: str; cls_model_dir: str; rec_model_dir: str; rec_char_dict_path: str
@@ -772,14 +864,17 @@ class _MDR_ONNXParams: # Simplified container
772
 
773
  class MDROcrEngine:
774
  """Handles OCR detection and recognition using ONNX models."""
 
775
  def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str):
776
  self._device = device; self._model_dir = mdr_ensure_directory(model_dir_path)
777
  self._text_system: _MDR_TextSystem | None = None; self._onnx_params: _MDR_ONNXParams | None = None
778
  self._ensure_models(); self._get_system() # Init on creation
 
779
  def _ensure_models(self):
780
  for key, parts in _MDR_OCR_MODELS.items():
781
  fp = Path(self._model_dir) / Path(*parts)
782
  if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join(parts); mdr_download_model(url, fp)
 
783
  def _get_system(self) -> _MDR_TextSystem | None:
784
  if self._text_system is None:
785
  paths = {k: str(Path(self._model_dir)/Path(*p)) for k,p in _MDR_OCR_MODELS.items()}
@@ -787,6 +882,7 @@ class MDROcrEngine:
787
  try: self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.")
788
  except Exception as e: print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None
789
  return self._text_system
 
790
  def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]:
791
  """Finds and recognizes text fragments in a NumPy image (BGR)."""
792
  system = self._get_system()
@@ -799,22 +895,26 @@ class MDROcrEngine:
799
  if not txt or mdr_is_whitespace(txt) or conf < 0.1: continue
800
  pts = [(float(p[0]), float(p[1])) for p in box_pts]
801
  if len(pts)==4: r=MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]); if r.is_valid and r.area>1: yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r)
 
802
  def _preprocess(self, img: np.ndarray) -> np.ndarray:
803
  if len(img.shape)==3 and img.shape[2]==4: a=img[:,:,3]/255.0; bg=(255,255,255); new=np.zeros_like(img[:,:,:3]); [setattr(new[:,:,i], 'flags.writeable', True) for i in range(3)]; [np.copyto(new[:,:,i], (bg[i]*(1-a)+img[:,:,i]*a)) for i in range(3)]; img=new.astype(np.uint8)
804
  elif len(img.shape)==2: img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
805
  elif not (len(img.shape)==3 and img.shape[2]==3): raise ValueError("Unsupported image format")
806
  return img
807
 
808
- # --- MDR Layout Reading Internals (layoutreader.py) ---
809
  _MDR_MAX_LEN = 510; _MDR_CLS_ID = 0; _MDR_SEP_ID = 2; _MDR_PAD_ID = 1
 
810
  def mdr_boxes_to_reader_inputs(boxes: List[List[int]], max_len=_MDR_MAX_LEN) -> Dict[str, torch.Tensor]:
811
  t_boxes = boxes[:max_len]; i_boxes = [[0,0,0,0]] + t_boxes + [[0,0,0,0]]
812
  i_ids = [_MDR_CLS_ID] + [_MDR_PAD_ID]*len(t_boxes) + [_MDR_SEP_ID]
813
  a_mask = [1]*len(i_ids); pad_len = (max_len+2) - len(i_ids)
814
  if pad_len > 0: i_boxes.extend([[0,0,0,0]]*pad_len); i_ids.extend([_MDR_PAD_ID]*pad_len); a_mask.extend([0]*pad_len)
815
  return {"bbox": torch.tensor([i_boxes]), "input_ids": torch.tensor([i_ids]), "attention_mask": torch.tensor([a_mask])}
 
816
  def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification) -> Dict[str, torch.Tensor]:
817
  return {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
 
818
  def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]:
819
  if length == 0: return []; rel_logits = logits[1:length+1, :length]; orders = rel_logits.argmax(dim=1).tolist()
820
  while True:
@@ -830,14 +930,17 @@ def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]:
830
  orders[idx] = rel_logits[idx, :].argmax().item(); rel_logits[idx, order] = orig_logit
831
  return orders
832
 
833
- # --- MDR Layout Reading Engine (layout_order.py) ---
834
  @dataclass
835
  class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; order: int; value: tuple[float, float, float, float]
 
836
  class MDRLayoutReader:
837
  """Determines reading order of layout elements using LayoutLMv3."""
 
838
  def __init__(self, model_path: str):
839
  self._model_path = model_path; self._model: LayoutLMv3ForTokenClassification | None = None
840
  self._device = "cuda" if torch.cuda.is_available() else "cpu"
 
841
  def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
842
  if self._model is None:
843
  cache = mdr_ensure_directory(self._model_path); name = "microsoft/layoutlmv3-base"; h_path = os.path.join(cache, "models--hantian--layoutreader")
@@ -847,6 +950,7 @@ class MDRLayoutReader:
847
  self._model.to(self._device); self._model.eval(); print(f"MDR LayoutReader loaded on {self._device}.")
848
  except Exception as e: print(f"ERROR loading MDR LayoutReader: {e}"); self._model = None
849
  return self._model
 
850
  def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]:
851
  w, h = size;
852
  if w<=0 or h<=0 or not layouts: return layouts;
@@ -873,6 +977,7 @@ class MDRLayoutReader:
873
  if len(orders) != len(bbox_list): print("MDR LayoutReader order mismatch"); return layouts # Fallback
874
  for i, order_idx in enumerate(orders): bbox_list[i].order = order_idx
875
  return self._apply_order(layouts, bbox_list)
 
876
  def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None:
877
  line_h = self._estimate_line_h(layouts); bbox_list = []
878
  for i, l in enumerate(layouts):
@@ -880,6 +985,7 @@ class MDRLayoutReader:
880
  else: bbox_list.extend(self._gen_virtual(l, i, line_h, w, h))
881
  if len(bbox_list) > _MDR_MAX_LEN: print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})"); return None
882
  bbox_list.sort(key=lambda b: (b.value[1], b.value[0])); return bbox_list
 
883
  def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]:
884
  layout_map = defaultdict(list); [layout_map[b.layout_index].append(b) for b in bbox_list]
885
  layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes]
@@ -893,9 +999,11 @@ class MDRLayoutReader:
893
  else: frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
894
  for frag in frags: frag.order = nfo; nfo += 1
895
  return sorted_layouts
 
896
  def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float:
897
  heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1]>0]
898
  return self._median(heights) if heights else 15.0
 
899
  def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]:
900
  x0,y0,x1,y1 = l.rect.wrapper; lh,lw = y1-y0,x1-x0
901
  if lh<=0 or lw<=0 or line_h<=0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(x0,y0,x1,y1)); return
@@ -910,16 +1018,19 @@ class MDRLayoutReader:
910
  ly0,ly1 = max(0,min(ph,cur_y)), max(0,min(ph,cur_y+act_line_h)); lx0,lx1 = max(0,min(pw,x0)), max(0,min(pw,x1))
911
  if ly1>ly0 and lx1>lx0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(lx0,ly0,lx1,ly1))
912
  cur_y += act_line_h
 
913
  def _median(self, nums: list[float|int]) -> float:
914
  if not nums: return 0.0; s_nums = sorted(nums); n = len(s_nums)
915
  return float(s_nums[n//2]) if n%2==1 else float((s_nums[n//2-1]+s_nums[n//2])/2.0)
916
 
917
- # --- MDR LaTeX Extractor (latex.py) ---
918
  class MDRLatexExtractor:
919
  """Extracts LaTeX from formula images using pix2tex."""
 
920
  def __init__(self, model_path: str):
921
  self._model_path = model_path; self._model: LatexOCR | None = None
922
  self._device = "cuda" if torch.cuda.is_available() else "cpu"
 
923
  def extract(self, image: Image) -> str | None:
924
  if LatexOCR is None: return None;
925
  image = mdr_expand_image(image, 0.1); model = self._get_model()
@@ -927,6 +1038,7 @@ class MDRLatexExtractor:
927
  try:
928
  with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None
929
  except Exception as e: print(f"MDR LaTeX error: {e}"); return None
 
930
  def _get_model(self) -> LatexOCR | None:
931
  if self._model is None and LatexOCR is not None:
932
  mdr_ensure_directory(self._model_path); wp, rp, cp = Path(self._model_path)/"weights.pth", Path(self._model_path)/"image_resizer.pth", Path(self._model_path)/"config.yaml"
@@ -935,19 +1047,23 @@ class MDRLatexExtractor:
935
  try: args = Munch({"config":str(cp), "checkpoint":str(wp), "device":self._device, "no_cuda":self._device=="cpu", "no_resize":False, "temperature":0.0}); self._model = LatexOCR(args); print(f"MDR LaTeX loaded on {self._device}.")
936
  except Exception as e: print(f"ERROR initializing MDR LatexOCR: {e}"); self._model = None
937
  return self._model
 
938
  def _download(self):
939
  tag = "v0.0.1"; base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/"; files = {"weights.pth": base+"weights.pth", "image_resizer.pth": base+"image_resizer.pth"}
940
  mdr_ensure_directory(self._model_path); [mdr_download_model(url, Path(self._model_path)/name) for name, url in files.items() if not (Path(self._model_path)/name).exists()]
941
 
942
- # --- MDR Table Parser (table.py) ---
943
  MDRTableOutputFormat = Literal["latex", "markdown", "html"]
 
944
  class MDRTableParser:
945
  """Parses table structure/content from images using StructTable model."""
 
946
  def __init__(self, device: Literal["cpu", "cuda"], model_path: str):
947
  self._model: Any | None = None; self._model_path = mdr_ensure_directory(model_path)
948
  self._device = device if torch.cuda.is_available() and device=="cuda" else "cpu"
949
  self._disabled = self._device == "cpu"
950
  if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.")
 
951
  def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None:
952
  if self._disabled: return None;
953
  fmt: MDRTableOutputFormat | None = None
@@ -962,6 +1078,7 @@ class MDRTableParser:
962
  with torch.no_grad(): results = model([img_rgb], output_format=fmt)
963
  return results[0] if results else None
964
  except Exception as e: print(f"MDR Table parsing error: {e}"); return None
 
965
  def _get_model(self):
966
  if self._model is None and not self._disabled:
967
  try:
@@ -974,23 +1091,31 @@ class MDRTableParser:
974
  except Exception as e: print(f"ERROR loading MDR StructTable: {e}"); self._model=None
975
  return self._model
976
 
977
- # --- MDR Image Optimizer (raw_optimizer.py) ---
978
  _MDR_TINY_ROTATION = 0.005
 
979
  @dataclass
980
  class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationAdjuster; fragment_origin_rectangles: list[MDRRectangle]
 
981
  class MDRImageOptimizer:
982
  """Handles image rotation detection and coordinate adjustments."""
 
983
  def __init__(self, raw_image: Image, adjust_points: bool):
984
  self._raw = raw_image; self._image = raw_image; self._adjust_points = adjust_points
985
  self._fragments: list[MDROcrFragment] = []; self._rotation: float = 0.0; self._rot_ctx: _MDR_RotationContext | None = None
 
986
  @property
987
  def image(self) -> Image: return self._image
 
988
  @property
989
  def adjusted_image(self) -> Image | None: return self._image if self._rot_ctx is not None else None
 
990
  @property
991
  def rotation(self) -> float: return self._rotation
 
992
  @property
993
  def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
 
994
  def receive_fragments(self, fragments: list[MDROcrFragment]):
995
  self._fragments = fragments;
996
  if not fragments: return;
@@ -1005,12 +1130,13 @@ class MDRImageOptimizer:
1005
  to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False),
1006
  to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True))
1007
  adj = self._rot_ctx.to_new; [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r:=f.rect)]
 
1008
  def finalize_layout_coords(self, layouts: list[MDRLayoutElement]):
1009
  if self._rot_ctx is None or self._adjust_points: return
1010
  if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles)]
1011
  adj = self._rot_ctx.to_origin; [setattr(l, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for l in layouts if (r:=l.rect)]
1012
 
1013
- # --- MDR Image Clipping (clipper.py) ---
1014
  def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
1015
  """Clips a potentially rotated rectangle from an image."""
1016
  try:
@@ -1026,18 +1152,21 @@ def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, w
1026
  out_w, out_h = ceil(avg_w+wrap_w), ceil(avg_h+wrap_h)
1027
  return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, fillcolor=(255,255,255))
1028
  except Exception as e: print(f"MDR Clipping error: {e}"); return new_image("RGB", (10,10), (255,255,255))
 
1029
  def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
1030
  """Clips a layout region from the MDRExtractionResult image."""
1031
  img = res.adjusted_image if res.adjusted_image else res.extracted_image
1032
  return mdr_clip_from_image(img, layout.rect, wrap_w, wrap_h)
1033
 
1034
- # --- MDR Debug Plotting (plot.py) ---
1035
  _MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); _MDR_LAYOUT_COLORS = { MDRLayoutClass.TITLE: (0x0A,0x12,0x2C,255), MDRLayoutClass.PLAIN_TEXT: (0x3C,0x67,0x90,255), MDRLayoutClass.ABANDON: (0xC0,0xBB,0xA9,180), MDRLayoutClass.FIGURE: (0x5B,0x91,0x3C,255), MDRLayoutClass.FIGURE_CAPTION: (0x77,0xB3,0x54,255), MDRLayoutClass.TABLE: (0x44,0x17,0x52,255), MDRLayoutClass.TABLE_CAPTION: (0x81,0x75,0xA0,255), MDRLayoutClass.TABLE_FOOTNOTE: (0xEF,0xB6,0xC9,255), MDRLayoutClass.ISOLATE_FORMULA: (0xFA,0x38,0x27,255), MDRLayoutClass.FORMULA_CAPTION: (0xFF,0x9D,0x24,255) }; _MDR_DEFAULT_COLOR = (0x80,0x80,0x80,255); _MDR_RGBA = tuple[int,int,int,int]
 
1036
  def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
1037
  """Draws layout and fragment boxes onto an image for debugging."""
1038
  if not layouts: return;
1039
  try: l_font, f_font = load_default(size=25), load_default(size=15); draw = ImageDraw.Draw(image, mode="RGBA")
1040
  except Exception as e: print(f"MDR Plot init error: {e}"); return
 
1041
  def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA):
1042
  try: x,y=pos; txt=str(num); txt_pos=(round(x)+3, round(y)+1); bbox=draw.textbbox(txt_pos,txt,font=font); bg_rect=(bbox[0]-2,bbox[1]-1,bbox[2]+2,bbox[3]+1); bg_color=(color[0],color[1],color[2],180); draw.rectangle(bg_rect,fill=bg_color); draw.text(txt_pos,txt,font=font,fill=(255,255,255,255))
1043
  except Exception as e: print(f"MDR Draw num error: {e}")
@@ -1049,26 +1178,65 @@ def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
1049
  try: draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1)
1050
  except Exception as e: print(f"MDR Fragment draw error: {e}")
1051
 
1052
- # --- MDR Extraction Engine (DocExtractor.py) ---
1053
  class MDRExtractionEngine:
1054
  """Core engine for extracting structured information from a document image."""
 
1055
  def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"]="cpu", ocr_for_each_layouts: bool=True, extract_formula: bool=True, extract_table_format: MDRTableLayoutParsedFormat|None=None):
1056
- self._model_dir = model_dir_path; self._device = device if torch.cuda.is_available() else "cpu"
 
1057
  self._ocr_each = ocr_for_each_layouts; self._ext_formula = extract_formula; self._ext_table = extract_table_format
1058
  self._yolo: YOLOv10 | None = None
1059
- self._ocr_engine = MDROcrEngine(device=self._device, model_dir_path=os.path.join(model_dir_path, "onnx_ocr"))
1060
- self._table_parser = MDRTableParser(device=self._device, model_path=os.path.join(model_dir_path, "struct_eqtable"))
1061
- self._latex_extractor = MDRLatexExtractor(model_path=os.path.join(model_dir_path, "latex"))
1062
- self._layout_reader = MDRLayoutReader(model_path=os.path.join(model_dir_path, "layoutreader"))
 
1063
  print(f"MDR Extraction Engine initialized on device: {self._device}")
 
 
1064
  def _get_yolo_model(self) -> YOLOv10 | None:
 
1065
  if self._yolo is None and YOLOv10 is not None:
1066
- bp = Path(self._model_dir)/"yolo"; mdr_ensure_directory(str(bp)); url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"; name = "doclayout_yolo_ft.pt"; mp = bp/name
1067
- if not mp.exists(): print(f"Downloading MDR YOLO model..."); mdr_download_model(url, mp)
1068
- try: self._yolo = YOLOv10(str(mp)); print("MDR YOLOv10 loaded.")
1069
- except Exception as e: print(f"ERROR loading MDR YOLOv10: {e}"); self._yolo = None
1070
- elif YOLOv10 is None: print("MDR YOLOv10 unavailable.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  return self._yolo
 
1072
  def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult:
1073
  """Analyzes a single page image to extract layout and content."""
1074
  print(" Engine: Analyzing image..."); optimizer = MDRImageOptimizer(image, adjust_points)
@@ -1088,6 +1256,7 @@ class MDRExtractionEngine:
1088
  print(" Engine: Finalizing coords..."); optimizer.finalize_layout_coords(layouts)
1089
  print(" Engine: Analysis complete.")
1090
  return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image)
 
1091
  def _run_yolo_detection(self, img: Image, yolo: YOLOv10) -> Generator[MDRLayoutElement, None, None]:
1092
  img_rgb = img.convert("RGB"); res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False)
1093
  if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: return
@@ -1101,6 +1270,7 @@ class MDRExtractionEngine:
1101
  if cls == MDRLayoutClass.TABLE: yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None)
1102
  elif cls == MDRLayoutClass.ISOLATE_FORMULA: yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None)
1103
  elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[])
 
1104
  def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]:
1105
  if not frags or not layouts: return layouts
1106
  layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts]
@@ -1120,11 +1290,13 @@ class MDRExtractionEngine:
1120
  layouts[best_idx].fragments.append(frag)
1121
  for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
1122
  return layouts
 
1123
  def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]):
1124
  for i, l in enumerate(layouts):
1125
  if l.cls == MDRLayoutClass.FIGURE: continue
1126
  try: mdr_correct_layout_fragments(self._ocr_engine, img, l)
1127
  except Exception as e: print(f" Engine: OCR correction error layout {i}: {e}")
 
1128
  def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer):
1129
  img_to_clip = optimizer.image
1130
  for l in layouts:
@@ -1135,25 +1307,33 @@ class MDRExtractionEngine:
1135
  try: t_img = mdr_clip_from_image(img_to_clip, l.rect); parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width>1 and t_img.height>1 else None
1136
  except Exception as e: print(f" Engine: Table parse error: {e}"); parsed = None
1137
  if parsed: l.parsed = (parsed, self._ext_table)
 
1138
  def _should_keep_layout(self, l: MDRLayoutElement) -> bool:
1139
  if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True
1140
  return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA]
1141
 
1142
- # --- MDR Page Section Linking (section.py) ---
1143
  class _MDR_LinkedShape:
1144
  """Internal helper for managing layout linking across pages."""
 
1145
  def __init__(self, layout: MDRLayoutElement): self.layout=layout; self.pre:list[MDRLayoutElement|None]=[None,None]; self.nex:list[MDRLayoutElement|None]=[None,None]
 
1146
  @property
1147
  def distance2(self) -> float: x,y=self.layout.rect.lt; return x*x+y*y
 
1148
  class MDRPageSection:
1149
  """Represents a page's layouts for framework detection via linking."""
 
1150
  def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]):
1151
  self._page_index = page_index; self._shapes = [_MDR_LinkedShape(l) for l in layouts]; self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0]))
 
1152
  @property
1153
  def page_index(self) -> int: return self._page_index
 
1154
  def find_framework_elements(self) -> list[MDRLayoutElement]:
1155
  """Identifies framework layouts based on links to other pages."""
1156
  return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)]
 
1157
  def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None:
1158
  """Links matching shapes between this section and the next."""
1159
  if offset not in (1,2): return
@@ -1169,6 +1349,7 @@ class MDRPageSection:
1169
  r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect); ovr = self._symmetric_iou(r1_rel, r2_rel)
1170
  if ovr > max_ovr: max_ovr, best_s2 = ovr, s2
1171
  if max_ovr >= 0.80 and best_s2 is not None: s1.nex[offset-1] = best_s2.layout; best_s2.pre[offset-1] = s1.layout # Link both ways
 
1172
  def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool:
1173
  l1, l2 = s1.layout, s2.layout; sz1, sz2 = l1.rect.size, l2.rect.size; thresh = 0.90
1174
  if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: return False
@@ -1183,10 +1364,12 @@ class MDRPageSection:
1183
  if max_sim > 0.75: matches += 1; if best_j != -1: used_f2[best_j] = True
1184
  max_c = max(c1, c2); rate_frags = matches / max_c
1185
  return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95))
 
1186
  def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float:
1187
  r1_rel = self._relative_rect(l1.rect.lt, f1.rect); r2_rel = self._relative_rect(l2.rect.lt, f2.rect)
1188
  geom_sim = self._symmetric_iou(r1_rel, r2_rel); text_sim, _ = mdr_check_text_similarity(f1.text, f2.text)
1189
  return (geom_sim + text_sim) / 2.0
 
1190
  def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None:
1191
  best_pair, min_dist2 = None, float('inf')
1192
  for i, s1 in enumerate(self._shapes):
@@ -1194,10 +1377,13 @@ class MDRPageSection:
1194
  if not match_list: continue
1195
  for s2 in match_list: dist2 = s1.distance2 + s2.distance2; if dist2 < min_dist2: min_dist2, best_pair = dist2, (s1, s2)
1196
  return best_pair
 
1197
  def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool:
1198
  if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx]
 
1199
  def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle:
1200
  ox, oy = origin; r=rect; return MDRRectangle(lt=(r.lt[0]-ox, r.lt[1]-oy), rt=(r.rt[0]-ox, r.rt[1]-oy), lb=(r.lb[0]-ox, r.lb[1]-oy), rb=(r.rb[0]-ox, r.rb[1]-oy))
 
1201
  def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float:
1202
  try: p1, p2 = Polygon(r1), Polygon(r2);
1203
  except: return 0.0
@@ -1207,21 +1393,26 @@ class MDRPageSection:
1207
  if inter.is_empty or inter.area < 1e-6: return 0.0
1208
  union_area = union.area; return inter.area / union_area if union_area > 1e-6 else 1.0
1209
 
1210
- # --- MDR Document Iterator (document.py) ---
1211
  _MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context
 
1212
  @dataclass
1213
  class MDRProcessingParams:
1214
  """Parameters for processing a document."""
1215
  pdf: str | FitzDocument; page_indexes: Iterable[int] | None; report_progress: MDRProgressReportCallback | None
 
1216
  class MDRDocumentIterator:
1217
  """Iterates through document pages, handles context, and calls the extraction engine."""
 
1218
  def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, debug_dir_path: str | None):
1219
  self._debug_dir = debug_dir_path
1220
  self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, ocr_for_each_layouts=(ocr_level==MDROcrLevel.OncePerLayout), extract_formula=extract_formula, extract_table_format=extract_table_format)
 
1221
  def iterate_sections(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]:
1222
  """Yields page index, extraction result, and content layouts for each requested page."""
1223
  for res, sec in self._process_and_link_sections(params):
1224
  framework = set(sec.find_framework_elements()); content = [l for l in res.layouts if l not in framework]; yield sec.page_index, res, content
 
1225
  def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[tuple[MDRExtractionResult, MDRPageSection], None, None]:
1226
  queue: list[tuple[MDRExtractionResult, MDRPageSection]] = []
1227
  for page_idx, res in self._run_extraction_on_pages(params):
@@ -1232,6 +1423,7 @@ class MDRDocumentIterator:
1232
  queue.append((res, cur_sec))
1233
  if len(queue) > _MDR_CONTEXT_PAGES: yield queue.pop(0)
1234
  for res, sec in queue: yield res, sec
 
1235
  def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]:
1236
  if self._debug_dir: mdr_ensure_directory(self._debug_dir)
1237
  doc, should_close = None, False
@@ -1254,6 +1446,7 @@ class MDRDocumentIterator:
1254
  except Exception as e: print(f" Iterator: Page {page_idx+1} processing error: {e}")
1255
  finally:
1256
  if should_close and doc: doc.close()
 
1257
  def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]:
1258
  count = doc.page_count;
1259
  if idxs is None: all_p = list(range(count)); return all_p, all_p
@@ -1261,19 +1454,22 @@ class MDRDocumentIterator:
1261
  for i in idxs:
1262
  if 0<=i<count: enable.add(i); [scan.add(j) for j in range(max(0, i-_MDR_CONTEXT_PAGES), min(count, i+_MDR_CONTEXT_PAGES+1))]
1263
  return sorted(list(scan)), sorted(list(enable))
 
1264
  def _render_page_image(self, page: FitzPage, dpi: int) -> Image:
1265
  mat = FitzMatrix(dpi/72.0, dpi/72.0); pix = page.get_pixmap(matrix=mat, alpha=False)
1266
  return frombytes("RGB", (pix.width, pix.height), pix.samples)
 
1267
  def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str):
1268
  try: plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy(); mdr_plot_layout(plot_img, res.layouts); plot_img.save(os.path.join(path, f"mdr_plot_page_{idx+1}.png"))
1269
  except Exception as e: print(f" Iterator: Plot generation error page {idx+1}: {e}")
1270
 
1271
- # --- MagicDataReadiness Main Processor (PDFPageExtractor.py) ---
1272
  class MagicPDFProcessor:
1273
  """
1274
  Main class for processing PDF documents to extract structured data blocks
1275
  using the MagicDataReadiness pipeline.
1276
  """
 
1277
  def __init__(self, device: Literal["cpu", "cuda"]="cuda", model_dir_path: str="./mdr_models", ocr_level: MDROcrLevel=MDROcrLevel.Once, extract_formula: bool=True, extract_table_format: MDRExtractedTableFormat|None=None, debug_dir_path: str|None=None):
1278
  """
1279
  Initializes the MagicPDFProcessor.
@@ -1436,7 +1632,7 @@ if __name__ == '__main__':
1436
  print(" MagicDataReadiness PDF Processor - Example Usage")
1437
  print("="*60)
1438
 
1439
- # --- 1. Configuration (!!! MODIFY THESE PATHS !!!) ---
1440
  # Directory where models are stored or will be downloaded
1441
  # IMPORTANT: Create this directory or ensure it's writable!
1442
  MDR_MODEL_DIRECTORY = "./mdr_pipeline_models"
@@ -1449,7 +1645,7 @@ if __name__ == '__main__':
1449
  MDR_DEBUG_DIRECTORY = "./mdr_debug_output"
1450
 
1451
  # Specify device ('cuda' or 'cpu').
1452
- MDR_DEVICE = "cuda"
1453
 
1454
  # Specify desired table format
1455
  MDR_TABLE_FORMAT = MDRExtractedTableFormat.MARKDOWN
 
38
  from PIL import ImageDraw
39
  from PIL.ImageFont import load_default, FreeTypeFont
40
  from shapely.geometry import Polygon
41
+ import pyclipper
42
  from unicodedata import category
43
  from alphabet_detector import AlphabetDetector
44
+ from munch import Munch
45
+ from transformers import LayoutLMv3ForTokenClassification
46
+ import onnxruntime
47
+ # --- HUGGING FACE HUB IMPORT ONLY BECAUSE RUNNING IN SPACES NOT NECESSARY IN PROD ---
48
+ from huggingface_hub import hf_hub_download, HfHubDownloadError
49
 
50
+ # --- External Dependencies ---
 
51
  try:
52
  from doclayout_yolo import YOLOv10
53
  except ImportError:
 
59
  print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.")
60
  LatexOCR = None
61
  try:
 
62
  pass # from struct_eqtable import build_model
63
  except ImportError:
64
  print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.")
65
 
66
  # --- MagicDataReadiness Core Components ---
67
 
68
+ # --- MDR Utilities ---
69
+
70
  def mdr_download_model(url: str, file_path: Path):
71
  """Downloads a model file from a URL to a local path."""
72
  try:
 
86
  if file_path.exists(): os.remove(file_path)
87
  raise e
88
 
89
+ # --- MDR Utilities ---
90
  def mdr_ensure_directory(path: str) -> str:
91
  """Ensures a directory exists, creating it if necessary."""
92
  path = os.path.abspath(path)
 
147
  return p1.intersection(p2).area
148
  except: return 0.0
149
 
150
+ # --- MDR Data Structures ---
151
  @dataclass
152
  class MDROcrFragment:
153
  """Represents a fragment of text identified by OCR."""
154
  order: int; text: str; rank: float; rect: MDRRectangle
155
+
156
  class MDRLayoutClass(Enum):
157
  """Enumeration of different layout types identified."""
158
  TITLE=0; PLAIN_TEXT=1; ABANDON=2; FIGURE=3; FIGURE_CAPTION=4; TABLE=5; TABLE_CAPTION=6; TABLE_FOOTNOTE=7; ISOLATE_FORMULA=8; FORMULA_CAPTION=9
159
+
160
  class MDRTableLayoutParsedFormat(Enum):
161
  """Enumeration for formats of parsed table content."""
162
  LATEX=auto(); MARKDOWN=auto(); HTML=auto()
163
+
164
  @dataclass
165
  class MDRBaseLayoutElement:
166
  """Base class for layout elements found on a page."""
167
  rect: MDRRectangle; fragments: list[MDROcrFragment]
168
+
169
  @dataclass
170
  class MDRPlainLayoutElement(MDRBaseLayoutElement):
171
  """Layout element for plain text, titles, captions, figures, etc."""
172
  cls: Literal[MDRLayoutClass.TITLE, MDRLayoutClass.PLAIN_TEXT, MDRLayoutClass.ABANDON, MDRLayoutClass.FIGURE, MDRLayoutClass.FIGURE_CAPTION, MDRLayoutClass.TABLE_CAPTION, MDRLayoutClass.TABLE_FOOTNOTE, MDRLayoutClass.FORMULA_CAPTION]
173
+
174
  @dataclass
175
  class MDRTableLayoutElement(MDRBaseLayoutElement):
176
  """Layout element specifically for tables."""
177
  parsed: tuple[str, MDRTableLayoutParsedFormat] | None; cls: Literal[MDRLayoutClass.TABLE] = MDRLayoutClass.TABLE
178
+
179
  @dataclass
180
  class MDRFormulaLayoutElement(MDRBaseLayoutElement):
181
  """Layout element specifically for formulas."""
182
  latex: str | None; cls: Literal[MDRLayoutClass.ISOLATE_FORMULA] = MDRLayoutClass.ISOLATE_FORMULA
183
+
184
  MDRLayoutElement = MDRPlainLayoutElement | MDRTableLayoutElement | MDRFormulaLayoutElement # Type alias
185
+
186
  @dataclass
187
  class MDRExtractionResult:
188
  """Holds the complete result of extracting from a single page image."""
189
  rotation: float; layouts: list[MDRLayoutElement]; extracted_image: Image; adjusted_image: Image | None
190
 
191
+ # --- MDR Data Structures ---
192
+
193
  MDRProgressReportCallback: TypeAlias = Callable[[int, int], None]
194
+
195
  class MDROcrLevel(Enum): Once=auto(); OncePerLayout=auto()
196
+
197
  class MDRExtractedTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); DISABLE=auto()
198
+
199
  class MDRTextKind(Enum): TITLE=0; PLAIN_TEXT=1; ABANDON=2
200
+
201
  @dataclass
202
  class MDRTextSpan:
203
  """Represents a span of text content within a block."""
204
  content: str; rank: float; rect: MDRRectangle
205
+
206
  @dataclass
207
  class MDRBasicBlock:
208
  """Base class for structured blocks extracted from the document."""
209
  rect: MDRRectangle; texts: list[MDRTextSpan]; font_size: float # Relative font size (0-1)
210
+
211
  @dataclass
212
  class MDRTextBlock(MDRBasicBlock):
213
  """A structured block containing text content."""
214
  kind: MDRTextKind; has_paragraph_indentation: bool = False; last_line_touch_end: bool = False
215
+
216
  class MDRTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); UNRECOGNIZABLE=auto()
217
+
218
  @dataclass
219
  class MDRTableBlock(MDRBasicBlock):
220
  """A structured block representing a table."""
221
  content: str; format: MDRTableFormat; image: Image # Image clip of the table
222
+
223
  @dataclass
224
  class MDRFormulaBlock(MDRBasicBlock):
225
  """A structured block representing a formula."""
226
  content: str | None; image: Image # Image clip of the formula
227
+
228
  @dataclass
229
  class MDRFigureBlock(MDRBasicBlock):
230
  """A structured block representing a figure/image."""
231
  image: Image # Image clip of the figure
232
+
233
  MDRAssetBlock = MDRTableBlock | MDRFormulaBlock | MDRFigureBlock # Type alias
234
+
235
  MDRStructuredBlock = MDRTextBlock | MDRAssetBlock # Type alias
236
 
237
+ # --- MDR Utilities ---
238
  def mdr_similarity_ratio(v1: float, v2: float) -> float:
239
  """Calculates the ratio of the smaller value to the larger value (0-1)."""
240
  if v1==0 and v2==0: return 1.0;
241
  if v1<0 or v2<0: return 0.0;
242
  v1, v2 = (v2, v1) if v1 > v2 else (v1, v2);
243
  return 1.0 if v2==0 else v1/v2
244
+
245
  def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[float, float]:
246
  """Calculates width/height of the intersection bounding box."""
247
  try:
 
251
  if inter.is_empty: return 0.0, 0.0;
252
  minx, miny, maxx, maxy = inter.bounds; return maxx-minx, maxy-miny
253
  except: return 0.0, 0.0
254
+
255
  _MDR_CJKA_PATTERN = re.compile(r"[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3\u0600-\u06ff]")
256
+
257
  def mdr_contains_cjka(text: str):
258
  """Checks if text contains Chinese, Japanese, Korean, or Arabic chars."""
259
  return bool(_MDR_CJKA_PATTERN.search(text)) if text else False
260
 
261
+ # --- MDR Text Processing ---
262
  class _MDR_TokenPhase(Enum): Init=0; Letter=1; Character=2; Number=3; Space=4
263
+
264
  _mdr_alphabet_detector = AlphabetDetector()
265
+
266
  def _mdr_is_letter(char: str):
267
  if not category(char).startswith("L"): return False
268
  try: return _mdr_alphabet_detector.is_latin(char) or _mdr_alphabet_detector.is_cyrillic(char) or _mdr_alphabet_detector.is_greek(char) or _mdr_alphabet_detector.is_hebrew(char)
269
  except: return False
270
+
271
  def mdr_split_into_words(text: str):
272
  """Splits text into words, numbers, and individual non-alphanumeric chars."""
273
  if not text: return;
 
287
  if is_s: phase=_MDR_TokenPhase.Space
288
  else: yield char; phase=_MDR_TokenPhase.Character
289
  if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): w=buf.getvalue(); yield w if w else None
290
+
291
  def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]:
292
  """Calculates word-based similarity between two texts."""
293
  w1, w2 = list(mdr_split_into_words(t1)), list(mdr_split_into_words(t2)); l1, l2 = len(w1), len(w2)
 
300
  if not taken[i] and word1==word2: taken[i]=True; matches+=1; break
301
  mismatches = l2 - matches; return 1.0 - (mismatches/l2), l2
302
 
303
+ # --- MDR Geometry Processing ---
304
  class MDRRotationAdjuster:
305
  """Adjusts point coordinates based on image rotation."""
306
+
307
  def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool):
308
  fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size)
309
  self._rot = rotation if to_origin_coordinate else -rotation
310
  self._c_off = (fs[0]/2.0, fs[1]/2.0); self._n_off = (ts[0]/2.0, ts[1]/2.0)
311
+
312
  def adjust(self, point: MDRPoint) -> MDRPoint:
313
  x, y = point[0]-self._c_off[0], point[1]-self._c_off[1]
314
  if x!=0 or y!=0: cos_r, sin_r = cos(self._rot), sin(self._rot); x, y = x*cos_r-y*sin_r, x*sin_r+y*cos_r
315
  return x+self._n_off[0], y+self._n_off[1]
316
+
317
  def mdr_normalize_vertical_rotation(rot: float) -> float:
318
  while rot >= pi: rot -= pi;
319
  while rot < 0: rot += pi;
320
  return rot
321
+
322
  def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[float]] | None:
323
  h_angs, v_angs = [], []
324
  for i, (p1, p2) in enumerate(rect.segments):
 
330
  else: v_angs.append(ang)
331
  if not h_angs or not v_angs: return None
332
  return h_angs, v_angs
333
+
334
  def _mdr_normalize_horizontal_angles(rots: list[float]) -> list[float]: return rots
335
+
336
  def _mdr_find_median(data: list[float]) -> float:
337
  if not data: return 0.0; s_data = sorted(data); n = len(s_data);
338
  return s_data[n//2] if n%2==1 else (s_data[n//2-1]+s_data[n//2])/2.0
339
+
340
  def _mdr_find_mean(data: list[float]) -> float: return sum(data)/len(data) if data else 0.0
341
+
342
  def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float:
343
  all_h, all_v = [], [];
344
  for f in frags:
 
351
  while rot_est >= pi/2: rot_est -= pi;
352
  while rot_est < -pi/2: rot_est += pi;
353
  return rot_est
354
+
355
  def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]:
356
  res = _mdr_get_rectangle_angles(rect);
357
  if res is None: return 0.0, pi/2.0;
 
359
  h_rots = _mdr_normalize_horizontal_angles(h_rots); v_rots = [mdr_normalize_vertical_rotation(a) for a in v_rots]
360
  return _mdr_find_mean(h_rots), _mdr_find_mean(v_rots)
361
 
362
+ # --- MDR ONNX OCR Internals ---
363
  class _MDR_PredictBase:
364
  """Base class for ONNX model prediction components."""
365
+
366
  def get_onnx_session(self, model_path: str, use_gpu: bool):
367
  try:
368
  sess_opts = onnxruntime.SessionOptions(); sess_opts.log_severity_level = 3
 
375
  if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
376
  print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.")
377
  raise e
378
+
379
  def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_outputs()]
380
+
381
  def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_inputs()]
382
+
383
  def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: return {name: img_np for name in names}
384
 
385
+ # --- MDR ONNX OCR Internals ---
386
  class _MDR_NormalizeImage:
387
+
388
  def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
389
  self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0))
390
  mean = mean if mean is not None else [0.485, 0.456, 0.406]; std = std if std is not None else [0.229, 0.224, 0.225]
391
  shape = (3, 1, 1) if order == 'chw' else (1, 1, 3); self.mean = np.array(mean).reshape(shape).astype('float32'); self.std = np.array(std).reshape(shape).astype('float32')
392
+
393
  def __call__(self, data): img = data['image']; img = np.array(img) if isinstance(img, Image) else img; data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std; return data
394
+
395
  class _MDR_DetResizeForTest:
396
+
397
  def __init__(self, **kwargs):
398
  self.resize_type = 0; self.keep_ratio = False
399
  if 'image_shape' in kwargs: self.image_shape = kwargs['image_shape']; self.resize_type = 1; self.keep_ratio = kwargs.get('keep_ratio', False)
400
  elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len']; self.limit_type = kwargs.get('limit_type', 'min')
401
  elif 'resize_long' in kwargs: self.resize_type = 2; self.resize_long = kwargs.get('resize_long', 960)
402
  else: self.limit_side_len = 736; self.limit_type = 'min'
403
+
404
  def __call__(self, data):
405
  img = data['image']; src_h, src_w, _ = img.shape
406
  if src_h + src_w < 64: img = self._pad(img)
 
409
  else: img, ratios = self._resize1(img)
410
  if img is None: return None
411
  data['image'] = img; data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]); return data
412
+
413
  def _pad(self, im, v=0): h,w,c=im.shape; p=np.zeros((max(32,h),max(32,w),c),np.uint8)+v; p[:h,:w,:]=im; return p
414
+
415
  def _resize1(self, img): rh,rw=self.image_shape; oh,ow=img.shape[:2]; if self.keep_ratio: rw=ow*rh/oh; N=math.ceil(rw/32); rw=N*32; r_h,r_w=float(rh)/oh,float(rw)/ow; img=cv2.resize(img,(int(rw),int(rh))); return img,[r_h,r_w]
416
+
417
  def _resize0(self, img): lsl=self.limit_side_len; h,w,_=img.shape; r=1.0; if self.limit_type=='max': r=float(lsl)/max(h,w) if max(h,w)>lsl else 1.0; elif self.limit_type=='min': r=float(lsl)/min(h,w) if min(h,w)<lsl else 1.0; elif self.limit_type=='resize_long': r=float(lsl)/max(h,w); else: raise Exception('Unsupported'); rh,rw=int(h*r),int(w*r); rh=max(int(round(rh/32)*32),32); rw=max(int(round(rw/32)*32),32); if int(rw)<=0 or int(rh)<=0: return None,(None,None); img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
418
+
419
  def _resize2(self, img): h,w,_=img.shape; rl=self.resize_long; r=float(rl)/max(h,w); rh,rw=int(h*r),int(w*r); ms=128; rh=(rh+ms-1)//ms*ms; rw=(rw+ms-1)//ms*ms; img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
420
+
421
  class _MDR_ToCHWImage:
422
+
423
  def __call__(self, data): img=data['image']; img=np.array(img) if isinstance(img,Image) else img; data['image']=img.transpose((2,0,1)); return data
424
+
425
  class _MDR_KeepKeys:
426
+
427
  def __init__(self, keep_keys, **kwargs): self.keep_keys=keep_keys
428
+
429
  def __call__(self, data): return [data[key] for key in self.keep_keys]
430
 
 
431
  def mdr_ocr_transform(data, ops=None):
432
  ops = ops if ops is not None else [];
433
  for op in ops: data = op(data); if data is None: return None;
434
  return data
435
+
436
  def mdr_ocr_create_operators(op_param_list, global_config=None):
437
  ops = []
438
  for operator in op_param_list:
 
444
  else: raise ValueError(f"Operator class '{op_class_name}' not found.")
445
  return ops
446
 
 
447
  class _MDR_DBPostProcess:
448
+
449
  def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs):
450
  self.thresh, self.box_thresh, self.max_cand = thresh, box_thresh, max_candidates; self.unclip_r, self.min_sz, self.score_m, self.box_t = unclip_ratio, 3, score_mode, box_type
451
  assert score_mode in ["slow", "fast"]; self.dila_k = np.array([[1,1],[1,1]], dtype=np.uint8) if use_dilation else None
452
+
453
  def _polygons_from_bitmap(self, pred, bmp, dw, dh):
454
  h, w = bmp.shape; boxes, scores = [], []
455
  contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
 
466
  box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
467
  boxes.append(box.tolist()); scores.append(score)
468
  return boxes, scores
469
+
470
  def _boxes_from_bitmap(self, pred, bmp, dw, dh):
471
  h, w = bmp.shape; contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
472
  num_contours = min(len(contours), self.max_cand); boxes, scores = [], []
 
482
  box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
483
  boxes.append(box.astype("int32")); scores.append(score)
484
  return np.array(boxes, dtype="int32"), scores
485
+
486
  def _unclip(self, box, ratio):
487
  poly = Polygon(box); dist = poly.area*ratio/poly.length; offset = pyclipper.PyclipperOffset(); offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
488
  expanded = offset.Execute(dist);
489
  if not expanded: raise ValueError("Unclip failed"); return np.array(expanded[0])
490
+
491
  def _get_mini_boxes(self, contour):
492
  bb = cv2.minAreaRect(contour); pts = sorted(list(cv2.boxPoints(bb)), key=lambda x:x[0])
493
  i1,i4 = (0,1) if pts[1][1]>pts[0][1] else (1,0); i2,i3 = (2,3) if pts[3][1]>pts[2][1] else (3,2)
494
  box = [pts[i1], pts[i2], pts[i3], pts[i4]]; return box, min(bb[1])
495
+
496
  def _box_score_fast(self, bmp, box):
497
  h,w = bmp.shape[:2]; xmin=np.clip(np.floor(box[:,0].min()).astype("int32"),0,w-1); xmax=np.clip(np.ceil(box[:,0].max()).astype("int32"),0,w-1)
498
  ymin=np.clip(np.floor(box[:,1].min()).astype("int32"),0,h-1); ymax=np.clip(np.ceil(box[:,1].max()).astype("int32"),0,h-1)
499
  mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=np.uint8); box[:,0]-=xmin; box[:,1]-=ymin
500
  cv2.fillPoly(mask, box.reshape(1,-1,2).astype("int32"), 1);
501
  return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
502
+
503
  def _box_score_slow(self, bmp, contour): # Not used if fast
504
  h,w = bmp.shape[:2]; contour = np.reshape(contour.copy(),(-1,2)); xmin=np.clip(np.min(contour[:,0]),0,w-1); xmax=np.clip(np.max(contour[:,0]),0,w-1)
505
  ymin=np.clip(np.min(contour[:,1]),0,h-1); ymax=np.clip(np.max(contour[:,1]),0,h-1); mask=np.zeros((ymax-ymin+1,xmax-xmin+1),dtype=np.uint8)
506
  contour[:,0]-=xmin; contour[:,1]-=ymin; cv2.fillPoly(mask, contour.reshape(1,-1,2).astype("int32"), 1);
507
  return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
508
+
509
  def __call__(self, outs_dict, shape_list):
510
  pred = outs_dict['maps'][:,0,:,:]; seg = pred > self.thresh; boxes_batch = []
511
  for batch_idx in range(pred.shape[0]):
 
516
  boxes_batch.append({'points': boxes})
517
  return boxes_batch
518
 
 
519
  class _MDR_TextDetector(_MDR_PredictBase):
520
+
521
  def __init__(self, args):
522
  super().__init__(); self.args = args
523
  pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}]
 
526
  self.post_op = _MDR_DBPostProcess(**post_params)
527
  self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu)
528
  self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
529
+
530
  def _order_pts(self, pts): r=np.zeros((4,2),dtype="float32"); s=pts.sum(axis=1); r[0]=pts[np.argmin(s)]; r[2]=pts[np.argmax(s)]; tmp=np.delete(pts,(np.argmin(s),np.argmax(s)),axis=0); d=np.diff(np.array(tmp),axis=1); r[1]=tmp[np.argmin(d)]; r[3]=tmp[np.argmax(d)]; return r
531
+
532
  def _clip_pts(self, pts, h, w): pts[:,0]=np.clip(pts[:,0],0,w-1); pts[:,1]=np.clip(pts[:,1],0,h-1); return pts
533
+
534
  def _filter_quad(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._order_pts(box); box=self._clip_pts(box,h,w); rw=int(np.linalg.norm(box[0]-box[1])); rh=int(np.linalg.norm(box[0]-box[3])); if rw<=3 or rh<=3: continue; new_boxes.append(box); return np.array(new_boxes)
535
+
536
  def _filter_poly(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._clip_pts(box,h,w); if Polygon(box).area<10: continue; new_boxes.append(box); return np.array(new_boxes)
537
+
538
  def __call__(self, img):
539
  ori_im = img.copy(); data = {"image": img}; data = mdr_ocr_transform(data, self.pre_op)
540
  if data is None: return None; img, shape_list = data;
 
543
  preds = {"maps": outputs[0]}; post_res = self.post_op(preds, shape_list); boxes = post_res[0]['points']
544
  return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type=='poly' else self._filter_quad(boxes, ori_im.shape)
545
 
 
546
  class _MDR_ClsPostProcess:
547
+
548
  def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0:'0', 1:'180'}
549
+
550
  def __call__(self, preds, label=None, *args, **kwargs):
551
  preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; idxs = preds.argmax(axis=1)
552
  return [(self.labels[idx], float(preds[i,idx])) for i,idx in enumerate(idxs)]
553
 
 
554
  class _MDR_TextClassifier(_MDR_PredictBase):
555
+
556
  def __init__(self, args):
557
  super().__init__(); self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape
558
  self.batch_num = args.cls_batch_num; self.thresh = args.cls_thresh; self.post_op = _MDR_ClsPostProcess(label_list=args.label_list)
559
  self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu); self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
560
+
561
  def _resize_norm(self, img):
562
  imgC,imgH,imgW = self.shape; h,w = img.shape[:2]; r=w/float(h) if h>0 else 0; rw=int(math.ceil(imgH*r)); rw=min(rw,imgW)
563
  resized = cv2.resize(img,(rw,imgH)); resized = resized.astype("float32")
564
  if imgC==1: resized = resized/255.0; resized = resized[np.newaxis,:]
565
  else: resized = resized.transpose((2,0,1))/255.0
566
  resized -= 0.5; resized /= 0.5; padding = np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:rw]=resized; return padding
567
+
568
  def __call__(self, img_list):
569
  if not img_list: return img_list, []; img_list_cp = copy.deepcopy(img_list); num = len(img_list_cp)
570
  ratios = [img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list_cp]; indices = np.argsort(np.array(ratios))
 
580
  if "180" in label and score > self.thresh: img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180)
581
  return img_list, results
582
 
 
583
  class _MDR_BaseRecLabelDecode:
584
+
585
  def __init__(self, char_path=None, use_space=False):
586
  self.beg, self.end, self.rev = "sos", "eos", False; self.chars = []
587
  if char_path is None: self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz")
 
592
  if any("\u0600"<=c<="\u06FF" for c in self.chars): self.rev=True
593
  except FileNotFoundError: print(f"Warn: Dict not found {char_path}"); self.chars=list("0123456789abcdefghijklmnopqrstuvwxyz"); if use_space: self.chars.append(" ")
594
  d_char = self.add_special_char(list(self.chars)); self.dict={c:i for i,c in enumerate(d_char)}; self.character=d_char
595
+
596
  def add_special_char(self, chars): return chars
597
+
598
  def get_ignored_tokens(self): return []
599
+
600
  def _reverse(self, pred): res=[]; cur=""; for c in pred: if not re.search("[a-zA-Z0-9 :*./%+-]",c): res.extend([cur,c] if cur!="" else [c]); cur="" else: cur+=c; if cur!="": res.append(cur); return "".join(res[::-1])
601
+
602
  def decode(self, idxs, probs=None, remove_dup=False):
603
  res=[]; ignored=self.get_ignored_tokens(); bs=len(idxs)
604
  for b_idx in range(bs):
 
612
  if self.rev: txt=self._reverse(txt)
613
  res.append((txt, float(np.mean(conf_l))))
614
  return res
615
+
616
  class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode):
617
  def __init__(self, char_path=None, use_space=False, **kwargs): super().__init__(char_path, use_space)
618
  def add_special_char(self, chars): return ["blank"]+chars
 
621
  preds = preds[-1] if isinstance(preds,(tuple,list)) else preds; preds = np.array(preds) if not isinstance(preds,np.ndarray) else preds
622
  idxs=preds.argmax(axis=2); probs=preds.max(axis=2); txt=self.decode(idxs, probs, remove_dup=True); return txt
623
 
 
624
  class _MDR_TextRecognizer(_MDR_PredictBase):
625
+
626
  def __init__(self, args):
627
  super().__init__(); shape_str=getattr(args,'rec_image_shape',"3,48,320"); self.shape=tuple(map(int,shape_str.split(',')))
628
  self.batch_num=getattr(args,'rec_batch_num',6); self.algo=getattr(args,'rec_algorithm','SVTR_LCNet')
629
  self.post_op=_MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, use_space=getattr(args,'use_space_char',True))
630
  self.sess=self.get_onnx_session(args.rec_model_dir, args.use_gpu); self.input_name=self.get_input_name(self.sess); self.output_name=self.get_output_name(self.sess)
631
+
632
  def _resize_norm(self, img, max_r):
633
  imgC,imgH,imgW = self.shape; h,w = img.shape[:2];
634
  if h==0 or w==0: return np.zeros((imgC,imgH,imgW),dtype=np.float32)
 
638
  if len(resized.shape)==2: resized=resized[:,:,np.newaxis]
639
  resized=resized.transpose((2,0,1))/255.0; resized-=0.5; resized/=0.5
640
  padding=np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:tw]=resized; return padding
641
+
642
  def __call__(self, img_list):
643
  if not img_list: return []; num=len(img_list); ratios=[img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list]
644
  indices=np.argsort(np.array(ratios)); results=[["",0.0]]*num; batch_n=self.batch_num
 
652
  for i in range(len(rec_out)): results[indices[start+i]]=rec_out[i]
653
  return results
654
 
655
+ # --- MDR ONNX OCR System ---
656
  class _MDR_TextSystem:
657
+
658
  def __init__(self, args):
659
  class ArgsObject: # Helper to access dict args with dot notation
660
  def __init__(self, **entries): self.__dict__.update(entries)
 
666
  self.drop_score = getattr(args, 'drop_score', 0.5)
667
  self.classifier = _MDR_TextClassifier(args) if self.use_cls else None
668
  self.crop_idx = 0; self.save_crop = getattr(args, 'save_crop_res', False); self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res")
669
+
670
  def _sort_boxes(self, boxes):
671
  if boxes is None or len(boxes)==0: return []
672
  def key(box): min_y=min(p[1] for p in box); min_x=min(p[0] for p in box); return (min_y, min_x)
673
  try: return list(sorted(boxes, key=key))
674
  except: return list(boxes) # Fallback
675
+
676
  def __call__(self, img, classify=True):
677
  ori_im = img.copy(); boxes = self.detector(img)
678
  if boxes is None or len(boxes)==0: return [], []
 
694
  if score >= self.drop_score: final_boxes.append(box); final_rec.append(res)
695
  if self.save_crop: self._save_crops(crops, rec_res)
696
  return final_boxes, final_rec
697
+
698
  def _save_crops(self, crops, recs):
699
  mdr_ensure_directory(self.crop_dir); num = len(crops)
700
  for i in range(num): txt, score = recs[i]; safe=re.sub(r'\W+', '_', txt)[:20]; fname=f"crop_{self.crop_idx+i}_{safe}_{score:.2f}.jpg"; cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i])
701
  self.crop_idx += num
702
 
703
+ # --- MDR ONNX OCR Utilities ---
704
  def mdr_get_rotated_crop(img, points):
705
  """Crops and perspective-transforms a quadrilateral region."""
706
  pts = np.array(points, dtype="float32"); assert len(pts)==4
 
712
  dh, dw = dst.shape[0:2]
713
  if dh>0 and dw>0 and dh*1.0/dw >= 1.5: dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE)
714
  return dst
715
+
716
  def mdr_get_min_area_crop(img, points):
717
  """Crops the minimum area rectangle containing the points."""
718
  bb = cv2.minAreaRect(np.array(points).astype(np.int32)); box_pts = cv2.boxPoints(bb)
719
  return mdr_get_rotated_crop(img, box_pts)
720
 
721
+ # --- MDR Layout Processing ---
722
  _MDR_INCLUDES_MIN_RATE = 0.99
723
+
724
  class _MDR_OverlapMatrixContext:
725
+
726
  def __init__(self, layouts: list[MDRLayoutElement]):
727
  length = len(layouts); self.polys: list[Polygon|None] = []
728
  for l in layouts:
 
736
  p2 = self.polys[j];
737
  if p2 is None: continue
738
  r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji
739
+
740
  def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2
741
  try: inter = p1.intersection(p2);
742
  except: return 0.0
 
745
  _, _, px1, py1 = p2.bounds; pw, ph = px1-p2.bounds[0], py1-p2.bounds[1]
746
  if pw < 1e-6 or ph < 1e-6: return 0.0
747
  wr = min(iw/pw, 1.0); hr = min(ih/ph, 1.0); return (wr+hr)/2.0
748
+
749
  def others(self, idx: int):
750
  for i, r in enumerate(self.matrix[idx]):
751
  if i != idx and i not in self.removed: yield r
752
+
753
  def includes(self, idx: int): # Layouts included BY idx
754
  for i, r in enumerate(self.matrix[idx]):
755
  if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE:
 
811
  for i, f in enumerate(merged): f.order = i
812
  return merged
813
 
814
+ # --- MDR Layout Processing ---
815
  _MDR_CORRECTION_MIN_OVERLAP = 0.5
816
+
817
  def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image, layout: MDRLayoutElement):
818
  if not layout.fragments: return;
819
  try:
 
848
  final = [n if n.rank >= o.rank else o for o, n in matched]; final.extend(unmatched_orig); final.extend(unmatched_new)
849
  layout.fragments = final; layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
850
 
851
+ # --- MDR OCR Engine ---
852
+
853
  _MDR_OCR_MODELS = {"det": ("ppocrv4","det","det.onnx"), "cls": ("ppocrv4","cls","cls.onnx"), "rec": ("ppocrv4","rec","rec.onnx"), "keys": ("ch_ppocr_server_v2.0","ppocr_keys_v1.txt")}
854
+
855
  _MDR_OCR_URL_BASE = "https://huggingface.co/moskize/OnnxOCR/resolve/main/"
856
+
857
  @dataclass
858
  class _MDR_ONNXParams: # Simplified container
859
  use_gpu: bool; det_model_dir: str; cls_model_dir: str; rec_model_dir: str; rec_char_dict_path: str
 
864
 
865
  class MDROcrEngine:
866
  """Handles OCR detection and recognition using ONNX models."""
867
+
868
  def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str):
869
  self._device = device; self._model_dir = mdr_ensure_directory(model_dir_path)
870
  self._text_system: _MDR_TextSystem | None = None; self._onnx_params: _MDR_ONNXParams | None = None
871
  self._ensure_models(); self._get_system() # Init on creation
872
+
873
  def _ensure_models(self):
874
  for key, parts in _MDR_OCR_MODELS.items():
875
  fp = Path(self._model_dir) / Path(*parts)
876
  if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join(parts); mdr_download_model(url, fp)
877
+
878
  def _get_system(self) -> _MDR_TextSystem | None:
879
  if self._text_system is None:
880
  paths = {k: str(Path(self._model_dir)/Path(*p)) for k,p in _MDR_OCR_MODELS.items()}
 
882
  try: self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.")
883
  except Exception as e: print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None
884
  return self._text_system
885
+
886
  def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]:
887
  """Finds and recognizes text fragments in a NumPy image (BGR)."""
888
  system = self._get_system()
 
895
  if not txt or mdr_is_whitespace(txt) or conf < 0.1: continue
896
  pts = [(float(p[0]), float(p[1])) for p in box_pts]
897
  if len(pts)==4: r=MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]); if r.is_valid and r.area>1: yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r)
898
+
899
  def _preprocess(self, img: np.ndarray) -> np.ndarray:
900
  if len(img.shape)==3 and img.shape[2]==4: a=img[:,:,3]/255.0; bg=(255,255,255); new=np.zeros_like(img[:,:,:3]); [setattr(new[:,:,i], 'flags.writeable', True) for i in range(3)]; [np.copyto(new[:,:,i], (bg[i]*(1-a)+img[:,:,i]*a)) for i in range(3)]; img=new.astype(np.uint8)
901
  elif len(img.shape)==2: img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
902
  elif not (len(img.shape)==3 and img.shape[2]==3): raise ValueError("Unsupported image format")
903
  return img
904
 
905
+ # --- MDR Layout Reading Internals ---
906
  _MDR_MAX_LEN = 510; _MDR_CLS_ID = 0; _MDR_SEP_ID = 2; _MDR_PAD_ID = 1
907
+
908
  def mdr_boxes_to_reader_inputs(boxes: List[List[int]], max_len=_MDR_MAX_LEN) -> Dict[str, torch.Tensor]:
909
  t_boxes = boxes[:max_len]; i_boxes = [[0,0,0,0]] + t_boxes + [[0,0,0,0]]
910
  i_ids = [_MDR_CLS_ID] + [_MDR_PAD_ID]*len(t_boxes) + [_MDR_SEP_ID]
911
  a_mask = [1]*len(i_ids); pad_len = (max_len+2) - len(i_ids)
912
  if pad_len > 0: i_boxes.extend([[0,0,0,0]]*pad_len); i_ids.extend([_MDR_PAD_ID]*pad_len); a_mask.extend([0]*pad_len)
913
  return {"bbox": torch.tensor([i_boxes]), "input_ids": torch.tensor([i_ids]), "attention_mask": torch.tensor([a_mask])}
914
+
915
  def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification) -> Dict[str, torch.Tensor]:
916
  return {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
917
+
918
  def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]:
919
  if length == 0: return []; rel_logits = logits[1:length+1, :length]; orders = rel_logits.argmax(dim=1).tolist()
920
  while True:
 
930
  orders[idx] = rel_logits[idx, :].argmax().item(); rel_logits[idx, order] = orig_logit
931
  return orders
932
 
933
+ # --- MDR Layout Reading Engine ---
934
  @dataclass
935
  class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; order: int; value: tuple[float, float, float, float]
936
+
937
  class MDRLayoutReader:
938
  """Determines reading order of layout elements using LayoutLMv3."""
939
+
940
  def __init__(self, model_path: str):
941
  self._model_path = model_path; self._model: LayoutLMv3ForTokenClassification | None = None
942
  self._device = "cuda" if torch.cuda.is_available() else "cpu"
943
+
944
  def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
945
  if self._model is None:
946
  cache = mdr_ensure_directory(self._model_path); name = "microsoft/layoutlmv3-base"; h_path = os.path.join(cache, "models--hantian--layoutreader")
 
950
  self._model.to(self._device); self._model.eval(); print(f"MDR LayoutReader loaded on {self._device}.")
951
  except Exception as e: print(f"ERROR loading MDR LayoutReader: {e}"); self._model = None
952
  return self._model
953
+
954
  def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]:
955
  w, h = size;
956
  if w<=0 or h<=0 or not layouts: return layouts;
 
977
  if len(orders) != len(bbox_list): print("MDR LayoutReader order mismatch"); return layouts # Fallback
978
  for i, order_idx in enumerate(orders): bbox_list[i].order = order_idx
979
  return self._apply_order(layouts, bbox_list)
980
+
981
  def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None:
982
  line_h = self._estimate_line_h(layouts); bbox_list = []
983
  for i, l in enumerate(layouts):
 
985
  else: bbox_list.extend(self._gen_virtual(l, i, line_h, w, h))
986
  if len(bbox_list) > _MDR_MAX_LEN: print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})"); return None
987
  bbox_list.sort(key=lambda b: (b.value[1], b.value[0])); return bbox_list
988
+
989
  def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]:
990
  layout_map = defaultdict(list); [layout_map[b.layout_index].append(b) for b in bbox_list]
991
  layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes]
 
999
  else: frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
1000
  for frag in frags: frag.order = nfo; nfo += 1
1001
  return sorted_layouts
1002
+
1003
  def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float:
1004
  heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1]>0]
1005
  return self._median(heights) if heights else 15.0
1006
+
1007
  def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]:
1008
  x0,y0,x1,y1 = l.rect.wrapper; lh,lw = y1-y0,x1-x0
1009
  if lh<=0 or lw<=0 or line_h<=0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(x0,y0,x1,y1)); return
 
1018
  ly0,ly1 = max(0,min(ph,cur_y)), max(0,min(ph,cur_y+act_line_h)); lx0,lx1 = max(0,min(pw,x0)), max(0,min(pw,x1))
1019
  if ly1>ly0 and lx1>lx0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(lx0,ly0,lx1,ly1))
1020
  cur_y += act_line_h
1021
+
1022
  def _median(self, nums: list[float|int]) -> float:
1023
  if not nums: return 0.0; s_nums = sorted(nums); n = len(s_nums)
1024
  return float(s_nums[n//2]) if n%2==1 else float((s_nums[n//2-1]+s_nums[n//2])/2.0)
1025
 
1026
+ # --- MDR LaTeX Extractor ---
1027
  class MDRLatexExtractor:
1028
  """Extracts LaTeX from formula images using pix2tex."""
1029
+
1030
  def __init__(self, model_path: str):
1031
  self._model_path = model_path; self._model: LatexOCR | None = None
1032
  self._device = "cuda" if torch.cuda.is_available() else "cpu"
1033
+
1034
  def extract(self, image: Image) -> str | None:
1035
  if LatexOCR is None: return None;
1036
  image = mdr_expand_image(image, 0.1); model = self._get_model()
 
1038
  try:
1039
  with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None
1040
  except Exception as e: print(f"MDR LaTeX error: {e}"); return None
1041
+
1042
  def _get_model(self) -> LatexOCR | None:
1043
  if self._model is None and LatexOCR is not None:
1044
  mdr_ensure_directory(self._model_path); wp, rp, cp = Path(self._model_path)/"weights.pth", Path(self._model_path)/"image_resizer.pth", Path(self._model_path)/"config.yaml"
 
1047
  try: args = Munch({"config":str(cp), "checkpoint":str(wp), "device":self._device, "no_cuda":self._device=="cpu", "no_resize":False, "temperature":0.0}); self._model = LatexOCR(args); print(f"MDR LaTeX loaded on {self._device}.")
1048
  except Exception as e: print(f"ERROR initializing MDR LatexOCR: {e}"); self._model = None
1049
  return self._model
1050
+
1051
  def _download(self):
1052
  tag = "v0.0.1"; base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/"; files = {"weights.pth": base+"weights.pth", "image_resizer.pth": base+"image_resizer.pth"}
1053
  mdr_ensure_directory(self._model_path); [mdr_download_model(url, Path(self._model_path)/name) for name, url in files.items() if not (Path(self._model_path)/name).exists()]
1054
 
1055
+ # --- MDR Table Parser ---
1056
  MDRTableOutputFormat = Literal["latex", "markdown", "html"]
1057
+
1058
  class MDRTableParser:
1059
  """Parses table structure/content from images using StructTable model."""
1060
+
1061
  def __init__(self, device: Literal["cpu", "cuda"], model_path: str):
1062
  self._model: Any | None = None; self._model_path = mdr_ensure_directory(model_path)
1063
  self._device = device if torch.cuda.is_available() and device=="cuda" else "cpu"
1064
  self._disabled = self._device == "cpu"
1065
  if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.")
1066
+
1067
  def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None:
1068
  if self._disabled: return None;
1069
  fmt: MDRTableOutputFormat | None = None
 
1078
  with torch.no_grad(): results = model([img_rgb], output_format=fmt)
1079
  return results[0] if results else None
1080
  except Exception as e: print(f"MDR Table parsing error: {e}"); return None
1081
+
1082
  def _get_model(self):
1083
  if self._model is None and not self._disabled:
1084
  try:
 
1091
  except Exception as e: print(f"ERROR loading MDR StructTable: {e}"); self._model=None
1092
  return self._model
1093
 
1094
+ # --- MDR Image Optimizer ---
1095
  _MDR_TINY_ROTATION = 0.005
1096
+
1097
  @dataclass
1098
  class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationAdjuster; fragment_origin_rectangles: list[MDRRectangle]
1099
+
1100
  class MDRImageOptimizer:
1101
  """Handles image rotation detection and coordinate adjustments."""
1102
+
1103
  def __init__(self, raw_image: Image, adjust_points: bool):
1104
  self._raw = raw_image; self._image = raw_image; self._adjust_points = adjust_points
1105
  self._fragments: list[MDROcrFragment] = []; self._rotation: float = 0.0; self._rot_ctx: _MDR_RotationContext | None = None
1106
+
1107
  @property
1108
  def image(self) -> Image: return self._image
1109
+
1110
  @property
1111
  def adjusted_image(self) -> Image | None: return self._image if self._rot_ctx is not None else None
1112
+
1113
  @property
1114
  def rotation(self) -> float: return self._rotation
1115
+
1116
  @property
1117
  def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
1118
+
1119
  def receive_fragments(self, fragments: list[MDROcrFragment]):
1120
  self._fragments = fragments;
1121
  if not fragments: return;
 
1130
  to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False),
1131
  to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True))
1132
  adj = self._rot_ctx.to_new; [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r:=f.rect)]
1133
+
1134
  def finalize_layout_coords(self, layouts: list[MDRLayoutElement]):
1135
  if self._rot_ctx is None or self._adjust_points: return
1136
  if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles)]
1137
  adj = self._rot_ctx.to_origin; [setattr(l, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for l in layouts if (r:=l.rect)]
1138
 
1139
+ # --- MDR Image Clipping ---
1140
  def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
1141
  """Clips a potentially rotated rectangle from an image."""
1142
  try:
 
1152
  out_w, out_h = ceil(avg_w+wrap_w), ceil(avg_h+wrap_h)
1153
  return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, fillcolor=(255,255,255))
1154
  except Exception as e: print(f"MDR Clipping error: {e}"); return new_image("RGB", (10,10), (255,255,255))
1155
+
1156
  def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
1157
  """Clips a layout region from the MDRExtractionResult image."""
1158
  img = res.adjusted_image if res.adjusted_image else res.extracted_image
1159
  return mdr_clip_from_image(img, layout.rect, wrap_w, wrap_h)
1160
 
1161
+ # --- MDR Debug Plotting ---
1162
  _MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); _MDR_LAYOUT_COLORS = { MDRLayoutClass.TITLE: (0x0A,0x12,0x2C,255), MDRLayoutClass.PLAIN_TEXT: (0x3C,0x67,0x90,255), MDRLayoutClass.ABANDON: (0xC0,0xBB,0xA9,180), MDRLayoutClass.FIGURE: (0x5B,0x91,0x3C,255), MDRLayoutClass.FIGURE_CAPTION: (0x77,0xB3,0x54,255), MDRLayoutClass.TABLE: (0x44,0x17,0x52,255), MDRLayoutClass.TABLE_CAPTION: (0x81,0x75,0xA0,255), MDRLayoutClass.TABLE_FOOTNOTE: (0xEF,0xB6,0xC9,255), MDRLayoutClass.ISOLATE_FORMULA: (0xFA,0x38,0x27,255), MDRLayoutClass.FORMULA_CAPTION: (0xFF,0x9D,0x24,255) }; _MDR_DEFAULT_COLOR = (0x80,0x80,0x80,255); _MDR_RGBA = tuple[int,int,int,int]
1163
+
1164
  def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
1165
  """Draws layout and fragment boxes onto an image for debugging."""
1166
  if not layouts: return;
1167
  try: l_font, f_font = load_default(size=25), load_default(size=15); draw = ImageDraw.Draw(image, mode="RGBA")
1168
  except Exception as e: print(f"MDR Plot init error: {e}"); return
1169
+
1170
  def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA):
1171
  try: x,y=pos; txt=str(num); txt_pos=(round(x)+3, round(y)+1); bbox=draw.textbbox(txt_pos,txt,font=font); bg_rect=(bbox[0]-2,bbox[1]-1,bbox[2]+2,bbox[3]+1); bg_color=(color[0],color[1],color[2],180); draw.rectangle(bg_rect,fill=bg_color); draw.text(txt_pos,txt,font=font,fill=(255,255,255,255))
1172
  except Exception as e: print(f"MDR Draw num error: {e}")
 
1178
  try: draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1)
1179
  except Exception as e: print(f"MDR Fragment draw error: {e}")
1180
 
1181
+ # --- MDR Extraction Engine ---
1182
  class MDRExtractionEngine:
1183
  """Core engine for extracting structured information from a document image."""
1184
+
1185
  def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"]="cpu", ocr_for_each_layouts: bool=True, extract_formula: bool=True, extract_table_format: MDRTableLayoutParsedFormat|None=None):
1186
+ self._model_dir = model_dir_path # Base directory for all models
1187
+ self._device = device if torch.cuda.is_available() else "cpu"
1188
  self._ocr_each = ocr_for_each_layouts; self._ext_formula = extract_formula; self._ext_table = extract_table_format
1189
  self._yolo: YOLOv10 | None = None
1190
+ # Initialize sub-modules, passing the main model_dir_path
1191
+ self._ocr_engine = MDROcrEngine(device=self._device, model_dir_path=os.path.join(self._model_dir, "onnx_ocr"))
1192
+ self._table_parser = MDRTableParser(device=self._device, model_path=os.path.join(self._model_dir, "struct_eqtable"))
1193
+ self._latex_extractor = MDRLatexExtractor(model_path=os.path.join(self._model_dir, "latex"))
1194
+ self._layout_reader = MDRLayoutReader(model_path=os.path.join(self._model_dir, "layoutreader"))
1195
  print(f"MDR Extraction Engine initialized on device: {self._device}")
1196
+
1197
+ # --- MODIFIED _get_yolo_model METHOD for HF ---
1198
  def _get_yolo_model(self) -> YOLOv10 | None:
1199
+ """Loads the YOLOv10 layout detection model using hf_hub_download."""
1200
  if self._yolo is None and YOLOv10 is not None:
1201
+ repo_id = "juliozhao/DocLayout-YOLO-DocStructBench"
1202
+ filename = "doclayout_yolo_docstructbench_imgsz1024.pt"
1203
+ # Use a subdirectory within the main model dir for YOLO cache via HF Hub
1204
+ yolo_cache_dir = Path(self._model_dir) / "yolo_hf_cache"
1205
+ mdr_ensure_directory(str(yolo_cache_dir)) # Ensure cache dir exists
1206
+
1207
+ print(f"Attempting to load YOLO model '{filename}' from repo '{repo_id}'...")
1208
+ print(f"Hugging Face Hub cache directory for YOLO: {yolo_cache_dir}")
1209
+
1210
+ try:
1211
+ # Download the model file using huggingface_hub, caching it
1212
+ yolo_model_filepath = hf_hub_download(
1213
+ repo_id=repo_id,
1214
+ filename=filename,
1215
+ cache_dir=yolo_cache_dir, # Cache within our designated structure
1216
+ local_files_only=False, # Allow download
1217
+ force_download=False, # Use cache if available
1218
+ )
1219
+ print(f"YOLO model file path: {yolo_model_filepath}")
1220
+
1221
+ # Load the model using the downloaded file path
1222
+ self._yolo = YOLOv10(yolo_model_filepath)
1223
+ print("MDR YOLOv10 model loaded successfully.")
1224
+
1225
+ except HfHubDownloadError as e:
1226
+ print(f"ERROR: Failed to download YOLO model from Hugging Face Hub: {e}")
1227
+ self._yolo = None
1228
+ except FileNotFoundError as e: # Catch if hf_hub_download fails finding file
1229
+ print(f"ERROR: YOLO model file not found via Hugging Face Hub: {e}")
1230
+ self._yolo = None
1231
+ except Exception as e:
1232
+ print(f"ERROR: Failed to load YOLOv10 model from {yolo_model_filepath}: {e}")
1233
+ self._yolo = None
1234
+
1235
+ elif YOLOv10 is None:
1236
+ print("MDR YOLOv10 class not available. Layout detection skipped.")
1237
+
1238
  return self._yolo
1239
+
1240
  def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult:
1241
  """Analyzes a single page image to extract layout and content."""
1242
  print(" Engine: Analyzing image..."); optimizer = MDRImageOptimizer(image, adjust_points)
 
1256
  print(" Engine: Finalizing coords..."); optimizer.finalize_layout_coords(layouts)
1257
  print(" Engine: Analysis complete.")
1258
  return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image)
1259
+
1260
  def _run_yolo_detection(self, img: Image, yolo: YOLOv10) -> Generator[MDRLayoutElement, None, None]:
1261
  img_rgb = img.convert("RGB"); res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False)
1262
  if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: return
 
1270
  if cls == MDRLayoutClass.TABLE: yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None)
1271
  elif cls == MDRLayoutClass.ISOLATE_FORMULA: yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None)
1272
  elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[])
1273
+
1274
  def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]:
1275
  if not frags or not layouts: return layouts
1276
  layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts]
 
1290
  layouts[best_idx].fragments.append(frag)
1291
  for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
1292
  return layouts
1293
+
1294
  def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]):
1295
  for i, l in enumerate(layouts):
1296
  if l.cls == MDRLayoutClass.FIGURE: continue
1297
  try: mdr_correct_layout_fragments(self._ocr_engine, img, l)
1298
  except Exception as e: print(f" Engine: OCR correction error layout {i}: {e}")
1299
+
1300
  def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer):
1301
  img_to_clip = optimizer.image
1302
  for l in layouts:
 
1307
  try: t_img = mdr_clip_from_image(img_to_clip, l.rect); parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width>1 and t_img.height>1 else None
1308
  except Exception as e: print(f" Engine: Table parse error: {e}"); parsed = None
1309
  if parsed: l.parsed = (parsed, self._ext_table)
1310
+
1311
  def _should_keep_layout(self, l: MDRLayoutElement) -> bool:
1312
  if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True
1313
  return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA]
1314
 
1315
+ # --- MDR Page Section Linking ---
1316
  class _MDR_LinkedShape:
1317
  """Internal helper for managing layout linking across pages."""
1318
+
1319
  def __init__(self, layout: MDRLayoutElement): self.layout=layout; self.pre:list[MDRLayoutElement|None]=[None,None]; self.nex:list[MDRLayoutElement|None]=[None,None]
1320
+
1321
  @property
1322
  def distance2(self) -> float: x,y=self.layout.rect.lt; return x*x+y*y
1323
+
1324
  class MDRPageSection:
1325
  """Represents a page's layouts for framework detection via linking."""
1326
+
1327
  def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]):
1328
  self._page_index = page_index; self._shapes = [_MDR_LinkedShape(l) for l in layouts]; self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0]))
1329
+
1330
  @property
1331
  def page_index(self) -> int: return self._page_index
1332
+
1333
  def find_framework_elements(self) -> list[MDRLayoutElement]:
1334
  """Identifies framework layouts based on links to other pages."""
1335
  return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)]
1336
+
1337
  def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None:
1338
  """Links matching shapes between this section and the next."""
1339
  if offset not in (1,2): return
 
1349
  r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect); ovr = self._symmetric_iou(r1_rel, r2_rel)
1350
  if ovr > max_ovr: max_ovr, best_s2 = ovr, s2
1351
  if max_ovr >= 0.80 and best_s2 is not None: s1.nex[offset-1] = best_s2.layout; best_s2.pre[offset-1] = s1.layout # Link both ways
1352
+
1353
  def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool:
1354
  l1, l2 = s1.layout, s2.layout; sz1, sz2 = l1.rect.size, l2.rect.size; thresh = 0.90
1355
  if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: return False
 
1364
  if max_sim > 0.75: matches += 1; if best_j != -1: used_f2[best_j] = True
1365
  max_c = max(c1, c2); rate_frags = matches / max_c
1366
  return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95))
1367
+
1368
  def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float:
1369
  r1_rel = self._relative_rect(l1.rect.lt, f1.rect); r2_rel = self._relative_rect(l2.rect.lt, f2.rect)
1370
  geom_sim = self._symmetric_iou(r1_rel, r2_rel); text_sim, _ = mdr_check_text_similarity(f1.text, f2.text)
1371
  return (geom_sim + text_sim) / 2.0
1372
+
1373
  def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None:
1374
  best_pair, min_dist2 = None, float('inf')
1375
  for i, s1 in enumerate(self._shapes):
 
1377
  if not match_list: continue
1378
  for s2 in match_list: dist2 = s1.distance2 + s2.distance2; if dist2 < min_dist2: min_dist2, best_pair = dist2, (s1, s2)
1379
  return best_pair
1380
+
1381
  def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool:
1382
  if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx]
1383
+
1384
  def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle:
1385
  ox, oy = origin; r=rect; return MDRRectangle(lt=(r.lt[0]-ox, r.lt[1]-oy), rt=(r.rt[0]-ox, r.rt[1]-oy), lb=(r.lb[0]-ox, r.lb[1]-oy), rb=(r.rb[0]-ox, r.rb[1]-oy))
1386
+
1387
  def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float:
1388
  try: p1, p2 = Polygon(r1), Polygon(r2);
1389
  except: return 0.0
 
1393
  if inter.is_empty or inter.area < 1e-6: return 0.0
1394
  union_area = union.area; return inter.area / union_area if union_area > 1e-6 else 1.0
1395
 
1396
+ # --- MDR Document Iterator ---
1397
  _MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context
1398
+
1399
  @dataclass
1400
  class MDRProcessingParams:
1401
  """Parameters for processing a document."""
1402
  pdf: str | FitzDocument; page_indexes: Iterable[int] | None; report_progress: MDRProgressReportCallback | None
1403
+
1404
  class MDRDocumentIterator:
1405
  """Iterates through document pages, handles context, and calls the extraction engine."""
1406
+
1407
  def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, debug_dir_path: str | None):
1408
  self._debug_dir = debug_dir_path
1409
  self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, ocr_for_each_layouts=(ocr_level==MDROcrLevel.OncePerLayout), extract_formula=extract_formula, extract_table_format=extract_table_format)
1410
+
1411
  def iterate_sections(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]:
1412
  """Yields page index, extraction result, and content layouts for each requested page."""
1413
  for res, sec in self._process_and_link_sections(params):
1414
  framework = set(sec.find_framework_elements()); content = [l for l in res.layouts if l not in framework]; yield sec.page_index, res, content
1415
+
1416
  def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[tuple[MDRExtractionResult, MDRPageSection], None, None]:
1417
  queue: list[tuple[MDRExtractionResult, MDRPageSection]] = []
1418
  for page_idx, res in self._run_extraction_on_pages(params):
 
1423
  queue.append((res, cur_sec))
1424
  if len(queue) > _MDR_CONTEXT_PAGES: yield queue.pop(0)
1425
  for res, sec in queue: yield res, sec
1426
+
1427
  def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]:
1428
  if self._debug_dir: mdr_ensure_directory(self._debug_dir)
1429
  doc, should_close = None, False
 
1446
  except Exception as e: print(f" Iterator: Page {page_idx+1} processing error: {e}")
1447
  finally:
1448
  if should_close and doc: doc.close()
1449
+
1450
  def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]:
1451
  count = doc.page_count;
1452
  if idxs is None: all_p = list(range(count)); return all_p, all_p
 
1454
  for i in idxs:
1455
  if 0<=i<count: enable.add(i); [scan.add(j) for j in range(max(0, i-_MDR_CONTEXT_PAGES), min(count, i+_MDR_CONTEXT_PAGES+1))]
1456
  return sorted(list(scan)), sorted(list(enable))
1457
+
1458
  def _render_page_image(self, page: FitzPage, dpi: int) -> Image:
1459
  mat = FitzMatrix(dpi/72.0, dpi/72.0); pix = page.get_pixmap(matrix=mat, alpha=False)
1460
  return frombytes("RGB", (pix.width, pix.height), pix.samples)
1461
+
1462
  def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str):
1463
  try: plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy(); mdr_plot_layout(plot_img, res.layouts); plot_img.save(os.path.join(path, f"mdr_plot_page_{idx+1}.png"))
1464
  except Exception as e: print(f" Iterator: Plot generation error page {idx+1}: {e}")
1465
 
1466
+ # --- MagicDataReadiness Main Processor ---
1467
  class MagicPDFProcessor:
1468
  """
1469
  Main class for processing PDF documents to extract structured data blocks
1470
  using the MagicDataReadiness pipeline.
1471
  """
1472
+
1473
  def __init__(self, device: Literal["cpu", "cuda"]="cuda", model_dir_path: str="./mdr_models", ocr_level: MDROcrLevel=MDROcrLevel.Once, extract_formula: bool=True, extract_table_format: MDRExtractedTableFormat|None=None, debug_dir_path: str|None=None):
1474
  """
1475
  Initializes the MagicPDFProcessor.
 
1632
  print(" MagicDataReadiness PDF Processor - Example Usage")
1633
  print("="*60)
1634
 
1635
+ # --- 1. Configuration (!!! MODIFY THESE PATHS WHEN OUTSIDE HF !!!) ---
1636
  # Directory where models are stored or will be downloaded
1637
  # IMPORTANT: Create this directory or ensure it's writable!
1638
  MDR_MODEL_DIRECTORY = "./mdr_pipeline_models"
 
1645
  MDR_DEBUG_DIRECTORY = "./mdr_debug_output"
1646
 
1647
  # Specify device ('cuda' or 'cpu').
1648
+ MDR_DEVICE = "cpu"
1649
 
1650
  # Specify desired table format
1651
  MDR_TABLE_FORMAT = MDRExtractedTableFormat.MARKDOWN