Update mdr_pdf_parser.py
Browse files- mdr_pdf_parser.py +182 -123
mdr_pdf_parser.py
CHANGED
@@ -46,6 +46,7 @@ from transformers import LayoutLMv3ForTokenClassification
|
|
46 |
import onnxruntime
|
47 |
# --- HUGGING FACE HUB IMPORT ONLY BECAUSE RUNNING IN SPACES NOT NECESSARY IN PROD ---
|
48 |
from huggingface_hub import hf_hub_download, HfHubDownloadError
|
|
|
49 |
|
50 |
# --- External Dependencies ---
|
51 |
try:
|
@@ -59,7 +60,7 @@ except ImportError:
|
|
59 |
print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.")
|
60 |
LatexOCR = None
|
61 |
try:
|
62 |
-
pass # from struct_eqtable import build_model
|
63 |
except ImportError:
|
64 |
print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.")
|
65 |
|
@@ -303,12 +304,12 @@ def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]:
|
|
303 |
# --- MDR Geometry Processing ---
|
304 |
class MDRRotationAdjuster:
|
305 |
"""Adjusts point coordinates based on image rotation."""
|
306 |
-
|
307 |
def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool):
|
308 |
fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size)
|
309 |
self._rot = rotation if to_origin_coordinate else -rotation
|
310 |
self._c_off = (fs[0]/2.0, fs[1]/2.0); self._n_off = (ts[0]/2.0, ts[1]/2.0)
|
311 |
-
|
312 |
def adjust(self, point: MDRPoint) -> MDRPoint:
|
313 |
x, y = point[0]-self._c_off[0], point[1]-self._c_off[1]
|
314 |
if x!=0 or y!=0: cos_r, sin_r = cos(self._rot), sin(self._rot); x, y = x*cos_r-y*sin_r, x*sin_r+y*cos_r
|
@@ -362,7 +363,7 @@ def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]:
|
|
362 |
# --- MDR ONNX OCR Internals ---
|
363 |
class _MDR_PredictBase:
|
364 |
"""Base class for ONNX model prediction components."""
|
365 |
-
|
366 |
def get_onnx_session(self, model_path: str, use_gpu: bool):
|
367 |
try:
|
368 |
sess_opts = onnxruntime.SessionOptions(); sess_opts.log_severity_level = 3
|
@@ -375,32 +376,32 @@ class _MDR_PredictBase:
|
|
375 |
if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
376 |
print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.")
|
377 |
raise e
|
378 |
-
|
379 |
def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_outputs()]
|
380 |
-
|
381 |
def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_inputs()]
|
382 |
-
|
383 |
def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: return {name: img_np for name in names}
|
384 |
|
385 |
# --- MDR ONNX OCR Internals ---
|
386 |
class _MDR_NormalizeImage:
|
387 |
-
|
388 |
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
389 |
self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0))
|
390 |
mean = mean if mean is not None else [0.485, 0.456, 0.406]; std = std if std is not None else [0.229, 0.224, 0.225]
|
391 |
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3); self.mean = np.array(mean).reshape(shape).astype('float32'); self.std = np.array(std).reshape(shape).astype('float32')
|
392 |
-
|
393 |
def __call__(self, data): img = data['image']; img = np.array(img) if isinstance(img, Image) else img; data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std; return data
|
394 |
|
395 |
class _MDR_DetResizeForTest:
|
396 |
-
|
397 |
def __init__(self, **kwargs):
|
398 |
self.resize_type = 0; self.keep_ratio = False
|
399 |
if 'image_shape' in kwargs: self.image_shape = kwargs['image_shape']; self.resize_type = 1; self.keep_ratio = kwargs.get('keep_ratio', False)
|
400 |
elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len']; self.limit_type = kwargs.get('limit_type', 'min')
|
401 |
elif 'resize_long' in kwargs: self.resize_type = 2; self.resize_long = kwargs.get('resize_long', 960)
|
402 |
else: self.limit_side_len = 736; self.limit_type = 'min'
|
403 |
-
|
404 |
def __call__(self, data):
|
405 |
img = data['image']; src_h, src_w, _ = img.shape
|
406 |
if src_h + src_w < 64: img = self._pad(img)
|
@@ -409,23 +410,23 @@ class _MDR_DetResizeForTest:
|
|
409 |
else: img, ratios = self._resize1(img)
|
410 |
if img is None: return None
|
411 |
data['image'] = img; data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]); return data
|
412 |
-
|
413 |
def _pad(self, im, v=0): h,w,c=im.shape; p=np.zeros((max(32,h),max(32,w),c),np.uint8)+v; p[:h,:w,:]=im; return p
|
414 |
-
|
415 |
-
def _resize1(self, img): rh,rw=self.image_shape; oh,ow=img.shape[:2]; if self.keep_ratio: rw=ow*rh/oh; N=
|
416 |
-
|
417 |
def _resize0(self, img): lsl=self.limit_side_len; h,w,_=img.shape; r=1.0; if self.limit_type=='max': r=float(lsl)/max(h,w) if max(h,w)>lsl else 1.0; elif self.limit_type=='min': r=float(lsl)/min(h,w) if min(h,w)<lsl else 1.0; elif self.limit_type=='resize_long': r=float(lsl)/max(h,w); else: raise Exception('Unsupported'); rh,rw=int(h*r),int(w*r); rh=max(int(round(rh/32)*32),32); rw=max(int(round(rw/32)*32),32); if int(rw)<=0 or int(rh)<=0: return None,(None,None); img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
418 |
-
|
419 |
def _resize2(self, img): h,w,_=img.shape; rl=self.resize_long; r=float(rl)/max(h,w); rh,rw=int(h*r),int(w*r); ms=128; rh=(rh+ms-1)//ms*ms; rw=(rw+ms-1)//ms*ms; img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
420 |
|
421 |
class _MDR_ToCHWImage:
|
422 |
-
|
423 |
def __call__(self, data): img=data['image']; img=np.array(img) if isinstance(img,Image) else img; data['image']=img.transpose((2,0,1)); return data
|
424 |
|
425 |
class _MDR_KeepKeys:
|
426 |
-
|
427 |
def __init__(self, keep_keys, **kwargs): self.keep_keys=keep_keys
|
428 |
-
|
429 |
def __call__(self, data): return [data[key] for key in self.keep_keys]
|
430 |
|
431 |
def mdr_ocr_transform(data, ops=None):
|
@@ -445,11 +446,11 @@ def mdr_ocr_create_operators(op_param_list, global_config=None):
|
|
445 |
return ops
|
446 |
|
447 |
class _MDR_DBPostProcess:
|
448 |
-
|
449 |
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs):
|
450 |
self.thresh, self.box_thresh, self.max_cand = thresh, box_thresh, max_candidates; self.unclip_r, self.min_sz, self.score_m, self.box_t = unclip_ratio, 3, score_mode, box_type
|
451 |
assert score_mode in ["slow", "fast"]; self.dila_k = np.array([[1,1],[1,1]], dtype=np.uint8) if use_dilation else None
|
452 |
-
|
453 |
def _polygons_from_bitmap(self, pred, bmp, dw, dh):
|
454 |
h, w = bmp.shape; boxes, scores = [], []
|
455 |
contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
@@ -466,7 +467,7 @@ class _MDR_DBPostProcess:
|
|
466 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
467 |
boxes.append(box.tolist()); scores.append(score)
|
468 |
return boxes, scores
|
469 |
-
|
470 |
def _boxes_from_bitmap(self, pred, bmp, dw, dh):
|
471 |
h, w = bmp.shape; contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
472 |
num_contours = min(len(contours), self.max_cand); boxes, scores = [], []
|
@@ -482,30 +483,30 @@ class _MDR_DBPostProcess:
|
|
482 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
483 |
boxes.append(box.astype("int32")); scores.append(score)
|
484 |
return np.array(boxes, dtype="int32"), scores
|
485 |
-
|
486 |
def _unclip(self, box, ratio):
|
487 |
poly = Polygon(box); dist = poly.area*ratio/poly.length; offset = pyclipper.PyclipperOffset(); offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
488 |
expanded = offset.Execute(dist);
|
489 |
if not expanded: raise ValueError("Unclip failed"); return np.array(expanded[0])
|
490 |
-
|
491 |
def _get_mini_boxes(self, contour):
|
492 |
bb = cv2.minAreaRect(contour); pts = sorted(list(cv2.boxPoints(bb)), key=lambda x:x[0])
|
493 |
i1,i4 = (0,1) if pts[1][1]>pts[0][1] else (1,0); i2,i3 = (2,3) if pts[3][1]>pts[2][1] else (3,2)
|
494 |
box = [pts[i1], pts[i2], pts[i3], pts[i4]]; return box, min(bb[1])
|
495 |
-
|
496 |
def _box_score_fast(self, bmp, box):
|
497 |
h,w = bmp.shape[:2]; xmin=np.clip(np.floor(box[:,0].min()).astype("int32"),0,w-1); xmax=np.clip(np.ceil(box[:,0].max()).astype("int32"),0,w-1)
|
498 |
ymin=np.clip(np.floor(box[:,1].min()).astype("int32"),0,h-1); ymax=np.clip(np.ceil(box[:,1].max()).astype("int32"),0,h-1)
|
499 |
mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=np.uint8); box[:,0]-=xmin; box[:,1]-=ymin
|
500 |
cv2.fillPoly(mask, box.reshape(1,-1,2).astype("int32"), 1);
|
501 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
502 |
-
|
503 |
def _box_score_slow(self, bmp, contour): # Not used if fast
|
504 |
h,w = bmp.shape[:2]; contour = np.reshape(contour.copy(),(-1,2)); xmin=np.clip(np.min(contour[:,0]),0,w-1); xmax=np.clip(np.max(contour[:,0]),0,w-1)
|
505 |
ymin=np.clip(np.min(contour[:,1]),0,h-1); ymax=np.clip(np.max(contour[:,1]),0,h-1); mask=np.zeros((ymax-ymin+1,xmax-xmin+1),dtype=np.uint8)
|
506 |
contour[:,0]-=xmin; contour[:,1]-=ymin; cv2.fillPoly(mask, contour.reshape(1,-1,2).astype("int32"), 1);
|
507 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
508 |
-
|
509 |
def __call__(self, outs_dict, shape_list):
|
510 |
pred = outs_dict['maps'][:,0,:,:]; seg = pred > self.thresh; boxes_batch = []
|
511 |
for batch_idx in range(pred.shape[0]):
|
@@ -517,7 +518,7 @@ class _MDR_DBPostProcess:
|
|
517 |
return boxes_batch
|
518 |
|
519 |
class _MDR_TextDetector(_MDR_PredictBase):
|
520 |
-
|
521 |
def __init__(self, args):
|
522 |
super().__init__(); self.args = args
|
523 |
pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}]
|
@@ -526,15 +527,15 @@ class _MDR_TextDetector(_MDR_PredictBase):
|
|
526 |
self.post_op = _MDR_DBPostProcess(**post_params)
|
527 |
self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu)
|
528 |
self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
529 |
-
|
530 |
def _order_pts(self, pts): r=np.zeros((4,2),dtype="float32"); s=pts.sum(axis=1); r[0]=pts[np.argmin(s)]; r[2]=pts[np.argmax(s)]; tmp=np.delete(pts,(np.argmin(s),np.argmax(s)),axis=0); d=np.diff(np.array(tmp),axis=1); r[1]=tmp[np.argmin(d)]; r[3]=tmp[np.argmax(d)]; return r
|
531 |
-
|
532 |
def _clip_pts(self, pts, h, w): pts[:,0]=np.clip(pts[:,0],0,w-1); pts[:,1]=np.clip(pts[:,1],0,h-1); return pts
|
533 |
-
|
534 |
def _filter_quad(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._order_pts(box); box=self._clip_pts(box,h,w); rw=int(np.linalg.norm(box[0]-box[1])); rh=int(np.linalg.norm(box[0]-box[3])); if rw<=3 or rh<=3: continue; new_boxes.append(box); return np.array(new_boxes)
|
535 |
-
|
536 |
def _filter_poly(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._clip_pts(box,h,w); if Polygon(box).area<10: continue; new_boxes.append(box); return np.array(new_boxes)
|
537 |
-
|
538 |
def __call__(self, img):
|
539 |
ori_im = img.copy(); data = {"image": img}; data = mdr_ocr_transform(data, self.pre_op)
|
540 |
if data is None: return None; img, shape_list = data;
|
@@ -544,27 +545,27 @@ class _MDR_TextDetector(_MDR_PredictBase):
|
|
544 |
return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type=='poly' else self._filter_quad(boxes, ori_im.shape)
|
545 |
|
546 |
class _MDR_ClsPostProcess:
|
547 |
-
|
548 |
def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0:'0', 1:'180'}
|
549 |
-
|
550 |
def __call__(self, preds, label=None, *args, **kwargs):
|
551 |
preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; idxs = preds.argmax(axis=1)
|
552 |
return [(self.labels[idx], float(preds[i,idx])) for i,idx in enumerate(idxs)]
|
553 |
|
554 |
class _MDR_TextClassifier(_MDR_PredictBase):
|
555 |
-
|
556 |
def __init__(self, args):
|
557 |
super().__init__(); self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape
|
558 |
self.batch_num = args.cls_batch_num; self.thresh = args.cls_thresh; self.post_op = _MDR_ClsPostProcess(label_list=args.label_list)
|
559 |
self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu); self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
560 |
-
|
561 |
def _resize_norm(self, img):
|
562 |
-
imgC,imgH,imgW = self.shape; h,w = img.shape[:2]; r=w/float(h) if h>0 else 0; rw=int(
|
563 |
resized = cv2.resize(img,(rw,imgH)); resized = resized.astype("float32")
|
564 |
if imgC==1: resized = resized/255.0; resized = resized[np.newaxis,:]
|
565 |
else: resized = resized.transpose((2,0,1))/255.0
|
566 |
resized -= 0.5; resized /= 0.5; padding = np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:rw]=resized; return padding
|
567 |
-
|
568 |
def __call__(self, img_list):
|
569 |
if not img_list: return img_list, []; img_list_cp = copy.deepcopy(img_list); num = len(img_list_cp)
|
570 |
ratios = [img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list_cp]; indices = np.argsort(np.array(ratios))
|
@@ -581,7 +582,7 @@ class _MDR_TextClassifier(_MDR_PredictBase):
|
|
581 |
return img_list, results
|
582 |
|
583 |
class _MDR_BaseRecLabelDecode:
|
584 |
-
|
585 |
def __init__(self, char_path=None, use_space=False):
|
586 |
self.beg, self.end, self.rev = "sos", "eos", False; self.chars = []
|
587 |
if char_path is None: self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz")
|
@@ -592,13 +593,13 @@ class _MDR_BaseRecLabelDecode:
|
|
592 |
if any("\u0600"<=c<="\u06FF" for c in self.chars): self.rev=True
|
593 |
except FileNotFoundError: print(f"Warn: Dict not found {char_path}"); self.chars=list("0123456789abcdefghijklmnopqrstuvwxyz"); if use_space: self.chars.append(" ")
|
594 |
d_char = self.add_special_char(list(self.chars)); self.dict={c:i for i,c in enumerate(d_char)}; self.character=d_char
|
595 |
-
|
596 |
def add_special_char(self, chars): return chars
|
597 |
-
|
598 |
def get_ignored_tokens(self): return []
|
599 |
-
|
600 |
def _reverse(self, pred): res=[]; cur=""; for c in pred: if not re.search("[a-zA-Z0-9 :*./%+-]",c): res.extend([cur,c] if cur!="" else [c]); cur="" else: cur+=c; if cur!="": res.append(cur); return "".join(res[::-1])
|
601 |
-
|
602 |
def decode(self, idxs, probs=None, remove_dup=False):
|
603 |
res=[]; ignored=self.get_ignored_tokens(); bs=len(idxs)
|
604 |
for b_idx in range(bs):
|
@@ -628,17 +629,17 @@ class _MDR_TextRecognizer(_MDR_PredictBase):
|
|
628 |
self.batch_num=getattr(args,'rec_batch_num',6); self.algo=getattr(args,'rec_algorithm','SVTR_LCNet')
|
629 |
self.post_op=_MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, use_space=getattr(args,'use_space_char',True))
|
630 |
self.sess=self.get_onnx_session(args.rec_model_dir, args.use_gpu); self.input_name=self.get_input_name(self.sess); self.output_name=self.get_output_name(self.sess)
|
631 |
-
|
632 |
def _resize_norm(self, img, max_r):
|
633 |
imgC,imgH,imgW = self.shape; h,w = img.shape[:2];
|
634 |
if h==0 or w==0: return np.zeros((imgC,imgH,imgW),dtype=np.float32)
|
635 |
-
r=w/float(h); tw=min(imgW, int(
|
636 |
resized=cv2.resize(img,(tw,imgH)); resized=resized.astype("float32")
|
637 |
if imgC==1 and len(resized.shape)==3: resized=cv2.cvtColor(resized,cv2.COLOR_BGR2GRAY); resized=resized[:,:,np.newaxis]
|
638 |
if len(resized.shape)==2: resized=resized[:,:,np.newaxis]
|
639 |
resized=resized.transpose((2,0,1))/255.0; resized-=0.5; resized/=0.5
|
640 |
padding=np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:tw]=resized; return padding
|
641 |
-
|
642 |
def __call__(self, img_list):
|
643 |
if not img_list: return []; num=len(img_list); ratios=[img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list]
|
644 |
indices=np.argsort(np.array(ratios)); results=[["",0.0]]*num; batch_n=self.batch_num
|
@@ -654,7 +655,7 @@ class _MDR_TextRecognizer(_MDR_PredictBase):
|
|
654 |
|
655 |
# --- MDR ONNX OCR System ---
|
656 |
class _MDR_TextSystem:
|
657 |
-
|
658 |
def __init__(self, args):
|
659 |
class ArgsObject: # Helper to access dict args with dot notation
|
660 |
def __init__(self, **entries): self.__dict__.update(entries)
|
@@ -666,13 +667,13 @@ class _MDR_TextSystem:
|
|
666 |
self.drop_score = getattr(args, 'drop_score', 0.5)
|
667 |
self.classifier = _MDR_TextClassifier(args) if self.use_cls else None
|
668 |
self.crop_idx = 0; self.save_crop = getattr(args, 'save_crop_res', False); self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res")
|
669 |
-
|
670 |
def _sort_boxes(self, boxes):
|
671 |
if boxes is None or len(boxes)==0: return []
|
672 |
def key(box): min_y=min(p[1] for p in box); min_x=min(p[0] for p in box); return (min_y, min_x)
|
673 |
try: return list(sorted(boxes, key=key))
|
674 |
except: return list(boxes) # Fallback
|
675 |
-
|
676 |
def __call__(self, img, classify=True):
|
677 |
ori_im = img.copy(); boxes = self.detector(img)
|
678 |
if boxes is None or len(boxes)==0: return [], []
|
@@ -694,7 +695,7 @@ class _MDR_TextSystem:
|
|
694 |
if score >= self.drop_score: final_boxes.append(box); final_rec.append(res)
|
695 |
if self.save_crop: self._save_crops(crops, rec_res)
|
696 |
return final_boxes, final_rec
|
697 |
-
|
698 |
def _save_crops(self, crops, recs):
|
699 |
mdr_ensure_directory(self.crop_dir); num = len(crops)
|
700 |
for i in range(num): txt, score = recs[i]; safe=re.sub(r'\W+', '_', txt)[:20]; fname=f"crop_{self.crop_idx+i}_{safe}_{score:.2f}.jpg"; cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i])
|
@@ -722,7 +723,7 @@ def mdr_get_min_area_crop(img, points):
|
|
722 |
_MDR_INCLUDES_MIN_RATE = 0.99
|
723 |
|
724 |
class _MDR_OverlapMatrixContext:
|
725 |
-
|
726 |
def __init__(self, layouts: list[MDRLayoutElement]):
|
727 |
length = len(layouts); self.polys: list[Polygon|None] = []
|
728 |
for l in layouts:
|
@@ -736,7 +737,7 @@ class _MDR_OverlapMatrixContext:
|
|
736 |
p2 = self.polys[j];
|
737 |
if p2 is None: continue
|
738 |
r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji
|
739 |
-
|
740 |
def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2
|
741 |
try: inter = p1.intersection(p2);
|
742 |
except: return 0.0
|
@@ -745,11 +746,11 @@ class _MDR_OverlapMatrixContext:
|
|
745 |
_, _, px1, py1 = p2.bounds; pw, ph = px1-p2.bounds[0], py1-p2.bounds[1]
|
746 |
if pw < 1e-6 or ph < 1e-6: return 0.0
|
747 |
wr = min(iw/pw, 1.0); hr = min(ih/ph, 1.0); return (wr+hr)/2.0
|
748 |
-
|
749 |
def others(self, idx: int):
|
750 |
for i, r in enumerate(self.matrix[idx]):
|
751 |
if i != idx and i not in self.removed: yield r
|
752 |
-
|
753 |
def includes(self, idx: int): # Layouts included BY idx
|
754 |
for i, r in enumerate(self.matrix[idx]):
|
755 |
if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE:
|
@@ -864,17 +865,17 @@ class _MDR_ONNXParams: # Simplified container
|
|
864 |
|
865 |
class MDROcrEngine:
|
866 |
"""Handles OCR detection and recognition using ONNX models."""
|
867 |
-
|
868 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str):
|
869 |
self._device = device; self._model_dir = mdr_ensure_directory(model_dir_path)
|
870 |
self._text_system: _MDR_TextSystem | None = None; self._onnx_params: _MDR_ONNXParams | None = None
|
871 |
self._ensure_models(); self._get_system() # Init on creation
|
872 |
-
|
873 |
def _ensure_models(self):
|
874 |
for key, parts in _MDR_OCR_MODELS.items():
|
875 |
fp = Path(self._model_dir) / Path(*parts)
|
876 |
if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join(parts); mdr_download_model(url, fp)
|
877 |
-
|
878 |
def _get_system(self) -> _MDR_TextSystem | None:
|
879 |
if self._text_system is None:
|
880 |
paths = {k: str(Path(self._model_dir)/Path(*p)) for k,p in _MDR_OCR_MODELS.items()}
|
@@ -882,7 +883,7 @@ class MDROcrEngine:
|
|
882 |
try: self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.")
|
883 |
except Exception as e: print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None
|
884 |
return self._text_system
|
885 |
-
|
886 |
def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]:
|
887 |
"""Finds and recognizes text fragments in a NumPy image (BGR)."""
|
888 |
system = self._get_system()
|
@@ -895,7 +896,7 @@ class MDROcrEngine:
|
|
895 |
if not txt or mdr_is_whitespace(txt) or conf < 0.1: continue
|
896 |
pts = [(float(p[0]), float(p[1])) for p in box_pts]
|
897 |
if len(pts)==4: r=MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]); if r.is_valid and r.area>1: yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r)
|
898 |
-
|
899 |
def _preprocess(self, img: np.ndarray) -> np.ndarray:
|
900 |
if len(img.shape)==3 and img.shape[2]==4: a=img[:,:,3]/255.0; bg=(255,255,255); new=np.zeros_like(img[:,:,:3]); [setattr(new[:,:,i], 'flags.writeable', True) for i in range(3)]; [np.copyto(new[:,:,i], (bg[i]*(1-a)+img[:,:,i]*a)) for i in range(3)]; img=new.astype(np.uint8)
|
901 |
elif len(img.shape)==2: img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
@@ -936,11 +937,11 @@ class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; or
|
|
936 |
|
937 |
class MDRLayoutReader:
|
938 |
"""Determines reading order of layout elements using LayoutLMv3."""
|
939 |
-
|
940 |
def __init__(self, model_path: str):
|
941 |
self._model_path = model_path; self._model: LayoutLMv3ForTokenClassification | None = None
|
942 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
943 |
-
|
944 |
def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
|
945 |
if self._model is None:
|
946 |
cache = mdr_ensure_directory(self._model_path); name = "microsoft/layoutlmv3-base"; h_path = os.path.join(cache, "models--hantian--layoutreader")
|
@@ -950,7 +951,7 @@ class MDRLayoutReader:
|
|
950 |
self._model.to(self._device); self._model.eval(); print(f"MDR LayoutReader loaded on {self._device}.")
|
951 |
except Exception as e: print(f"ERROR loading MDR LayoutReader: {e}"); self._model = None
|
952 |
return self._model
|
953 |
-
|
954 |
def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]:
|
955 |
w, h = size;
|
956 |
if w<=0 or h<=0 or not layouts: return layouts;
|
@@ -977,7 +978,7 @@ class MDRLayoutReader:
|
|
977 |
if len(orders) != len(bbox_list): print("MDR LayoutReader order mismatch"); return layouts # Fallback
|
978 |
for i, order_idx in enumerate(orders): bbox_list[i].order = order_idx
|
979 |
return self._apply_order(layouts, bbox_list)
|
980 |
-
|
981 |
def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None:
|
982 |
line_h = self._estimate_line_h(layouts); bbox_list = []
|
983 |
for i, l in enumerate(layouts):
|
@@ -985,7 +986,7 @@ class MDRLayoutReader:
|
|
985 |
else: bbox_list.extend(self._gen_virtual(l, i, line_h, w, h))
|
986 |
if len(bbox_list) > _MDR_MAX_LEN: print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})"); return None
|
987 |
bbox_list.sort(key=lambda b: (b.value[1], b.value[0])); return bbox_list
|
988 |
-
|
989 |
def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]:
|
990 |
layout_map = defaultdict(list); [layout_map[b.layout_index].append(b) for b in bbox_list]
|
991 |
layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes]
|
@@ -999,11 +1000,11 @@ class MDRLayoutReader:
|
|
999 |
else: frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
1000 |
for frag in frags: frag.order = nfo; nfo += 1
|
1001 |
return sorted_layouts
|
1002 |
-
|
1003 |
def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float:
|
1004 |
heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1]>0]
|
1005 |
return self._median(heights) if heights else 15.0
|
1006 |
-
|
1007 |
def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]:
|
1008 |
x0,y0,x1,y1 = l.rect.wrapper; lh,lw = y1-y0,x1-x0
|
1009 |
if lh<=0 or lw<=0 or line_h<=0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(x0,y0,x1,y1)); return
|
@@ -1018,7 +1019,7 @@ class MDRLayoutReader:
|
|
1018 |
ly0,ly1 = max(0,min(ph,cur_y)), max(0,min(ph,cur_y+act_line_h)); lx0,lx1 = max(0,min(pw,x0)), max(0,min(pw,x1))
|
1019 |
if ly1>ly0 and lx1>lx0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(lx0,ly0,lx1,ly1))
|
1020 |
cur_y += act_line_h
|
1021 |
-
|
1022 |
def _median(self, nums: list[float|int]) -> float:
|
1023 |
if not nums: return 0.0; s_nums = sorted(nums); n = len(s_nums)
|
1024 |
return float(s_nums[n//2]) if n%2==1 else float((s_nums[n//2-1]+s_nums[n//2])/2.0)
|
@@ -1026,11 +1027,11 @@ class MDRLayoutReader:
|
|
1026 |
# --- MDR LaTeX Extractor ---
|
1027 |
class MDRLatexExtractor:
|
1028 |
"""Extracts LaTeX from formula images using pix2tex."""
|
1029 |
-
|
1030 |
def __init__(self, model_path: str):
|
1031 |
self._model_path = model_path; self._model: LatexOCR | None = None
|
1032 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
1033 |
-
|
1034 |
def extract(self, image: Image) -> str | None:
|
1035 |
if LatexOCR is None: return None;
|
1036 |
image = mdr_expand_image(image, 0.1); model = self._get_model()
|
@@ -1038,7 +1039,7 @@ class MDRLatexExtractor:
|
|
1038 |
try:
|
1039 |
with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None
|
1040 |
except Exception as e: print(f"MDR LaTeX error: {e}"); return None
|
1041 |
-
|
1042 |
def _get_model(self) -> LatexOCR | None:
|
1043 |
if self._model is None and LatexOCR is not None:
|
1044 |
mdr_ensure_directory(self._model_path); wp, rp, cp = Path(self._model_path)/"weights.pth", Path(self._model_path)/"image_resizer.pth", Path(self._model_path)/"config.yaml"
|
@@ -1047,7 +1048,7 @@ class MDRLatexExtractor:
|
|
1047 |
try: args = Munch({"config":str(cp), "checkpoint":str(wp), "device":self._device, "no_cuda":self._device=="cpu", "no_resize":False, "temperature":0.0}); self._model = LatexOCR(args); print(f"MDR LaTeX loaded on {self._device}.")
|
1048 |
except Exception as e: print(f"ERROR initializing MDR LatexOCR: {e}"); self._model = None
|
1049 |
return self._model
|
1050 |
-
|
1051 |
def _download(self):
|
1052 |
tag = "v0.0.1"; base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/"; files = {"weights.pth": base+"weights.pth", "image_resizer.pth": base+"image_resizer.pth"}
|
1053 |
mdr_ensure_directory(self._model_path); [mdr_download_model(url, Path(self._model_path)/name) for name, url in files.items() if not (Path(self._model_path)/name).exists()]
|
@@ -1057,13 +1058,13 @@ MDRTableOutputFormat = Literal["latex", "markdown", "html"]
|
|
1057 |
|
1058 |
class MDRTableParser:
|
1059 |
"""Parses table structure/content from images using StructTable model."""
|
1060 |
-
|
1061 |
def __init__(self, device: Literal["cpu", "cuda"], model_path: str):
|
1062 |
self._model: Any | None = None; self._model_path = mdr_ensure_directory(model_path)
|
1063 |
self._device = device if torch.cuda.is_available() and device=="cuda" else "cpu"
|
1064 |
self._disabled = self._device == "cpu"
|
1065 |
if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.")
|
1066 |
-
|
1067 |
def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None:
|
1068 |
if self._disabled: return None;
|
1069 |
fmt: MDRTableOutputFormat | None = None
|
@@ -1078,7 +1079,7 @@ class MDRTableParser:
|
|
1078 |
with torch.no_grad(): results = model([img_rgb], output_format=fmt)
|
1079 |
return results[0] if results else None
|
1080 |
except Exception as e: print(f"MDR Table parsing error: {e}"); return None
|
1081 |
-
|
1082 |
def _get_model(self):
|
1083 |
if self._model is None and not self._disabled:
|
1084 |
try:
|
@@ -1099,23 +1100,23 @@ class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationA
|
|
1099 |
|
1100 |
class MDRImageOptimizer:
|
1101 |
"""Handles image rotation detection and coordinate adjustments."""
|
1102 |
-
|
1103 |
def __init__(self, raw_image: Image, adjust_points: bool):
|
1104 |
self._raw = raw_image; self._image = raw_image; self._adjust_points = adjust_points
|
1105 |
self._fragments: list[MDROcrFragment] = []; self._rotation: float = 0.0; self._rot_ctx: _MDR_RotationContext | None = None
|
1106 |
-
|
1107 |
@property
|
1108 |
def image(self) -> Image: return self._image
|
1109 |
-
|
1110 |
@property
|
1111 |
def adjusted_image(self) -> Image | None: return self._image if self._rot_ctx is not None else None
|
1112 |
-
|
1113 |
@property
|
1114 |
def rotation(self) -> float: return self._rotation
|
1115 |
-
|
1116 |
@property
|
1117 |
def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
1118 |
-
|
1119 |
def receive_fragments(self, fragments: list[MDROcrFragment]):
|
1120 |
self._fragments = fragments;
|
1121 |
if not fragments: return;
|
@@ -1130,7 +1131,7 @@ class MDRImageOptimizer:
|
|
1130 |
to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False),
|
1131 |
to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True))
|
1132 |
adj = self._rot_ctx.to_new; [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r:=f.rect)]
|
1133 |
-
|
1134 |
def finalize_layout_coords(self, layouts: list[MDRLayoutElement]):
|
1135 |
if self._rot_ctx is None or self._adjust_points: return
|
1136 |
if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles)]
|
@@ -1166,7 +1167,7 @@ def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
|
|
1166 |
if not layouts: return;
|
1167 |
try: l_font, f_font = load_default(size=25), load_default(size=15); draw = ImageDraw.Draw(image, mode="RGBA")
|
1168 |
except Exception as e: print(f"MDR Plot init error: {e}"); return
|
1169 |
-
|
1170 |
def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA):
|
1171 |
try: x,y=pos; txt=str(num); txt_pos=(round(x)+3, round(y)+1); bbox=draw.textbbox(txt_pos,txt,font=font); bg_rect=(bbox[0]-2,bbox[1]-1,bbox[2]+2,bbox[3]+1); bg_color=(color[0],color[1],color[2],180); draw.rectangle(bg_rect,fill=bg_color); draw.text(txt_pos,txt,font=font,fill=(255,255,255,255))
|
1172 |
except Exception as e: print(f"MDR Draw num error: {e}")
|
@@ -1181,7 +1182,7 @@ def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
|
|
1181 |
# --- MDR Extraction Engine ---
|
1182 |
class MDRExtractionEngine:
|
1183 |
"""Core engine for extracting structured information from a document image."""
|
1184 |
-
|
1185 |
def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"]="cpu", ocr_for_each_layouts: bool=True, extract_formula: bool=True, extract_table_format: MDRTableLayoutParsedFormat|None=None):
|
1186 |
self._model_dir = model_dir_path # Base directory for all models
|
1187 |
self._device = device if torch.cuda.is_available() else "cpu"
|
@@ -1236,7 +1237,7 @@ class MDRExtractionEngine:
|
|
1236 |
print("MDR YOLOv10 class not available. Layout detection skipped.")
|
1237 |
|
1238 |
return self._yolo
|
1239 |
-
|
1240 |
def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult:
|
1241 |
"""Analyzes a single page image to extract layout and content."""
|
1242 |
print(" Engine: Analyzing image..."); optimizer = MDRImageOptimizer(image, adjust_points)
|
@@ -1256,7 +1257,7 @@ class MDRExtractionEngine:
|
|
1256 |
print(" Engine: Finalizing coords..."); optimizer.finalize_layout_coords(layouts)
|
1257 |
print(" Engine: Analysis complete.")
|
1258 |
return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image)
|
1259 |
-
|
1260 |
def _run_yolo_detection(self, img: Image, yolo: YOLOv10) -> Generator[MDRLayoutElement, None, None]:
|
1261 |
img_rgb = img.convert("RGB"); res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False)
|
1262 |
if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: return
|
@@ -1270,7 +1271,7 @@ class MDRExtractionEngine:
|
|
1270 |
if cls == MDRLayoutClass.TABLE: yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None)
|
1271 |
elif cls == MDRLayoutClass.ISOLATE_FORMULA: yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None)
|
1272 |
elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[])
|
1273 |
-
|
1274 |
def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]:
|
1275 |
if not frags or not layouts: return layouts
|
1276 |
layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts]
|
@@ -1290,13 +1291,13 @@ class MDRExtractionEngine:
|
|
1290 |
layouts[best_idx].fragments.append(frag)
|
1291 |
for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
1292 |
return layouts
|
1293 |
-
|
1294 |
def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]):
|
1295 |
for i, l in enumerate(layouts):
|
1296 |
if l.cls == MDRLayoutClass.FIGURE: continue
|
1297 |
try: mdr_correct_layout_fragments(self._ocr_engine, img, l)
|
1298 |
except Exception as e: print(f" Engine: OCR correction error layout {i}: {e}")
|
1299 |
-
|
1300 |
def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer):
|
1301 |
img_to_clip = optimizer.image
|
1302 |
for l in layouts:
|
@@ -1307,7 +1308,7 @@ class MDRExtractionEngine:
|
|
1307 |
try: t_img = mdr_clip_from_image(img_to_clip, l.rect); parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width>1 and t_img.height>1 else None
|
1308 |
except Exception as e: print(f" Engine: Table parse error: {e}"); parsed = None
|
1309 |
if parsed: l.parsed = (parsed, self._ext_table)
|
1310 |
-
|
1311 |
def _should_keep_layout(self, l: MDRLayoutElement) -> bool:
|
1312 |
if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True
|
1313 |
return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA]
|
@@ -1315,25 +1316,25 @@ class MDRExtractionEngine:
|
|
1315 |
# --- MDR Page Section Linking ---
|
1316 |
class _MDR_LinkedShape:
|
1317 |
"""Internal helper for managing layout linking across pages."""
|
1318 |
-
|
1319 |
def __init__(self, layout: MDRLayoutElement): self.layout=layout; self.pre:list[MDRLayoutElement|None]=[None,None]; self.nex:list[MDRLayoutElement|None]=[None,None]
|
1320 |
-
|
1321 |
@property
|
1322 |
def distance2(self) -> float: x,y=self.layout.rect.lt; return x*x+y*y
|
1323 |
|
1324 |
class MDRPageSection:
|
1325 |
"""Represents a page's layouts for framework detection via linking."""
|
1326 |
-
|
1327 |
def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]):
|
1328 |
self._page_index = page_index; self._shapes = [_MDR_LinkedShape(l) for l in layouts]; self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0]))
|
1329 |
-
|
1330 |
@property
|
1331 |
def page_index(self) -> int: return self._page_index
|
1332 |
-
|
1333 |
def find_framework_elements(self) -> list[MDRLayoutElement]:
|
1334 |
"""Identifies framework layouts based on links to other pages."""
|
1335 |
return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)]
|
1336 |
-
|
1337 |
def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None:
|
1338 |
"""Links matching shapes between this section and the next."""
|
1339 |
if offset not in (1,2): return
|
@@ -1364,12 +1365,12 @@ class MDRPageSection:
|
|
1364 |
if max_sim > 0.75: matches += 1; if best_j != -1: used_f2[best_j] = True
|
1365 |
max_c = max(c1, c2); rate_frags = matches / max_c
|
1366 |
return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95))
|
1367 |
-
|
1368 |
def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float:
|
1369 |
r1_rel = self._relative_rect(l1.rect.lt, f1.rect); r2_rel = self._relative_rect(l2.rect.lt, f2.rect)
|
1370 |
geom_sim = self._symmetric_iou(r1_rel, r2_rel); text_sim, _ = mdr_check_text_similarity(f1.text, f2.text)
|
1371 |
return (geom_sim + text_sim) / 2.0
|
1372 |
-
|
1373 |
def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None:
|
1374 |
best_pair, min_dist2 = None, float('inf')
|
1375 |
for i, s1 in enumerate(self._shapes):
|
@@ -1377,13 +1378,13 @@ class MDRPageSection:
|
|
1377 |
if not match_list: continue
|
1378 |
for s2 in match_list: dist2 = s1.distance2 + s2.distance2; if dist2 < min_dist2: min_dist2, best_pair = dist2, (s1, s2)
|
1379 |
return best_pair
|
1380 |
-
|
1381 |
def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool:
|
1382 |
if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx]
|
1383 |
-
|
1384 |
def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle:
|
1385 |
ox, oy = origin; r=rect; return MDRRectangle(lt=(r.lt[0]-ox, r.lt[1]-oy), rt=(r.rt[0]-ox, r.rt[1]-oy), lb=(r.lb[0]-ox, r.lb[1]-oy), rb=(r.rb[0]-ox, r.rb[1]-oy))
|
1386 |
-
|
1387 |
def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float:
|
1388 |
try: p1, p2 = Polygon(r1), Polygon(r2);
|
1389 |
except: return 0.0
|
@@ -1400,19 +1401,19 @@ _MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context
|
|
1400 |
class MDRProcessingParams:
|
1401 |
"""Parameters for processing a document."""
|
1402 |
pdf: str | FitzDocument; page_indexes: Iterable[int] | None; report_progress: MDRProgressReportCallback | None
|
1403 |
-
|
1404 |
class MDRDocumentIterator:
|
1405 |
"""Iterates through document pages, handles context, and calls the extraction engine."""
|
1406 |
-
|
1407 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, debug_dir_path: str | None):
|
1408 |
self._debug_dir = debug_dir_path
|
1409 |
self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, ocr_for_each_layouts=(ocr_level==MDROcrLevel.OncePerLayout), extract_formula=extract_formula, extract_table_format=extract_table_format)
|
1410 |
-
|
1411 |
def iterate_sections(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]:
|
1412 |
"""Yields page index, extraction result, and content layouts for each requested page."""
|
1413 |
for res, sec in self._process_and_link_sections(params):
|
1414 |
framework = set(sec.find_framework_elements()); content = [l for l in res.layouts if l not in framework]; yield sec.page_index, res, content
|
1415 |
-
|
1416 |
def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[tuple[MDRExtractionResult, MDRPageSection], None, None]:
|
1417 |
queue: list[tuple[MDRExtractionResult, MDRPageSection]] = []
|
1418 |
for page_idx, res in self._run_extraction_on_pages(params):
|
@@ -1423,7 +1424,7 @@ class MDRDocumentIterator:
|
|
1423 |
queue.append((res, cur_sec))
|
1424 |
if len(queue) > _MDR_CONTEXT_PAGES: yield queue.pop(0)
|
1425 |
for res, sec in queue: yield res, sec
|
1426 |
-
|
1427 |
def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]:
|
1428 |
if self._debug_dir: mdr_ensure_directory(self._debug_dir)
|
1429 |
doc, should_close = None, False
|
@@ -1446,7 +1447,7 @@ class MDRDocumentIterator:
|
|
1446 |
except Exception as e: print(f" Iterator: Page {page_idx+1} processing error: {e}")
|
1447 |
finally:
|
1448 |
if should_close and doc: doc.close()
|
1449 |
-
|
1450 |
def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]:
|
1451 |
count = doc.page_count;
|
1452 |
if idxs is None: all_p = list(range(count)); return all_p, all_p
|
@@ -1454,11 +1455,11 @@ class MDRDocumentIterator:
|
|
1454 |
for i in idxs:
|
1455 |
if 0<=i<count: enable.add(i); [scan.add(j) for j in range(max(0, i-_MDR_CONTEXT_PAGES), min(count, i+_MDR_CONTEXT_PAGES+1))]
|
1456 |
return sorted(list(scan)), sorted(list(enable))
|
1457 |
-
|
1458 |
def _render_page_image(self, page: FitzPage, dpi: int) -> Image:
|
1459 |
mat = FitzMatrix(dpi/72.0, dpi/72.0); pix = page.get_pixmap(matrix=mat, alpha=False)
|
1460 |
return frombytes("RGB", (pix.width, pix.height), pix.samples)
|
1461 |
-
|
1462 |
def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str):
|
1463 |
try: plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy(); mdr_plot_layout(plot_img, res.layouts); plot_img.save(os.path.join(path, f"mdr_plot_page_{idx+1}.png"))
|
1464 |
except Exception as e: print(f" Iterator: Plot generation error page {idx+1}: {e}")
|
@@ -1469,7 +1470,7 @@ class MagicPDFProcessor:
|
|
1469 |
Main class for processing PDF documents to extract structured data blocks
|
1470 |
using the MagicDataReadiness pipeline.
|
1471 |
"""
|
1472 |
-
|
1473 |
def __init__(self, device: Literal["cpu", "cuda"]="cuda", model_dir_path: str="./mdr_models", ocr_level: MDROcrLevel=MDROcrLevel.Once, extract_formula: bool=True, extract_table_format: MDRExtractedTableFormat|None=None, debug_dir_path: str|None=None):
|
1474 |
"""
|
1475 |
Initializes the MagicPDFProcessor.
|
@@ -1546,22 +1547,68 @@ class MagicPDFProcessor:
|
|
1546 |
self._assign_relative_font_sizes(temp_store);
|
1547 |
return [block for _, block in temp_store]
|
1548 |
|
|
|
1549 |
def _analyze_paragraph_structure(self, blocks: list[MDRStructuredBlock]):
|
1550 |
-
"""
|
1551 |
-
|
1552 |
-
|
1553 |
-
|
1554 |
-
|
1555 |
-
|
1556 |
-
|
1557 |
-
|
1558 |
-
|
1559 |
-
|
1560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1561 |
except Exception as e:
|
|
|
1562 |
print(f"Warn: Error calculating paragraph structure for block: {e}")
|
1563 |
-
|
|
|
|
|
1564 |
|
|
|
1565 |
|
1566 |
def _calculate_text_range(self, blocks_iter: Iterable[MDRStructuredBlock]) -> tuple[float, float, float]:
|
1567 |
"""Calculates average line height and min/max x-coordinates for text."""
|
@@ -1639,7 +1686,19 @@ if __name__ == '__main__':
|
|
1639 |
|
1640 |
# Path to the PDF file you want to process
|
1641 |
# IMPORTANT: Place a PDF file here for testing!
|
|
|
1642 |
MDR_INPUT_PDF = "example_input.pdf" # <--- CHANGE THIS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1643 |
|
1644 |
# Optional: Directory to save debug plots (set to None to disable)
|
1645 |
MDR_DEBUG_DIRECTORY = "./mdr_debug_output"
|
@@ -1718,7 +1777,7 @@ if __name__ == '__main__':
|
|
1718 |
info = f" - Block {block_idx+1}: {type(block).__name__}"
|
1719 |
if isinstance(block, MDRTextBlock):
|
1720 |
preview = block.texts[0].content[:70].replace('\n',' ') + "..." if block.texts else "[EMPTY]"
|
1721 |
-
info += f" (Kind: {block.kind.name}, FontSz: {block.font_size:.2f}) | Text: '{preview}'"
|
1722 |
elif isinstance(block, MDRTableBlock):
|
1723 |
info += f" (Format: {block.format.name}, HasContent: {bool(block.content)}, FontSz: {block.font_size:.2f})"
|
1724 |
# if block.content: print(f" Content:\n{block.content}") # Uncomment to see content
|
|
|
46 |
import onnxruntime
|
47 |
# --- HUGGING FACE HUB IMPORT ONLY BECAUSE RUNNING IN SPACES NOT NECESSARY IN PROD ---
|
48 |
from huggingface_hub import hf_hub_download, HfHubDownloadError
|
49 |
+
import time # Added for example usage timing
|
50 |
|
51 |
# --- External Dependencies ---
|
52 |
try:
|
|
|
60 |
print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.")
|
61 |
LatexOCR = None
|
62 |
try:
|
63 |
+
pass # from struct_eqtable import build_model # Keep commented as per original
|
64 |
except ImportError:
|
65 |
print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.")
|
66 |
|
|
|
304 |
# --- MDR Geometry Processing ---
|
305 |
class MDRRotationAdjuster:
|
306 |
"""Adjusts point coordinates based on image rotation."""
|
307 |
+
|
308 |
def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool):
|
309 |
fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size)
|
310 |
self._rot = rotation if to_origin_coordinate else -rotation
|
311 |
self._c_off = (fs[0]/2.0, fs[1]/2.0); self._n_off = (ts[0]/2.0, ts[1]/2.0)
|
312 |
+
|
313 |
def adjust(self, point: MDRPoint) -> MDRPoint:
|
314 |
x, y = point[0]-self._c_off[0], point[1]-self._c_off[1]
|
315 |
if x!=0 or y!=0: cos_r, sin_r = cos(self._rot), sin(self._rot); x, y = x*cos_r-y*sin_r, x*sin_r+y*cos_r
|
|
|
363 |
# --- MDR ONNX OCR Internals ---
|
364 |
class _MDR_PredictBase:
|
365 |
"""Base class for ONNX model prediction components."""
|
366 |
+
|
367 |
def get_onnx_session(self, model_path: str, use_gpu: bool):
|
368 |
try:
|
369 |
sess_opts = onnxruntime.SessionOptions(); sess_opts.log_severity_level = 3
|
|
|
376 |
if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
377 |
print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.")
|
378 |
raise e
|
379 |
+
|
380 |
def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_outputs()]
|
381 |
+
|
382 |
def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_inputs()]
|
383 |
+
|
384 |
def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: return {name: img_np for name in names}
|
385 |
|
386 |
# --- MDR ONNX OCR Internals ---
|
387 |
class _MDR_NormalizeImage:
|
388 |
+
|
389 |
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
390 |
self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0))
|
391 |
mean = mean if mean is not None else [0.485, 0.456, 0.406]; std = std if std is not None else [0.229, 0.224, 0.225]
|
392 |
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3); self.mean = np.array(mean).reshape(shape).astype('float32'); self.std = np.array(std).reshape(shape).astype('float32')
|
393 |
+
|
394 |
def __call__(self, data): img = data['image']; img = np.array(img) if isinstance(img, Image) else img; data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std; return data
|
395 |
|
396 |
class _MDR_DetResizeForTest:
|
397 |
+
|
398 |
def __init__(self, **kwargs):
|
399 |
self.resize_type = 0; self.keep_ratio = False
|
400 |
if 'image_shape' in kwargs: self.image_shape = kwargs['image_shape']; self.resize_type = 1; self.keep_ratio = kwargs.get('keep_ratio', False)
|
401 |
elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len']; self.limit_type = kwargs.get('limit_type', 'min')
|
402 |
elif 'resize_long' in kwargs: self.resize_type = 2; self.resize_long = kwargs.get('resize_long', 960)
|
403 |
else: self.limit_side_len = 736; self.limit_type = 'min'
|
404 |
+
|
405 |
def __call__(self, data):
|
406 |
img = data['image']; src_h, src_w, _ = img.shape
|
407 |
if src_h + src_w < 64: img = self._pad(img)
|
|
|
410 |
else: img, ratios = self._resize1(img)
|
411 |
if img is None: return None
|
412 |
data['image'] = img; data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]); return data
|
413 |
+
|
414 |
def _pad(self, im, v=0): h,w,c=im.shape; p=np.zeros((max(32,h),max(32,w),c),np.uint8)+v; p[:h,:w,:]=im; return p
|
415 |
+
|
416 |
+
def _resize1(self, img): rh,rw=self.image_shape; oh,ow=img.shape[:2]; if self.keep_ratio: rw=ow*rh/oh; N=ceil(rw/32); rw=N*32; r_h,r_w=float(rh)/oh,float(rw)/ow; img=cv2.resize(img,(int(rw),int(rh))); return img,[r_h,r_w]
|
417 |
+
|
418 |
def _resize0(self, img): lsl=self.limit_side_len; h,w,_=img.shape; r=1.0; if self.limit_type=='max': r=float(lsl)/max(h,w) if max(h,w)>lsl else 1.0; elif self.limit_type=='min': r=float(lsl)/min(h,w) if min(h,w)<lsl else 1.0; elif self.limit_type=='resize_long': r=float(lsl)/max(h,w); else: raise Exception('Unsupported'); rh,rw=int(h*r),int(w*r); rh=max(int(round(rh/32)*32),32); rw=max(int(round(rw/32)*32),32); if int(rw)<=0 or int(rh)<=0: return None,(None,None); img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
419 |
+
|
420 |
def _resize2(self, img): h,w,_=img.shape; rl=self.resize_long; r=float(rl)/max(h,w); rh,rw=int(h*r),int(w*r); ms=128; rh=(rh+ms-1)//ms*ms; rw=(rw+ms-1)//ms*ms; img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
421 |
|
422 |
class _MDR_ToCHWImage:
|
423 |
+
|
424 |
def __call__(self, data): img=data['image']; img=np.array(img) if isinstance(img,Image) else img; data['image']=img.transpose((2,0,1)); return data
|
425 |
|
426 |
class _MDR_KeepKeys:
|
427 |
+
|
428 |
def __init__(self, keep_keys, **kwargs): self.keep_keys=keep_keys
|
429 |
+
|
430 |
def __call__(self, data): return [data[key] for key in self.keep_keys]
|
431 |
|
432 |
def mdr_ocr_transform(data, ops=None):
|
|
|
446 |
return ops
|
447 |
|
448 |
class _MDR_DBPostProcess:
|
449 |
+
|
450 |
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs):
|
451 |
self.thresh, self.box_thresh, self.max_cand = thresh, box_thresh, max_candidates; self.unclip_r, self.min_sz, self.score_m, self.box_t = unclip_ratio, 3, score_mode, box_type
|
452 |
assert score_mode in ["slow", "fast"]; self.dila_k = np.array([[1,1],[1,1]], dtype=np.uint8) if use_dilation else None
|
453 |
+
|
454 |
def _polygons_from_bitmap(self, pred, bmp, dw, dh):
|
455 |
h, w = bmp.shape; boxes, scores = [], []
|
456 |
contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
467 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
468 |
boxes.append(box.tolist()); scores.append(score)
|
469 |
return boxes, scores
|
470 |
+
|
471 |
def _boxes_from_bitmap(self, pred, bmp, dw, dh):
|
472 |
h, w = bmp.shape; contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
473 |
num_contours = min(len(contours), self.max_cand); boxes, scores = [], []
|
|
|
483 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
484 |
boxes.append(box.astype("int32")); scores.append(score)
|
485 |
return np.array(boxes, dtype="int32"), scores
|
486 |
+
|
487 |
def _unclip(self, box, ratio):
|
488 |
poly = Polygon(box); dist = poly.area*ratio/poly.length; offset = pyclipper.PyclipperOffset(); offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
489 |
expanded = offset.Execute(dist);
|
490 |
if not expanded: raise ValueError("Unclip failed"); return np.array(expanded[0])
|
491 |
+
|
492 |
def _get_mini_boxes(self, contour):
|
493 |
bb = cv2.minAreaRect(contour); pts = sorted(list(cv2.boxPoints(bb)), key=lambda x:x[0])
|
494 |
i1,i4 = (0,1) if pts[1][1]>pts[0][1] else (1,0); i2,i3 = (2,3) if pts[3][1]>pts[2][1] else (3,2)
|
495 |
box = [pts[i1], pts[i2], pts[i3], pts[i4]]; return box, min(bb[1])
|
496 |
+
|
497 |
def _box_score_fast(self, bmp, box):
|
498 |
h,w = bmp.shape[:2]; xmin=np.clip(np.floor(box[:,0].min()).astype("int32"),0,w-1); xmax=np.clip(np.ceil(box[:,0].max()).astype("int32"),0,w-1)
|
499 |
ymin=np.clip(np.floor(box[:,1].min()).astype("int32"),0,h-1); ymax=np.clip(np.ceil(box[:,1].max()).astype("int32"),0,h-1)
|
500 |
mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=np.uint8); box[:,0]-=xmin; box[:,1]-=ymin
|
501 |
cv2.fillPoly(mask, box.reshape(1,-1,2).astype("int32"), 1);
|
502 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
503 |
+
|
504 |
def _box_score_slow(self, bmp, contour): # Not used if fast
|
505 |
h,w = bmp.shape[:2]; contour = np.reshape(contour.copy(),(-1,2)); xmin=np.clip(np.min(contour[:,0]),0,w-1); xmax=np.clip(np.max(contour[:,0]),0,w-1)
|
506 |
ymin=np.clip(np.min(contour[:,1]),0,h-1); ymax=np.clip(np.max(contour[:,1]),0,h-1); mask=np.zeros((ymax-ymin+1,xmax-xmin+1),dtype=np.uint8)
|
507 |
contour[:,0]-=xmin; contour[:,1]-=ymin; cv2.fillPoly(mask, contour.reshape(1,-1,2).astype("int32"), 1);
|
508 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
509 |
+
|
510 |
def __call__(self, outs_dict, shape_list):
|
511 |
pred = outs_dict['maps'][:,0,:,:]; seg = pred > self.thresh; boxes_batch = []
|
512 |
for batch_idx in range(pred.shape[0]):
|
|
|
518 |
return boxes_batch
|
519 |
|
520 |
class _MDR_TextDetector(_MDR_PredictBase):
|
521 |
+
|
522 |
def __init__(self, args):
|
523 |
super().__init__(); self.args = args
|
524 |
pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}]
|
|
|
527 |
self.post_op = _MDR_DBPostProcess(**post_params)
|
528 |
self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu)
|
529 |
self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
530 |
+
|
531 |
def _order_pts(self, pts): r=np.zeros((4,2),dtype="float32"); s=pts.sum(axis=1); r[0]=pts[np.argmin(s)]; r[2]=pts[np.argmax(s)]; tmp=np.delete(pts,(np.argmin(s),np.argmax(s)),axis=0); d=np.diff(np.array(tmp),axis=1); r[1]=tmp[np.argmin(d)]; r[3]=tmp[np.argmax(d)]; return r
|
532 |
+
|
533 |
def _clip_pts(self, pts, h, w): pts[:,0]=np.clip(pts[:,0],0,w-1); pts[:,1]=np.clip(pts[:,1],0,h-1); return pts
|
534 |
+
|
535 |
def _filter_quad(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._order_pts(box); box=self._clip_pts(box,h,w); rw=int(np.linalg.norm(box[0]-box[1])); rh=int(np.linalg.norm(box[0]-box[3])); if rw<=3 or rh<=3: continue; new_boxes.append(box); return np.array(new_boxes)
|
536 |
+
|
537 |
def _filter_poly(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._clip_pts(box,h,w); if Polygon(box).area<10: continue; new_boxes.append(box); return np.array(new_boxes)
|
538 |
+
|
539 |
def __call__(self, img):
|
540 |
ori_im = img.copy(); data = {"image": img}; data = mdr_ocr_transform(data, self.pre_op)
|
541 |
if data is None: return None; img, shape_list = data;
|
|
|
545 |
return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type=='poly' else self._filter_quad(boxes, ori_im.shape)
|
546 |
|
547 |
class _MDR_ClsPostProcess:
|
548 |
+
|
549 |
def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0:'0', 1:'180'}
|
550 |
+
|
551 |
def __call__(self, preds, label=None, *args, **kwargs):
|
552 |
preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; idxs = preds.argmax(axis=1)
|
553 |
return [(self.labels[idx], float(preds[i,idx])) for i,idx in enumerate(idxs)]
|
554 |
|
555 |
class _MDR_TextClassifier(_MDR_PredictBase):
|
556 |
+
|
557 |
def __init__(self, args):
|
558 |
super().__init__(); self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape
|
559 |
self.batch_num = args.cls_batch_num; self.thresh = args.cls_thresh; self.post_op = _MDR_ClsPostProcess(label_list=args.label_list)
|
560 |
self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu); self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
561 |
+
|
562 |
def _resize_norm(self, img):
|
563 |
+
imgC,imgH,imgW = self.shape; h,w = img.shape[:2]; r=w/float(h) if h>0 else 0; rw=int(ceil(imgH*r)); rw=min(rw,imgW)
|
564 |
resized = cv2.resize(img,(rw,imgH)); resized = resized.astype("float32")
|
565 |
if imgC==1: resized = resized/255.0; resized = resized[np.newaxis,:]
|
566 |
else: resized = resized.transpose((2,0,1))/255.0
|
567 |
resized -= 0.5; resized /= 0.5; padding = np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:rw]=resized; return padding
|
568 |
+
|
569 |
def __call__(self, img_list):
|
570 |
if not img_list: return img_list, []; img_list_cp = copy.deepcopy(img_list); num = len(img_list_cp)
|
571 |
ratios = [img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list_cp]; indices = np.argsort(np.array(ratios))
|
|
|
582 |
return img_list, results
|
583 |
|
584 |
class _MDR_BaseRecLabelDecode:
|
585 |
+
|
586 |
def __init__(self, char_path=None, use_space=False):
|
587 |
self.beg, self.end, self.rev = "sos", "eos", False; self.chars = []
|
588 |
if char_path is None: self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz")
|
|
|
593 |
if any("\u0600"<=c<="\u06FF" for c in self.chars): self.rev=True
|
594 |
except FileNotFoundError: print(f"Warn: Dict not found {char_path}"); self.chars=list("0123456789abcdefghijklmnopqrstuvwxyz"); if use_space: self.chars.append(" ")
|
595 |
d_char = self.add_special_char(list(self.chars)); self.dict={c:i for i,c in enumerate(d_char)}; self.character=d_char
|
596 |
+
|
597 |
def add_special_char(self, chars): return chars
|
598 |
+
|
599 |
def get_ignored_tokens(self): return []
|
600 |
+
|
601 |
def _reverse(self, pred): res=[]; cur=""; for c in pred: if not re.search("[a-zA-Z0-9 :*./%+-]",c): res.extend([cur,c] if cur!="" else [c]); cur="" else: cur+=c; if cur!="": res.append(cur); return "".join(res[::-1])
|
602 |
+
|
603 |
def decode(self, idxs, probs=None, remove_dup=False):
|
604 |
res=[]; ignored=self.get_ignored_tokens(); bs=len(idxs)
|
605 |
for b_idx in range(bs):
|
|
|
629 |
self.batch_num=getattr(args,'rec_batch_num',6); self.algo=getattr(args,'rec_algorithm','SVTR_LCNet')
|
630 |
self.post_op=_MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, use_space=getattr(args,'use_space_char',True))
|
631 |
self.sess=self.get_onnx_session(args.rec_model_dir, args.use_gpu); self.input_name=self.get_input_name(self.sess); self.output_name=self.get_output_name(self.sess)
|
632 |
+
|
633 |
def _resize_norm(self, img, max_r):
|
634 |
imgC,imgH,imgW = self.shape; h,w = img.shape[:2];
|
635 |
if h==0 or w==0: return np.zeros((imgC,imgH,imgW),dtype=np.float32)
|
636 |
+
r=w/float(h); tw=min(imgW, int(ceil(imgH*max(r,max_r)))); tw=max(1,tw)
|
637 |
resized=cv2.resize(img,(tw,imgH)); resized=resized.astype("float32")
|
638 |
if imgC==1 and len(resized.shape)==3: resized=cv2.cvtColor(resized,cv2.COLOR_BGR2GRAY); resized=resized[:,:,np.newaxis]
|
639 |
if len(resized.shape)==2: resized=resized[:,:,np.newaxis]
|
640 |
resized=resized.transpose((2,0,1))/255.0; resized-=0.5; resized/=0.5
|
641 |
padding=np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:tw]=resized; return padding
|
642 |
+
|
643 |
def __call__(self, img_list):
|
644 |
if not img_list: return []; num=len(img_list); ratios=[img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list]
|
645 |
indices=np.argsort(np.array(ratios)); results=[["",0.0]]*num; batch_n=self.batch_num
|
|
|
655 |
|
656 |
# --- MDR ONNX OCR System ---
|
657 |
class _MDR_TextSystem:
|
658 |
+
|
659 |
def __init__(self, args):
|
660 |
class ArgsObject: # Helper to access dict args with dot notation
|
661 |
def __init__(self, **entries): self.__dict__.update(entries)
|
|
|
667 |
self.drop_score = getattr(args, 'drop_score', 0.5)
|
668 |
self.classifier = _MDR_TextClassifier(args) if self.use_cls else None
|
669 |
self.crop_idx = 0; self.save_crop = getattr(args, 'save_crop_res', False); self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res")
|
670 |
+
|
671 |
def _sort_boxes(self, boxes):
|
672 |
if boxes is None or len(boxes)==0: return []
|
673 |
def key(box): min_y=min(p[1] for p in box); min_x=min(p[0] for p in box); return (min_y, min_x)
|
674 |
try: return list(sorted(boxes, key=key))
|
675 |
except: return list(boxes) # Fallback
|
676 |
+
|
677 |
def __call__(self, img, classify=True):
|
678 |
ori_im = img.copy(); boxes = self.detector(img)
|
679 |
if boxes is None or len(boxes)==0: return [], []
|
|
|
695 |
if score >= self.drop_score: final_boxes.append(box); final_rec.append(res)
|
696 |
if self.save_crop: self._save_crops(crops, rec_res)
|
697 |
return final_boxes, final_rec
|
698 |
+
|
699 |
def _save_crops(self, crops, recs):
|
700 |
mdr_ensure_directory(self.crop_dir); num = len(crops)
|
701 |
for i in range(num): txt, score = recs[i]; safe=re.sub(r'\W+', '_', txt)[:20]; fname=f"crop_{self.crop_idx+i}_{safe}_{score:.2f}.jpg"; cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i])
|
|
|
723 |
_MDR_INCLUDES_MIN_RATE = 0.99
|
724 |
|
725 |
class _MDR_OverlapMatrixContext:
|
726 |
+
|
727 |
def __init__(self, layouts: list[MDRLayoutElement]):
|
728 |
length = len(layouts); self.polys: list[Polygon|None] = []
|
729 |
for l in layouts:
|
|
|
737 |
p2 = self.polys[j];
|
738 |
if p2 is None: continue
|
739 |
r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji
|
740 |
+
|
741 |
def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2
|
742 |
try: inter = p1.intersection(p2);
|
743 |
except: return 0.0
|
|
|
746 |
_, _, px1, py1 = p2.bounds; pw, ph = px1-p2.bounds[0], py1-p2.bounds[1]
|
747 |
if pw < 1e-6 or ph < 1e-6: return 0.0
|
748 |
wr = min(iw/pw, 1.0); hr = min(ih/ph, 1.0); return (wr+hr)/2.0
|
749 |
+
|
750 |
def others(self, idx: int):
|
751 |
for i, r in enumerate(self.matrix[idx]):
|
752 |
if i != idx and i not in self.removed: yield r
|
753 |
+
|
754 |
def includes(self, idx: int): # Layouts included BY idx
|
755 |
for i, r in enumerate(self.matrix[idx]):
|
756 |
if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE:
|
|
|
865 |
|
866 |
class MDROcrEngine:
|
867 |
"""Handles OCR detection and recognition using ONNX models."""
|
868 |
+
|
869 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str):
|
870 |
self._device = device; self._model_dir = mdr_ensure_directory(model_dir_path)
|
871 |
self._text_system: _MDR_TextSystem | None = None; self._onnx_params: _MDR_ONNXParams | None = None
|
872 |
self._ensure_models(); self._get_system() # Init on creation
|
873 |
+
|
874 |
def _ensure_models(self):
|
875 |
for key, parts in _MDR_OCR_MODELS.items():
|
876 |
fp = Path(self._model_dir) / Path(*parts)
|
877 |
if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join(parts); mdr_download_model(url, fp)
|
878 |
+
|
879 |
def _get_system(self) -> _MDR_TextSystem | None:
|
880 |
if self._text_system is None:
|
881 |
paths = {k: str(Path(self._model_dir)/Path(*p)) for k,p in _MDR_OCR_MODELS.items()}
|
|
|
883 |
try: self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.")
|
884 |
except Exception as e: print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None
|
885 |
return self._text_system
|
886 |
+
|
887 |
def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]:
|
888 |
"""Finds and recognizes text fragments in a NumPy image (BGR)."""
|
889 |
system = self._get_system()
|
|
|
896 |
if not txt or mdr_is_whitespace(txt) or conf < 0.1: continue
|
897 |
pts = [(float(p[0]), float(p[1])) for p in box_pts]
|
898 |
if len(pts)==4: r=MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]); if r.is_valid and r.area>1: yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r)
|
899 |
+
|
900 |
def _preprocess(self, img: np.ndarray) -> np.ndarray:
|
901 |
if len(img.shape)==3 and img.shape[2]==4: a=img[:,:,3]/255.0; bg=(255,255,255); new=np.zeros_like(img[:,:,:3]); [setattr(new[:,:,i], 'flags.writeable', True) for i in range(3)]; [np.copyto(new[:,:,i], (bg[i]*(1-a)+img[:,:,i]*a)) for i in range(3)]; img=new.astype(np.uint8)
|
902 |
elif len(img.shape)==2: img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
937 |
|
938 |
class MDRLayoutReader:
|
939 |
"""Determines reading order of layout elements using LayoutLMv3."""
|
940 |
+
|
941 |
def __init__(self, model_path: str):
|
942 |
self._model_path = model_path; self._model: LayoutLMv3ForTokenClassification | None = None
|
943 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
944 |
+
|
945 |
def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
|
946 |
if self._model is None:
|
947 |
cache = mdr_ensure_directory(self._model_path); name = "microsoft/layoutlmv3-base"; h_path = os.path.join(cache, "models--hantian--layoutreader")
|
|
|
951 |
self._model.to(self._device); self._model.eval(); print(f"MDR LayoutReader loaded on {self._device}.")
|
952 |
except Exception as e: print(f"ERROR loading MDR LayoutReader: {e}"); self._model = None
|
953 |
return self._model
|
954 |
+
|
955 |
def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]:
|
956 |
w, h = size;
|
957 |
if w<=0 or h<=0 or not layouts: return layouts;
|
|
|
978 |
if len(orders) != len(bbox_list): print("MDR LayoutReader order mismatch"); return layouts # Fallback
|
979 |
for i, order_idx in enumerate(orders): bbox_list[i].order = order_idx
|
980 |
return self._apply_order(layouts, bbox_list)
|
981 |
+
|
982 |
def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None:
|
983 |
line_h = self._estimate_line_h(layouts); bbox_list = []
|
984 |
for i, l in enumerate(layouts):
|
|
|
986 |
else: bbox_list.extend(self._gen_virtual(l, i, line_h, w, h))
|
987 |
if len(bbox_list) > _MDR_MAX_LEN: print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})"); return None
|
988 |
bbox_list.sort(key=lambda b: (b.value[1], b.value[0])); return bbox_list
|
989 |
+
|
990 |
def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]:
|
991 |
layout_map = defaultdict(list); [layout_map[b.layout_index].append(b) for b in bbox_list]
|
992 |
layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes]
|
|
|
1000 |
else: frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
1001 |
for frag in frags: frag.order = nfo; nfo += 1
|
1002 |
return sorted_layouts
|
1003 |
+
|
1004 |
def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float:
|
1005 |
heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1]>0]
|
1006 |
return self._median(heights) if heights else 15.0
|
1007 |
+
|
1008 |
def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]:
|
1009 |
x0,y0,x1,y1 = l.rect.wrapper; lh,lw = y1-y0,x1-x0
|
1010 |
if lh<=0 or lw<=0 or line_h<=0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(x0,y0,x1,y1)); return
|
|
|
1019 |
ly0,ly1 = max(0,min(ph,cur_y)), max(0,min(ph,cur_y+act_line_h)); lx0,lx1 = max(0,min(pw,x0)), max(0,min(pw,x1))
|
1020 |
if ly1>ly0 and lx1>lx0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(lx0,ly0,lx1,ly1))
|
1021 |
cur_y += act_line_h
|
1022 |
+
|
1023 |
def _median(self, nums: list[float|int]) -> float:
|
1024 |
if not nums: return 0.0; s_nums = sorted(nums); n = len(s_nums)
|
1025 |
return float(s_nums[n//2]) if n%2==1 else float((s_nums[n//2-1]+s_nums[n//2])/2.0)
|
|
|
1027 |
# --- MDR LaTeX Extractor ---
|
1028 |
class MDRLatexExtractor:
|
1029 |
"""Extracts LaTeX from formula images using pix2tex."""
|
1030 |
+
|
1031 |
def __init__(self, model_path: str):
|
1032 |
self._model_path = model_path; self._model: LatexOCR | None = None
|
1033 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
1034 |
+
|
1035 |
def extract(self, image: Image) -> str | None:
|
1036 |
if LatexOCR is None: return None;
|
1037 |
image = mdr_expand_image(image, 0.1); model = self._get_model()
|
|
|
1039 |
try:
|
1040 |
with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None
|
1041 |
except Exception as e: print(f"MDR LaTeX error: {e}"); return None
|
1042 |
+
|
1043 |
def _get_model(self) -> LatexOCR | None:
|
1044 |
if self._model is None and LatexOCR is not None:
|
1045 |
mdr_ensure_directory(self._model_path); wp, rp, cp = Path(self._model_path)/"weights.pth", Path(self._model_path)/"image_resizer.pth", Path(self._model_path)/"config.yaml"
|
|
|
1048 |
try: args = Munch({"config":str(cp), "checkpoint":str(wp), "device":self._device, "no_cuda":self._device=="cpu", "no_resize":False, "temperature":0.0}); self._model = LatexOCR(args); print(f"MDR LaTeX loaded on {self._device}.")
|
1049 |
except Exception as e: print(f"ERROR initializing MDR LatexOCR: {e}"); self._model = None
|
1050 |
return self._model
|
1051 |
+
|
1052 |
def _download(self):
|
1053 |
tag = "v0.0.1"; base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/"; files = {"weights.pth": base+"weights.pth", "image_resizer.pth": base+"image_resizer.pth"}
|
1054 |
mdr_ensure_directory(self._model_path); [mdr_download_model(url, Path(self._model_path)/name) for name, url in files.items() if not (Path(self._model_path)/name).exists()]
|
|
|
1058 |
|
1059 |
class MDRTableParser:
|
1060 |
"""Parses table structure/content from images using StructTable model."""
|
1061 |
+
|
1062 |
def __init__(self, device: Literal["cpu", "cuda"], model_path: str):
|
1063 |
self._model: Any | None = None; self._model_path = mdr_ensure_directory(model_path)
|
1064 |
self._device = device if torch.cuda.is_available() and device=="cuda" else "cpu"
|
1065 |
self._disabled = self._device == "cpu"
|
1066 |
if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.")
|
1067 |
+
|
1068 |
def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None:
|
1069 |
if self._disabled: return None;
|
1070 |
fmt: MDRTableOutputFormat | None = None
|
|
|
1079 |
with torch.no_grad(): results = model([img_rgb], output_format=fmt)
|
1080 |
return results[0] if results else None
|
1081 |
except Exception as e: print(f"MDR Table parsing error: {e}"); return None
|
1082 |
+
|
1083 |
def _get_model(self):
|
1084 |
if self._model is None and not self._disabled:
|
1085 |
try:
|
|
|
1100 |
|
1101 |
class MDRImageOptimizer:
|
1102 |
"""Handles image rotation detection and coordinate adjustments."""
|
1103 |
+
|
1104 |
def __init__(self, raw_image: Image, adjust_points: bool):
|
1105 |
self._raw = raw_image; self._image = raw_image; self._adjust_points = adjust_points
|
1106 |
self._fragments: list[MDROcrFragment] = []; self._rotation: float = 0.0; self._rot_ctx: _MDR_RotationContext | None = None
|
1107 |
+
|
1108 |
@property
|
1109 |
def image(self) -> Image: return self._image
|
1110 |
+
|
1111 |
@property
|
1112 |
def adjusted_image(self) -> Image | None: return self._image if self._rot_ctx is not None else None
|
1113 |
+
|
1114 |
@property
|
1115 |
def rotation(self) -> float: return self._rotation
|
1116 |
+
|
1117 |
@property
|
1118 |
def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
1119 |
+
|
1120 |
def receive_fragments(self, fragments: list[MDROcrFragment]):
|
1121 |
self._fragments = fragments;
|
1122 |
if not fragments: return;
|
|
|
1131 |
to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False),
|
1132 |
to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True))
|
1133 |
adj = self._rot_ctx.to_new; [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r:=f.rect)]
|
1134 |
+
|
1135 |
def finalize_layout_coords(self, layouts: list[MDRLayoutElement]):
|
1136 |
if self._rot_ctx is None or self._adjust_points: return
|
1137 |
if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles)]
|
|
|
1167 |
if not layouts: return;
|
1168 |
try: l_font, f_font = load_default(size=25), load_default(size=15); draw = ImageDraw.Draw(image, mode="RGBA")
|
1169 |
except Exception as e: print(f"MDR Plot init error: {e}"); return
|
1170 |
+
|
1171 |
def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA):
|
1172 |
try: x,y=pos; txt=str(num); txt_pos=(round(x)+3, round(y)+1); bbox=draw.textbbox(txt_pos,txt,font=font); bg_rect=(bbox[0]-2,bbox[1]-1,bbox[2]+2,bbox[3]+1); bg_color=(color[0],color[1],color[2],180); draw.rectangle(bg_rect,fill=bg_color); draw.text(txt_pos,txt,font=font,fill=(255,255,255,255))
|
1173 |
except Exception as e: print(f"MDR Draw num error: {e}")
|
|
|
1182 |
# --- MDR Extraction Engine ---
|
1183 |
class MDRExtractionEngine:
|
1184 |
"""Core engine for extracting structured information from a document image."""
|
1185 |
+
|
1186 |
def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"]="cpu", ocr_for_each_layouts: bool=True, extract_formula: bool=True, extract_table_format: MDRTableLayoutParsedFormat|None=None):
|
1187 |
self._model_dir = model_dir_path # Base directory for all models
|
1188 |
self._device = device if torch.cuda.is_available() else "cpu"
|
|
|
1237 |
print("MDR YOLOv10 class not available. Layout detection skipped.")
|
1238 |
|
1239 |
return self._yolo
|
1240 |
+
|
1241 |
def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult:
|
1242 |
"""Analyzes a single page image to extract layout and content."""
|
1243 |
print(" Engine: Analyzing image..."); optimizer = MDRImageOptimizer(image, adjust_points)
|
|
|
1257 |
print(" Engine: Finalizing coords..."); optimizer.finalize_layout_coords(layouts)
|
1258 |
print(" Engine: Analysis complete.")
|
1259 |
return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image)
|
1260 |
+
|
1261 |
def _run_yolo_detection(self, img: Image, yolo: YOLOv10) -> Generator[MDRLayoutElement, None, None]:
|
1262 |
img_rgb = img.convert("RGB"); res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False)
|
1263 |
if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: return
|
|
|
1271 |
if cls == MDRLayoutClass.TABLE: yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None)
|
1272 |
elif cls == MDRLayoutClass.ISOLATE_FORMULA: yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None)
|
1273 |
elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[])
|
1274 |
+
|
1275 |
def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]:
|
1276 |
if not frags or not layouts: return layouts
|
1277 |
layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts]
|
|
|
1291 |
layouts[best_idx].fragments.append(frag)
|
1292 |
for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
1293 |
return layouts
|
1294 |
+
|
1295 |
def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]):
|
1296 |
for i, l in enumerate(layouts):
|
1297 |
if l.cls == MDRLayoutClass.FIGURE: continue
|
1298 |
try: mdr_correct_layout_fragments(self._ocr_engine, img, l)
|
1299 |
except Exception as e: print(f" Engine: OCR correction error layout {i}: {e}")
|
1300 |
+
|
1301 |
def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer):
|
1302 |
img_to_clip = optimizer.image
|
1303 |
for l in layouts:
|
|
|
1308 |
try: t_img = mdr_clip_from_image(img_to_clip, l.rect); parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width>1 and t_img.height>1 else None
|
1309 |
except Exception as e: print(f" Engine: Table parse error: {e}"); parsed = None
|
1310 |
if parsed: l.parsed = (parsed, self._ext_table)
|
1311 |
+
|
1312 |
def _should_keep_layout(self, l: MDRLayoutElement) -> bool:
|
1313 |
if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True
|
1314 |
return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA]
|
|
|
1316 |
# --- MDR Page Section Linking ---
|
1317 |
class _MDR_LinkedShape:
|
1318 |
"""Internal helper for managing layout linking across pages."""
|
1319 |
+
|
1320 |
def __init__(self, layout: MDRLayoutElement): self.layout=layout; self.pre:list[MDRLayoutElement|None]=[None,None]; self.nex:list[MDRLayoutElement|None]=[None,None]
|
1321 |
+
|
1322 |
@property
|
1323 |
def distance2(self) -> float: x,y=self.layout.rect.lt; return x*x+y*y
|
1324 |
|
1325 |
class MDRPageSection:
|
1326 |
"""Represents a page's layouts for framework detection via linking."""
|
1327 |
+
|
1328 |
def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]):
|
1329 |
self._page_index = page_index; self._shapes = [_MDR_LinkedShape(l) for l in layouts]; self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0]))
|
1330 |
+
|
1331 |
@property
|
1332 |
def page_index(self) -> int: return self._page_index
|
1333 |
+
|
1334 |
def find_framework_elements(self) -> list[MDRLayoutElement]:
|
1335 |
"""Identifies framework layouts based on links to other pages."""
|
1336 |
return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)]
|
1337 |
+
|
1338 |
def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None:
|
1339 |
"""Links matching shapes between this section and the next."""
|
1340 |
if offset not in (1,2): return
|
|
|
1365 |
if max_sim > 0.75: matches += 1; if best_j != -1: used_f2[best_j] = True
|
1366 |
max_c = max(c1, c2); rate_frags = matches / max_c
|
1367 |
return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95))
|
1368 |
+
|
1369 |
def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float:
|
1370 |
r1_rel = self._relative_rect(l1.rect.lt, f1.rect); r2_rel = self._relative_rect(l2.rect.lt, f2.rect)
|
1371 |
geom_sim = self._symmetric_iou(r1_rel, r2_rel); text_sim, _ = mdr_check_text_similarity(f1.text, f2.text)
|
1372 |
return (geom_sim + text_sim) / 2.0
|
1373 |
+
|
1374 |
def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None:
|
1375 |
best_pair, min_dist2 = None, float('inf')
|
1376 |
for i, s1 in enumerate(self._shapes):
|
|
|
1378 |
if not match_list: continue
|
1379 |
for s2 in match_list: dist2 = s1.distance2 + s2.distance2; if dist2 < min_dist2: min_dist2, best_pair = dist2, (s1, s2)
|
1380 |
return best_pair
|
1381 |
+
|
1382 |
def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool:
|
1383 |
if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx]
|
1384 |
+
|
1385 |
def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle:
|
1386 |
ox, oy = origin; r=rect; return MDRRectangle(lt=(r.lt[0]-ox, r.lt[1]-oy), rt=(r.rt[0]-ox, r.rt[1]-oy), lb=(r.lb[0]-ox, r.lb[1]-oy), rb=(r.rb[0]-ox, r.rb[1]-oy))
|
1387 |
+
|
1388 |
def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float:
|
1389 |
try: p1, p2 = Polygon(r1), Polygon(r2);
|
1390 |
except: return 0.0
|
|
|
1401 |
class MDRProcessingParams:
|
1402 |
"""Parameters for processing a document."""
|
1403 |
pdf: str | FitzDocument; page_indexes: Iterable[int] | None; report_progress: MDRProgressReportCallback | None
|
1404 |
+
|
1405 |
class MDRDocumentIterator:
|
1406 |
"""Iterates through document pages, handles context, and calls the extraction engine."""
|
1407 |
+
|
1408 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, debug_dir_path: str | None):
|
1409 |
self._debug_dir = debug_dir_path
|
1410 |
self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, ocr_for_each_layouts=(ocr_level==MDROcrLevel.OncePerLayout), extract_formula=extract_formula, extract_table_format=extract_table_format)
|
1411 |
+
|
1412 |
def iterate_sections(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]:
|
1413 |
"""Yields page index, extraction result, and content layouts for each requested page."""
|
1414 |
for res, sec in self._process_and_link_sections(params):
|
1415 |
framework = set(sec.find_framework_elements()); content = [l for l in res.layouts if l not in framework]; yield sec.page_index, res, content
|
1416 |
+
|
1417 |
def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[tuple[MDRExtractionResult, MDRPageSection], None, None]:
|
1418 |
queue: list[tuple[MDRExtractionResult, MDRPageSection]] = []
|
1419 |
for page_idx, res in self._run_extraction_on_pages(params):
|
|
|
1424 |
queue.append((res, cur_sec))
|
1425 |
if len(queue) > _MDR_CONTEXT_PAGES: yield queue.pop(0)
|
1426 |
for res, sec in queue: yield res, sec
|
1427 |
+
|
1428 |
def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]:
|
1429 |
if self._debug_dir: mdr_ensure_directory(self._debug_dir)
|
1430 |
doc, should_close = None, False
|
|
|
1447 |
except Exception as e: print(f" Iterator: Page {page_idx+1} processing error: {e}")
|
1448 |
finally:
|
1449 |
if should_close and doc: doc.close()
|
1450 |
+
|
1451 |
def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]:
|
1452 |
count = doc.page_count;
|
1453 |
if idxs is None: all_p = list(range(count)); return all_p, all_p
|
|
|
1455 |
for i in idxs:
|
1456 |
if 0<=i<count: enable.add(i); [scan.add(j) for j in range(max(0, i-_MDR_CONTEXT_PAGES), min(count, i+_MDR_CONTEXT_PAGES+1))]
|
1457 |
return sorted(list(scan)), sorted(list(enable))
|
1458 |
+
|
1459 |
def _render_page_image(self, page: FitzPage, dpi: int) -> Image:
|
1460 |
mat = FitzMatrix(dpi/72.0, dpi/72.0); pix = page.get_pixmap(matrix=mat, alpha=False)
|
1461 |
return frombytes("RGB", (pix.width, pix.height), pix.samples)
|
1462 |
+
|
1463 |
def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str):
|
1464 |
try: plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy(); mdr_plot_layout(plot_img, res.layouts); plot_img.save(os.path.join(path, f"mdr_plot_page_{idx+1}.png"))
|
1465 |
except Exception as e: print(f" Iterator: Plot generation error page {idx+1}: {e}")
|
|
|
1470 |
Main class for processing PDF documents to extract structured data blocks
|
1471 |
using the MagicDataReadiness pipeline.
|
1472 |
"""
|
1473 |
+
|
1474 |
def __init__(self, device: Literal["cpu", "cuda"]="cuda", model_dir_path: str="./mdr_models", ocr_level: MDROcrLevel=MDROcrLevel.Once, extract_formula: bool=True, extract_table_format: MDRExtractedTableFormat|None=None, debug_dir_path: str|None=None):
|
1475 |
"""
|
1476 |
Initializes the MagicPDFProcessor.
|
|
|
1547 |
self._assign_relative_font_sizes(temp_store);
|
1548 |
return [block for _, block in temp_store]
|
1549 |
|
1550 |
+
# --- START REFACTORED METHOD ---
|
1551 |
def _analyze_paragraph_structure(self, blocks: list[MDRStructuredBlock]):
|
1552 |
+
"""
|
1553 |
+
Calculates indentation and line-end heuristics for MDRTextBlocks
|
1554 |
+
based on page-level text boundaries and average line height.
|
1555 |
+
"""
|
1556 |
+
# Define constants for clarity and maintainability
|
1557 |
+
MIN_VALID_HEIGHT = 1e-6
|
1558 |
+
# Heuristic: Indent if first line starts more than 1.0 * avg line height from page text start
|
1559 |
+
INDENTATION_THRESHOLD_FACTOR = 1.0
|
1560 |
+
# Heuristic: Last line touches end if it ends less than 1.0 * avg line height from page text end
|
1561 |
+
LINE_END_THRESHOLD_FACTOR = 1.0
|
1562 |
+
|
1563 |
+
# Calculate average line height and text boundaries for the relevant text blocks on the page
|
1564 |
+
page_avg_line_height, page_min_x, page_max_x = self._calculate_text_range(
|
1565 |
+
(b for b in blocks if isinstance(b, MDRTextBlock) and b.kind != MDRTextKind.ABANDON)
|
1566 |
+
)
|
1567 |
+
|
1568 |
+
# Avoid calculations if page metrics are invalid (e.g., no text, zero height)
|
1569 |
+
if page_avg_line_height <= MIN_VALID_HEIGHT:
|
1570 |
+
return
|
1571 |
+
|
1572 |
+
# Iterate through each block to determine its paragraph properties
|
1573 |
+
for block in blocks:
|
1574 |
+
# Process only valid text blocks with actual text content
|
1575 |
+
if not isinstance(block, MDRTextBlock) or block.kind == MDRTextKind.ABANDON or not block.texts:
|
1576 |
+
continue
|
1577 |
+
|
1578 |
+
# Use calculated page-level metrics for consistency in thresholds
|
1579 |
+
avg_line_height = page_avg_line_height
|
1580 |
+
page_text_start_x = page_min_x
|
1581 |
+
page_text_end_x = page_max_x
|
1582 |
+
|
1583 |
+
# Get the first and last text spans (assumed to be lines after merging) within the block
|
1584 |
+
first_text_span = block.texts[0]
|
1585 |
+
last_text_span = block.texts[-1]
|
1586 |
+
|
1587 |
+
try:
|
1588 |
+
# --- Calculate Indentation ---
|
1589 |
+
# Estimate the starting x-coordinate of the first line (average of left top/bottom)
|
1590 |
+
first_line_start_x = (first_text_span.rect.lt[0] + first_text_span.rect.lb[0]) / 2.0
|
1591 |
+
# Calculate the difference between the first line's start and the page's text start boundary
|
1592 |
+
indentation_delta = first_line_start_x - page_text_start_x
|
1593 |
+
# Determine indentation based on the heuristic threshold relative to average line height
|
1594 |
+
block.has_paragraph_indentation = indentation_delta > (avg_line_height * INDENTATION_THRESHOLD_FACTOR)
|
1595 |
+
|
1596 |
+
# --- Calculate Last Line End ---
|
1597 |
+
# Estimate the ending x-coordinate of the last line (average of right top/bottom)
|
1598 |
+
last_line_end_x = (last_text_span.rect.rt[0] + last_text_span.rect.rb[0]) / 2.0
|
1599 |
+
# Calculate the difference between the page's text end boundary and the last line's end
|
1600 |
+
line_end_delta = page_text_end_x - last_line_end_x
|
1601 |
+
# Determine if the last line reaches near the end based on the heuristic threshold
|
1602 |
+
block.last_line_touch_end = line_end_delta < (avg_line_height * LINE_END_THRESHOLD_FACTOR)
|
1603 |
+
|
1604 |
except Exception as e:
|
1605 |
+
# Handle potential errors during calculation (e.g., invalid rect data)
|
1606 |
print(f"Warn: Error calculating paragraph structure for block: {e}")
|
1607 |
+
# Default to False if calculation fails to ensure attributes are set
|
1608 |
+
block.has_paragraph_indentation = False
|
1609 |
+
block.last_line_touch_end = False # Removed semicolon from original
|
1610 |
|
1611 |
+
# --- END REFACTORED METHOD ---
|
1612 |
|
1613 |
def _calculate_text_range(self, blocks_iter: Iterable[MDRStructuredBlock]) -> tuple[float, float, float]:
|
1614 |
"""Calculates average line height and min/max x-coordinates for text."""
|
|
|
1686 |
|
1687 |
# Path to the PDF file you want to process
|
1688 |
# IMPORTANT: Place a PDF file here for testing!
|
1689 |
+
# Create a dummy PDF if it doesn't exist for the example to run
|
1690 |
MDR_INPUT_PDF = "example_input.pdf" # <--- CHANGE THIS
|
1691 |
+
if not Path(MDR_INPUT_PDF).exists():
|
1692 |
+
try:
|
1693 |
+
print(f"Creating dummy PDF: {MDR_INPUT_PDF}")
|
1694 |
+
doc = fitz.new_document()
|
1695 |
+
page = doc.new_page()
|
1696 |
+
page.insert_text((72, 72), "This is a dummy PDF for testing.")
|
1697 |
+
doc.save(MDR_INPUT_PDF)
|
1698 |
+
doc.close()
|
1699 |
+
except Exception as e:
|
1700 |
+
print(f"Warning: Could not create dummy PDF: {e}")
|
1701 |
+
|
1702 |
|
1703 |
# Optional: Directory to save debug plots (set to None to disable)
|
1704 |
MDR_DEBUG_DIRECTORY = "./mdr_debug_output"
|
|
|
1777 |
info = f" - Block {block_idx+1}: {type(block).__name__}"
|
1778 |
if isinstance(block, MDRTextBlock):
|
1779 |
preview = block.texts[0].content[:70].replace('\n',' ') + "..." if block.texts else "[EMPTY]"
|
1780 |
+
info += f" (Kind: {block.kind.name}, FontSz: {block.font_size:.2f}, Indent: {block.has_paragraph_indentation}, EndTouch: {block.last_line_touch_end}) | Text: '{preview}'" # Added indent/endtouch
|
1781 |
elif isinstance(block, MDRTableBlock):
|
1782 |
info += f" (Format: {block.format.name}, HasContent: {bool(block.content)}, FontSz: {block.font_size:.2f})"
|
1783 |
# if block.content: print(f" Content:\n{block.content}") # Uncomment to see content
|