Update mdr_pdf_parser.py
Browse files- mdr_pdf_parser.py +247 -51
mdr_pdf_parser.py
CHANGED
@@ -38,15 +38,16 @@ from PIL.ImageOps import expand as pil_expand
|
|
38 |
from PIL import ImageDraw
|
39 |
from PIL.ImageFont import load_default, FreeTypeFont
|
40 |
from shapely.geometry import Polygon
|
41 |
-
import pyclipper
|
42 |
from unicodedata import category
|
43 |
from alphabet_detector import AlphabetDetector
|
44 |
-
from munch import Munch
|
45 |
-
from transformers import LayoutLMv3ForTokenClassification
|
46 |
-
import onnxruntime
|
|
|
|
|
47 |
|
48 |
-
# ---
|
49 |
-
# Ensure these are installed or available in your environment
|
50 |
try:
|
51 |
from doclayout_yolo import YOLOv10
|
52 |
except ImportError:
|
@@ -58,14 +59,14 @@ except ImportError:
|
|
58 |
print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.")
|
59 |
LatexOCR = None
|
60 |
try:
|
61 |
-
# Dynamic import within MDRTableParser, assuming struct_eqtable.py exists or is installable
|
62 |
pass # from struct_eqtable import build_model
|
63 |
except ImportError:
|
64 |
print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.")
|
65 |
|
66 |
# --- MagicDataReadiness Core Components ---
|
67 |
|
68 |
-
# --- MDR Utilities
|
|
|
69 |
def mdr_download_model(url: str, file_path: Path):
|
70 |
"""Downloads a model file from a URL to a local path."""
|
71 |
try:
|
@@ -85,7 +86,7 @@ def mdr_download_model(url: str, file_path: Path):
|
|
85 |
if file_path.exists(): os.remove(file_path)
|
86 |
raise e
|
87 |
|
88 |
-
# --- MDR Utilities
|
89 |
def mdr_ensure_directory(path: str) -> str:
|
90 |
"""Ensures a directory exists, creating it if necessary."""
|
91 |
path = os.path.abspath(path)
|
@@ -146,79 +147,101 @@ def mdr_intersection_area(rect1: MDRRectangle, rect2: MDRRectangle) -> float:
|
|
146 |
return p1.intersection(p2).area
|
147 |
except: return 0.0
|
148 |
|
149 |
-
# --- MDR Data Structures
|
150 |
@dataclass
|
151 |
class MDROcrFragment:
|
152 |
"""Represents a fragment of text identified by OCR."""
|
153 |
order: int; text: str; rank: float; rect: MDRRectangle
|
|
|
154 |
class MDRLayoutClass(Enum):
|
155 |
"""Enumeration of different layout types identified."""
|
156 |
TITLE=0; PLAIN_TEXT=1; ABANDON=2; FIGURE=3; FIGURE_CAPTION=4; TABLE=5; TABLE_CAPTION=6; TABLE_FOOTNOTE=7; ISOLATE_FORMULA=8; FORMULA_CAPTION=9
|
|
|
157 |
class MDRTableLayoutParsedFormat(Enum):
|
158 |
"""Enumeration for formats of parsed table content."""
|
159 |
LATEX=auto(); MARKDOWN=auto(); HTML=auto()
|
|
|
160 |
@dataclass
|
161 |
class MDRBaseLayoutElement:
|
162 |
"""Base class for layout elements found on a page."""
|
163 |
rect: MDRRectangle; fragments: list[MDROcrFragment]
|
|
|
164 |
@dataclass
|
165 |
class MDRPlainLayoutElement(MDRBaseLayoutElement):
|
166 |
"""Layout element for plain text, titles, captions, figures, etc."""
|
167 |
cls: Literal[MDRLayoutClass.TITLE, MDRLayoutClass.PLAIN_TEXT, MDRLayoutClass.ABANDON, MDRLayoutClass.FIGURE, MDRLayoutClass.FIGURE_CAPTION, MDRLayoutClass.TABLE_CAPTION, MDRLayoutClass.TABLE_FOOTNOTE, MDRLayoutClass.FORMULA_CAPTION]
|
|
|
168 |
@dataclass
|
169 |
class MDRTableLayoutElement(MDRBaseLayoutElement):
|
170 |
"""Layout element specifically for tables."""
|
171 |
parsed: tuple[str, MDRTableLayoutParsedFormat] | None; cls: Literal[MDRLayoutClass.TABLE] = MDRLayoutClass.TABLE
|
|
|
172 |
@dataclass
|
173 |
class MDRFormulaLayoutElement(MDRBaseLayoutElement):
|
174 |
"""Layout element specifically for formulas."""
|
175 |
latex: str | None; cls: Literal[MDRLayoutClass.ISOLATE_FORMULA] = MDRLayoutClass.ISOLATE_FORMULA
|
|
|
176 |
MDRLayoutElement = MDRPlainLayoutElement | MDRTableLayoutElement | MDRFormulaLayoutElement # Type alias
|
|
|
177 |
@dataclass
|
178 |
class MDRExtractionResult:
|
179 |
"""Holds the complete result of extracting from a single page image."""
|
180 |
rotation: float; layouts: list[MDRLayoutElement]; extracted_image: Image; adjusted_image: Image | None
|
181 |
|
182 |
-
# --- MDR Data Structures
|
|
|
183 |
MDRProgressReportCallback: TypeAlias = Callable[[int, int], None]
|
|
|
184 |
class MDROcrLevel(Enum): Once=auto(); OncePerLayout=auto()
|
|
|
185 |
class MDRExtractedTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); DISABLE=auto()
|
|
|
186 |
class MDRTextKind(Enum): TITLE=0; PLAIN_TEXT=1; ABANDON=2
|
|
|
187 |
@dataclass
|
188 |
class MDRTextSpan:
|
189 |
"""Represents a span of text content within a block."""
|
190 |
content: str; rank: float; rect: MDRRectangle
|
|
|
191 |
@dataclass
|
192 |
class MDRBasicBlock:
|
193 |
"""Base class for structured blocks extracted from the document."""
|
194 |
rect: MDRRectangle; texts: list[MDRTextSpan]; font_size: float # Relative font size (0-1)
|
|
|
195 |
@dataclass
|
196 |
class MDRTextBlock(MDRBasicBlock):
|
197 |
"""A structured block containing text content."""
|
198 |
kind: MDRTextKind; has_paragraph_indentation: bool = False; last_line_touch_end: bool = False
|
|
|
199 |
class MDRTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); UNRECOGNIZABLE=auto()
|
|
|
200 |
@dataclass
|
201 |
class MDRTableBlock(MDRBasicBlock):
|
202 |
"""A structured block representing a table."""
|
203 |
content: str; format: MDRTableFormat; image: Image # Image clip of the table
|
|
|
204 |
@dataclass
|
205 |
class MDRFormulaBlock(MDRBasicBlock):
|
206 |
"""A structured block representing a formula."""
|
207 |
content: str | None; image: Image # Image clip of the formula
|
|
|
208 |
@dataclass
|
209 |
class MDRFigureBlock(MDRBasicBlock):
|
210 |
"""A structured block representing a figure/image."""
|
211 |
image: Image # Image clip of the figure
|
|
|
212 |
MDRAssetBlock = MDRTableBlock | MDRFormulaBlock | MDRFigureBlock # Type alias
|
|
|
213 |
MDRStructuredBlock = MDRTextBlock | MDRAssetBlock # Type alias
|
214 |
|
215 |
-
# --- MDR Utilities
|
216 |
def mdr_similarity_ratio(v1: float, v2: float) -> float:
|
217 |
"""Calculates the ratio of the smaller value to the larger value (0-1)."""
|
218 |
if v1==0 and v2==0: return 1.0;
|
219 |
if v1<0 or v2<0: return 0.0;
|
220 |
v1, v2 = (v2, v1) if v1 > v2 else (v1, v2);
|
221 |
return 1.0 if v2==0 else v1/v2
|
|
|
222 |
def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[float, float]:
|
223 |
"""Calculates width/height of the intersection bounding box."""
|
224 |
try:
|
@@ -228,18 +251,23 @@ def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[fl
|
|
228 |
if inter.is_empty: return 0.0, 0.0;
|
229 |
minx, miny, maxx, maxy = inter.bounds; return maxx-minx, maxy-miny
|
230 |
except: return 0.0, 0.0
|
|
|
231 |
_MDR_CJKA_PATTERN = re.compile(r"[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3\u0600-\u06ff]")
|
|
|
232 |
def mdr_contains_cjka(text: str):
|
233 |
"""Checks if text contains Chinese, Japanese, Korean, or Arabic chars."""
|
234 |
return bool(_MDR_CJKA_PATTERN.search(text)) if text else False
|
235 |
|
236 |
-
# --- MDR Text Processing
|
237 |
class _MDR_TokenPhase(Enum): Init=0; Letter=1; Character=2; Number=3; Space=4
|
|
|
238 |
_mdr_alphabet_detector = AlphabetDetector()
|
|
|
239 |
def _mdr_is_letter(char: str):
|
240 |
if not category(char).startswith("L"): return False
|
241 |
try: return _mdr_alphabet_detector.is_latin(char) or _mdr_alphabet_detector.is_cyrillic(char) or _mdr_alphabet_detector.is_greek(char) or _mdr_alphabet_detector.is_hebrew(char)
|
242 |
except: return False
|
|
|
243 |
def mdr_split_into_words(text: str):
|
244 |
"""Splits text into words, numbers, and individual non-alphanumeric chars."""
|
245 |
if not text: return;
|
@@ -259,6 +287,7 @@ def mdr_split_into_words(text: str):
|
|
259 |
if is_s: phase=_MDR_TokenPhase.Space
|
260 |
else: yield char; phase=_MDR_TokenPhase.Character
|
261 |
if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): w=buf.getvalue(); yield w if w else None
|
|
|
262 |
def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]:
|
263 |
"""Calculates word-based similarity between two texts."""
|
264 |
w1, w2 = list(mdr_split_into_words(t1)), list(mdr_split_into_words(t2)); l1, l2 = len(w1), len(w2)
|
@@ -271,21 +300,25 @@ def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]:
|
|
271 |
if not taken[i] and word1==word2: taken[i]=True; matches+=1; break
|
272 |
mismatches = l2 - matches; return 1.0 - (mismatches/l2), l2
|
273 |
|
274 |
-
# --- MDR Geometry Processing
|
275 |
class MDRRotationAdjuster:
|
276 |
"""Adjusts point coordinates based on image rotation."""
|
|
|
277 |
def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool):
|
278 |
fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size)
|
279 |
self._rot = rotation if to_origin_coordinate else -rotation
|
280 |
self._c_off = (fs[0]/2.0, fs[1]/2.0); self._n_off = (ts[0]/2.0, ts[1]/2.0)
|
|
|
281 |
def adjust(self, point: MDRPoint) -> MDRPoint:
|
282 |
x, y = point[0]-self._c_off[0], point[1]-self._c_off[1]
|
283 |
if x!=0 or y!=0: cos_r, sin_r = cos(self._rot), sin(self._rot); x, y = x*cos_r-y*sin_r, x*sin_r+y*cos_r
|
284 |
return x+self._n_off[0], y+self._n_off[1]
|
|
|
285 |
def mdr_normalize_vertical_rotation(rot: float) -> float:
|
286 |
while rot >= pi: rot -= pi;
|
287 |
while rot < 0: rot += pi;
|
288 |
return rot
|
|
|
289 |
def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[float]] | None:
|
290 |
h_angs, v_angs = [], []
|
291 |
for i, (p1, p2) in enumerate(rect.segments):
|
@@ -297,11 +330,15 @@ def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[flo
|
|
297 |
else: v_angs.append(ang)
|
298 |
if not h_angs or not v_angs: return None
|
299 |
return h_angs, v_angs
|
|
|
300 |
def _mdr_normalize_horizontal_angles(rots: list[float]) -> list[float]: return rots
|
|
|
301 |
def _mdr_find_median(data: list[float]) -> float:
|
302 |
if not data: return 0.0; s_data = sorted(data); n = len(s_data);
|
303 |
return s_data[n//2] if n%2==1 else (s_data[n//2-1]+s_data[n//2])/2.0
|
|
|
304 |
def _mdr_find_mean(data: list[float]) -> float: return sum(data)/len(data) if data else 0.0
|
|
|
305 |
def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float:
|
306 |
all_h, all_v = [], [];
|
307 |
for f in frags:
|
@@ -314,6 +351,7 @@ def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float:
|
|
314 |
while rot_est >= pi/2: rot_est -= pi;
|
315 |
while rot_est < -pi/2: rot_est += pi;
|
316 |
return rot_est
|
|
|
317 |
def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]:
|
318 |
res = _mdr_get_rectangle_angles(rect);
|
319 |
if res is None: return 0.0, pi/2.0;
|
@@ -321,9 +359,10 @@ def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]:
|
|
321 |
h_rots = _mdr_normalize_horizontal_angles(h_rots); v_rots = [mdr_normalize_vertical_rotation(a) for a in v_rots]
|
322 |
return _mdr_find_mean(h_rots), _mdr_find_mean(v_rots)
|
323 |
|
324 |
-
# --- MDR ONNX OCR Internals
|
325 |
class _MDR_PredictBase:
|
326 |
"""Base class for ONNX model prediction components."""
|
|
|
327 |
def get_onnx_session(self, model_path: str, use_gpu: bool):
|
328 |
try:
|
329 |
sess_opts = onnxruntime.SessionOptions(); sess_opts.log_severity_level = 3
|
@@ -336,24 +375,32 @@ class _MDR_PredictBase:
|
|
336 |
if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
337 |
print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.")
|
338 |
raise e
|
|
|
339 |
def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_outputs()]
|
|
|
340 |
def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_inputs()]
|
|
|
341 |
def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: return {name: img_np for name in names}
|
342 |
|
343 |
-
# --- MDR ONNX OCR Internals
|
344 |
class _MDR_NormalizeImage:
|
|
|
345 |
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
346 |
self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0))
|
347 |
mean = mean if mean is not None else [0.485, 0.456, 0.406]; std = std if std is not None else [0.229, 0.224, 0.225]
|
348 |
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3); self.mean = np.array(mean).reshape(shape).astype('float32'); self.std = np.array(std).reshape(shape).astype('float32')
|
|
|
349 |
def __call__(self, data): img = data['image']; img = np.array(img) if isinstance(img, Image) else img; data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std; return data
|
|
|
350 |
class _MDR_DetResizeForTest:
|
|
|
351 |
def __init__(self, **kwargs):
|
352 |
self.resize_type = 0; self.keep_ratio = False
|
353 |
if 'image_shape' in kwargs: self.image_shape = kwargs['image_shape']; self.resize_type = 1; self.keep_ratio = kwargs.get('keep_ratio', False)
|
354 |
elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len']; self.limit_type = kwargs.get('limit_type', 'min')
|
355 |
elif 'resize_long' in kwargs: self.resize_type = 2; self.resize_long = kwargs.get('resize_long', 960)
|
356 |
else: self.limit_side_len = 736; self.limit_type = 'min'
|
|
|
357 |
def __call__(self, data):
|
358 |
img = data['image']; src_h, src_w, _ = img.shape
|
359 |
if src_h + src_w < 64: img = self._pad(img)
|
@@ -362,21 +409,30 @@ class _MDR_DetResizeForTest:
|
|
362 |
else: img, ratios = self._resize1(img)
|
363 |
if img is None: return None
|
364 |
data['image'] = img; data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]); return data
|
|
|
365 |
def _pad(self, im, v=0): h,w,c=im.shape; p=np.zeros((max(32,h),max(32,w),c),np.uint8)+v; p[:h,:w,:]=im; return p
|
|
|
366 |
def _resize1(self, img): rh,rw=self.image_shape; oh,ow=img.shape[:2]; if self.keep_ratio: rw=ow*rh/oh; N=math.ceil(rw/32); rw=N*32; r_h,r_w=float(rh)/oh,float(rw)/ow; img=cv2.resize(img,(int(rw),int(rh))); return img,[r_h,r_w]
|
|
|
367 |
def _resize0(self, img): lsl=self.limit_side_len; h,w,_=img.shape; r=1.0; if self.limit_type=='max': r=float(lsl)/max(h,w) if max(h,w)>lsl else 1.0; elif self.limit_type=='min': r=float(lsl)/min(h,w) if min(h,w)<lsl else 1.0; elif self.limit_type=='resize_long': r=float(lsl)/max(h,w); else: raise Exception('Unsupported'); rh,rw=int(h*r),int(w*r); rh=max(int(round(rh/32)*32),32); rw=max(int(round(rw/32)*32),32); if int(rw)<=0 or int(rh)<=0: return None,(None,None); img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
|
|
368 |
def _resize2(self, img): h,w,_=img.shape; rl=self.resize_long; r=float(rl)/max(h,w); rh,rw=int(h*r),int(w*r); ms=128; rh=(rh+ms-1)//ms*ms; rw=(rw+ms-1)//ms*ms; img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
|
|
369 |
class _MDR_ToCHWImage:
|
|
|
370 |
def __call__(self, data): img=data['image']; img=np.array(img) if isinstance(img,Image) else img; data['image']=img.transpose((2,0,1)); return data
|
|
|
371 |
class _MDR_KeepKeys:
|
|
|
372 |
def __init__(self, keep_keys, **kwargs): self.keep_keys=keep_keys
|
|
|
373 |
def __call__(self, data): return [data[key] for key in self.keep_keys]
|
374 |
|
375 |
-
# --- MDR ONNX OCR Internals (imaug.py) ---
|
376 |
def mdr_ocr_transform(data, ops=None):
|
377 |
ops = ops if ops is not None else [];
|
378 |
for op in ops: data = op(data); if data is None: return None;
|
379 |
return data
|
|
|
380 |
def mdr_ocr_create_operators(op_param_list, global_config=None):
|
381 |
ops = []
|
382 |
for operator in op_param_list:
|
@@ -388,11 +444,12 @@ def mdr_ocr_create_operators(op_param_list, global_config=None):
|
|
388 |
else: raise ValueError(f"Operator class '{op_class_name}' not found.")
|
389 |
return ops
|
390 |
|
391 |
-
# --- MDR ONNX OCR Internals (db_postprocess.py) ---
|
392 |
class _MDR_DBPostProcess:
|
|
|
393 |
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs):
|
394 |
self.thresh, self.box_thresh, self.max_cand = thresh, box_thresh, max_candidates; self.unclip_r, self.min_sz, self.score_m, self.box_t = unclip_ratio, 3, score_mode, box_type
|
395 |
assert score_mode in ["slow", "fast"]; self.dila_k = np.array([[1,1],[1,1]], dtype=np.uint8) if use_dilation else None
|
|
|
396 |
def _polygons_from_bitmap(self, pred, bmp, dw, dh):
|
397 |
h, w = bmp.shape; boxes, scores = [], []
|
398 |
contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
@@ -409,6 +466,7 @@ class _MDR_DBPostProcess:
|
|
409 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
410 |
boxes.append(box.tolist()); scores.append(score)
|
411 |
return boxes, scores
|
|
|
412 |
def _boxes_from_bitmap(self, pred, bmp, dw, dh):
|
413 |
h, w = bmp.shape; contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
414 |
num_contours = min(len(contours), self.max_cand); boxes, scores = [], []
|
@@ -424,25 +482,30 @@ class _MDR_DBPostProcess:
|
|
424 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
425 |
boxes.append(box.astype("int32")); scores.append(score)
|
426 |
return np.array(boxes, dtype="int32"), scores
|
|
|
427 |
def _unclip(self, box, ratio):
|
428 |
poly = Polygon(box); dist = poly.area*ratio/poly.length; offset = pyclipper.PyclipperOffset(); offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
429 |
expanded = offset.Execute(dist);
|
430 |
if not expanded: raise ValueError("Unclip failed"); return np.array(expanded[0])
|
|
|
431 |
def _get_mini_boxes(self, contour):
|
432 |
bb = cv2.minAreaRect(contour); pts = sorted(list(cv2.boxPoints(bb)), key=lambda x:x[0])
|
433 |
i1,i4 = (0,1) if pts[1][1]>pts[0][1] else (1,0); i2,i3 = (2,3) if pts[3][1]>pts[2][1] else (3,2)
|
434 |
box = [pts[i1], pts[i2], pts[i3], pts[i4]]; return box, min(bb[1])
|
|
|
435 |
def _box_score_fast(self, bmp, box):
|
436 |
h,w = bmp.shape[:2]; xmin=np.clip(np.floor(box[:,0].min()).astype("int32"),0,w-1); xmax=np.clip(np.ceil(box[:,0].max()).astype("int32"),0,w-1)
|
437 |
ymin=np.clip(np.floor(box[:,1].min()).astype("int32"),0,h-1); ymax=np.clip(np.ceil(box[:,1].max()).astype("int32"),0,h-1)
|
438 |
mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=np.uint8); box[:,0]-=xmin; box[:,1]-=ymin
|
439 |
cv2.fillPoly(mask, box.reshape(1,-1,2).astype("int32"), 1);
|
440 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
|
|
441 |
def _box_score_slow(self, bmp, contour): # Not used if fast
|
442 |
h,w = bmp.shape[:2]; contour = np.reshape(contour.copy(),(-1,2)); xmin=np.clip(np.min(contour[:,0]),0,w-1); xmax=np.clip(np.max(contour[:,0]),0,w-1)
|
443 |
ymin=np.clip(np.min(contour[:,1]),0,h-1); ymax=np.clip(np.max(contour[:,1]),0,h-1); mask=np.zeros((ymax-ymin+1,xmax-xmin+1),dtype=np.uint8)
|
444 |
contour[:,0]-=xmin; contour[:,1]-=ymin; cv2.fillPoly(mask, contour.reshape(1,-1,2).astype("int32"), 1);
|
445 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
|
|
446 |
def __call__(self, outs_dict, shape_list):
|
447 |
pred = outs_dict['maps'][:,0,:,:]; seg = pred > self.thresh; boxes_batch = []
|
448 |
for batch_idx in range(pred.shape[0]):
|
@@ -453,8 +516,8 @@ class _MDR_DBPostProcess:
|
|
453 |
boxes_batch.append({'points': boxes})
|
454 |
return boxes_batch
|
455 |
|
456 |
-
# --- MDR ONNX OCR Internals (predict_det.py) ---
|
457 |
class _MDR_TextDetector(_MDR_PredictBase):
|
|
|
458 |
def __init__(self, args):
|
459 |
super().__init__(); self.args = args
|
460 |
pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}]
|
@@ -463,10 +526,15 @@ class _MDR_TextDetector(_MDR_PredictBase):
|
|
463 |
self.post_op = _MDR_DBPostProcess(**post_params)
|
464 |
self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu)
|
465 |
self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
|
|
466 |
def _order_pts(self, pts): r=np.zeros((4,2),dtype="float32"); s=pts.sum(axis=1); r[0]=pts[np.argmin(s)]; r[2]=pts[np.argmax(s)]; tmp=np.delete(pts,(np.argmin(s),np.argmax(s)),axis=0); d=np.diff(np.array(tmp),axis=1); r[1]=tmp[np.argmin(d)]; r[3]=tmp[np.argmax(d)]; return r
|
|
|
467 |
def _clip_pts(self, pts, h, w): pts[:,0]=np.clip(pts[:,0],0,w-1); pts[:,1]=np.clip(pts[:,1],0,h-1); return pts
|
|
|
468 |
def _filter_quad(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._order_pts(box); box=self._clip_pts(box,h,w); rw=int(np.linalg.norm(box[0]-box[1])); rh=int(np.linalg.norm(box[0]-box[3])); if rw<=3 or rh<=3: continue; new_boxes.append(box); return np.array(new_boxes)
|
|
|
469 |
def _filter_poly(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._clip_pts(box,h,w); if Polygon(box).area<10: continue; new_boxes.append(box); return np.array(new_boxes)
|
|
|
470 |
def __call__(self, img):
|
471 |
ori_im = img.copy(); data = {"image": img}; data = mdr_ocr_transform(data, self.pre_op)
|
472 |
if data is None: return None; img, shape_list = data;
|
@@ -475,25 +543,28 @@ class _MDR_TextDetector(_MDR_PredictBase):
|
|
475 |
preds = {"maps": outputs[0]}; post_res = self.post_op(preds, shape_list); boxes = post_res[0]['points']
|
476 |
return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type=='poly' else self._filter_quad(boxes, ori_im.shape)
|
477 |
|
478 |
-
# --- MDR ONNX OCR Internals (cls_postprocess.py) ---
|
479 |
class _MDR_ClsPostProcess:
|
|
|
480 |
def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0:'0', 1:'180'}
|
|
|
481 |
def __call__(self, preds, label=None, *args, **kwargs):
|
482 |
preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; idxs = preds.argmax(axis=1)
|
483 |
return [(self.labels[idx], float(preds[i,idx])) for i,idx in enumerate(idxs)]
|
484 |
|
485 |
-
# --- MDR ONNX OCR Internals (predict_cls.py) ---
|
486 |
class _MDR_TextClassifier(_MDR_PredictBase):
|
|
|
487 |
def __init__(self, args):
|
488 |
super().__init__(); self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape
|
489 |
self.batch_num = args.cls_batch_num; self.thresh = args.cls_thresh; self.post_op = _MDR_ClsPostProcess(label_list=args.label_list)
|
490 |
self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu); self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
|
|
491 |
def _resize_norm(self, img):
|
492 |
imgC,imgH,imgW = self.shape; h,w = img.shape[:2]; r=w/float(h) if h>0 else 0; rw=int(math.ceil(imgH*r)); rw=min(rw,imgW)
|
493 |
resized = cv2.resize(img,(rw,imgH)); resized = resized.astype("float32")
|
494 |
if imgC==1: resized = resized/255.0; resized = resized[np.newaxis,:]
|
495 |
else: resized = resized.transpose((2,0,1))/255.0
|
496 |
resized -= 0.5; resized /= 0.5; padding = np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:rw]=resized; return padding
|
|
|
497 |
def __call__(self, img_list):
|
498 |
if not img_list: return img_list, []; img_list_cp = copy.deepcopy(img_list); num = len(img_list_cp)
|
499 |
ratios = [img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list_cp]; indices = np.argsort(np.array(ratios))
|
@@ -509,8 +580,8 @@ class _MDR_TextClassifier(_MDR_PredictBase):
|
|
509 |
if "180" in label and score > self.thresh: img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180)
|
510 |
return img_list, results
|
511 |
|
512 |
-
# --- MDR ONNX OCR Internals (rec_postprocess.py) ---
|
513 |
class _MDR_BaseRecLabelDecode:
|
|
|
514 |
def __init__(self, char_path=None, use_space=False):
|
515 |
self.beg, self.end, self.rev = "sos", "eos", False; self.chars = []
|
516 |
if char_path is None: self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz")
|
@@ -521,9 +592,13 @@ class _MDR_BaseRecLabelDecode:
|
|
521 |
if any("\u0600"<=c<="\u06FF" for c in self.chars): self.rev=True
|
522 |
except FileNotFoundError: print(f"Warn: Dict not found {char_path}"); self.chars=list("0123456789abcdefghijklmnopqrstuvwxyz"); if use_space: self.chars.append(" ")
|
523 |
d_char = self.add_special_char(list(self.chars)); self.dict={c:i for i,c in enumerate(d_char)}; self.character=d_char
|
|
|
524 |
def add_special_char(self, chars): return chars
|
|
|
525 |
def get_ignored_tokens(self): return []
|
|
|
526 |
def _reverse(self, pred): res=[]; cur=""; for c in pred: if not re.search("[a-zA-Z0-9 :*./%+-]",c): res.extend([cur,c] if cur!="" else [c]); cur="" else: cur+=c; if cur!="": res.append(cur); return "".join(res[::-1])
|
|
|
527 |
def decode(self, idxs, probs=None, remove_dup=False):
|
528 |
res=[]; ignored=self.get_ignored_tokens(); bs=len(idxs)
|
529 |
for b_idx in range(bs):
|
@@ -537,6 +612,7 @@ class _MDR_BaseRecLabelDecode:
|
|
537 |
if self.rev: txt=self._reverse(txt)
|
538 |
res.append((txt, float(np.mean(conf_l))))
|
539 |
return res
|
|
|
540 |
class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode):
|
541 |
def __init__(self, char_path=None, use_space=False, **kwargs): super().__init__(char_path, use_space)
|
542 |
def add_special_char(self, chars): return ["blank"]+chars
|
@@ -545,13 +621,14 @@ class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode):
|
|
545 |
preds = preds[-1] if isinstance(preds,(tuple,list)) else preds; preds = np.array(preds) if not isinstance(preds,np.ndarray) else preds
|
546 |
idxs=preds.argmax(axis=2); probs=preds.max(axis=2); txt=self.decode(idxs, probs, remove_dup=True); return txt
|
547 |
|
548 |
-
# --- MDR ONNX OCR Internals (predict_rec.py) ---
|
549 |
class _MDR_TextRecognizer(_MDR_PredictBase):
|
|
|
550 |
def __init__(self, args):
|
551 |
super().__init__(); shape_str=getattr(args,'rec_image_shape',"3,48,320"); self.shape=tuple(map(int,shape_str.split(',')))
|
552 |
self.batch_num=getattr(args,'rec_batch_num',6); self.algo=getattr(args,'rec_algorithm','SVTR_LCNet')
|
553 |
self.post_op=_MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, use_space=getattr(args,'use_space_char',True))
|
554 |
self.sess=self.get_onnx_session(args.rec_model_dir, args.use_gpu); self.input_name=self.get_input_name(self.sess); self.output_name=self.get_output_name(self.sess)
|
|
|
555 |
def _resize_norm(self, img, max_r):
|
556 |
imgC,imgH,imgW = self.shape; h,w = img.shape[:2];
|
557 |
if h==0 or w==0: return np.zeros((imgC,imgH,imgW),dtype=np.float32)
|
@@ -561,6 +638,7 @@ class _MDR_TextRecognizer(_MDR_PredictBase):
|
|
561 |
if len(resized.shape)==2: resized=resized[:,:,np.newaxis]
|
562 |
resized=resized.transpose((2,0,1))/255.0; resized-=0.5; resized/=0.5
|
563 |
padding=np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:tw]=resized; return padding
|
|
|
564 |
def __call__(self, img_list):
|
565 |
if not img_list: return []; num=len(img_list); ratios=[img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list]
|
566 |
indices=np.argsort(np.array(ratios)); results=[["",0.0]]*num; batch_n=self.batch_num
|
@@ -574,8 +652,9 @@ class _MDR_TextRecognizer(_MDR_PredictBase):
|
|
574 |
for i in range(len(rec_out)): results[indices[start+i]]=rec_out[i]
|
575 |
return results
|
576 |
|
577 |
-
# --- MDR ONNX OCR System
|
578 |
class _MDR_TextSystem:
|
|
|
579 |
def __init__(self, args):
|
580 |
class ArgsObject: # Helper to access dict args with dot notation
|
581 |
def __init__(self, **entries): self.__dict__.update(entries)
|
@@ -587,11 +666,13 @@ class _MDR_TextSystem:
|
|
587 |
self.drop_score = getattr(args, 'drop_score', 0.5)
|
588 |
self.classifier = _MDR_TextClassifier(args) if self.use_cls else None
|
589 |
self.crop_idx = 0; self.save_crop = getattr(args, 'save_crop_res', False); self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res")
|
|
|
590 |
def _sort_boxes(self, boxes):
|
591 |
if boxes is None or len(boxes)==0: return []
|
592 |
def key(box): min_y=min(p[1] for p in box); min_x=min(p[0] for p in box); return (min_y, min_x)
|
593 |
try: return list(sorted(boxes, key=key))
|
594 |
except: return list(boxes) # Fallback
|
|
|
595 |
def __call__(self, img, classify=True):
|
596 |
ori_im = img.copy(); boxes = self.detector(img)
|
597 |
if boxes is None or len(boxes)==0: return [], []
|
@@ -613,12 +694,13 @@ class _MDR_TextSystem:
|
|
613 |
if score >= self.drop_score: final_boxes.append(box); final_rec.append(res)
|
614 |
if self.save_crop: self._save_crops(crops, rec_res)
|
615 |
return final_boxes, final_rec
|
|
|
616 |
def _save_crops(self, crops, recs):
|
617 |
mdr_ensure_directory(self.crop_dir); num = len(crops)
|
618 |
for i in range(num): txt, score = recs[i]; safe=re.sub(r'\W+', '_', txt)[:20]; fname=f"crop_{self.crop_idx+i}_{safe}_{score:.2f}.jpg"; cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i])
|
619 |
self.crop_idx += num
|
620 |
|
621 |
-
# --- MDR ONNX OCR Utilities
|
622 |
def mdr_get_rotated_crop(img, points):
|
623 |
"""Crops and perspective-transforms a quadrilateral region."""
|
624 |
pts = np.array(points, dtype="float32"); assert len(pts)==4
|
@@ -630,14 +712,17 @@ def mdr_get_rotated_crop(img, points):
|
|
630 |
dh, dw = dst.shape[0:2]
|
631 |
if dh>0 and dw>0 and dh*1.0/dw >= 1.5: dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE)
|
632 |
return dst
|
|
|
633 |
def mdr_get_min_area_crop(img, points):
|
634 |
"""Crops the minimum area rectangle containing the points."""
|
635 |
bb = cv2.minAreaRect(np.array(points).astype(np.int32)); box_pts = cv2.boxPoints(bb)
|
636 |
return mdr_get_rotated_crop(img, box_pts)
|
637 |
|
638 |
-
# --- MDR Layout Processing
|
639 |
_MDR_INCLUDES_MIN_RATE = 0.99
|
|
|
640 |
class _MDR_OverlapMatrixContext:
|
|
|
641 |
def __init__(self, layouts: list[MDRLayoutElement]):
|
642 |
length = len(layouts); self.polys: list[Polygon|None] = []
|
643 |
for l in layouts:
|
@@ -651,6 +736,7 @@ class _MDR_OverlapMatrixContext:
|
|
651 |
p2 = self.polys[j];
|
652 |
if p2 is None: continue
|
653 |
r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji
|
|
|
654 |
def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2
|
655 |
try: inter = p1.intersection(p2);
|
656 |
except: return 0.0
|
@@ -659,9 +745,11 @@ class _MDR_OverlapMatrixContext:
|
|
659 |
_, _, px1, py1 = p2.bounds; pw, ph = px1-p2.bounds[0], py1-p2.bounds[1]
|
660 |
if pw < 1e-6 or ph < 1e-6: return 0.0
|
661 |
wr = min(iw/pw, 1.0); hr = min(ih/ph, 1.0); return (wr+hr)/2.0
|
|
|
662 |
def others(self, idx: int):
|
663 |
for i, r in enumerate(self.matrix[idx]):
|
664 |
if i != idx and i not in self.removed: yield r
|
|
|
665 |
def includes(self, idx: int): # Layouts included BY idx
|
666 |
for i, r in enumerate(self.matrix[idx]):
|
667 |
if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE:
|
@@ -723,8 +811,9 @@ def mdr_merge_fragments_into_lines(orig_frags: list[MDROcrFragment]) -> list[MDR
|
|
723 |
for i, f in enumerate(merged): f.order = i
|
724 |
return merged
|
725 |
|
726 |
-
# --- MDR Layout Processing
|
727 |
_MDR_CORRECTION_MIN_OVERLAP = 0.5
|
|
|
728 |
def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image, layout: MDRLayoutElement):
|
729 |
if not layout.fragments: return;
|
730 |
try:
|
@@ -759,9 +848,12 @@ def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image,
|
|
759 |
final = [n if n.rank >= o.rank else o for o, n in matched]; final.extend(unmatched_orig); final.extend(unmatched_new)
|
760 |
layout.fragments = final; layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
761 |
|
762 |
-
# --- MDR OCR Engine
|
|
|
763 |
_MDR_OCR_MODELS = {"det": ("ppocrv4","det","det.onnx"), "cls": ("ppocrv4","cls","cls.onnx"), "rec": ("ppocrv4","rec","rec.onnx"), "keys": ("ch_ppocr_server_v2.0","ppocr_keys_v1.txt")}
|
|
|
764 |
_MDR_OCR_URL_BASE = "https://huggingface.co/moskize/OnnxOCR/resolve/main/"
|
|
|
765 |
@dataclass
|
766 |
class _MDR_ONNXParams: # Simplified container
|
767 |
use_gpu: bool; det_model_dir: str; cls_model_dir: str; rec_model_dir: str; rec_char_dict_path: str
|
@@ -772,14 +864,17 @@ class _MDR_ONNXParams: # Simplified container
|
|
772 |
|
773 |
class MDROcrEngine:
|
774 |
"""Handles OCR detection and recognition using ONNX models."""
|
|
|
775 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str):
|
776 |
self._device = device; self._model_dir = mdr_ensure_directory(model_dir_path)
|
777 |
self._text_system: _MDR_TextSystem | None = None; self._onnx_params: _MDR_ONNXParams | None = None
|
778 |
self._ensure_models(); self._get_system() # Init on creation
|
|
|
779 |
def _ensure_models(self):
|
780 |
for key, parts in _MDR_OCR_MODELS.items():
|
781 |
fp = Path(self._model_dir) / Path(*parts)
|
782 |
if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join(parts); mdr_download_model(url, fp)
|
|
|
783 |
def _get_system(self) -> _MDR_TextSystem | None:
|
784 |
if self._text_system is None:
|
785 |
paths = {k: str(Path(self._model_dir)/Path(*p)) for k,p in _MDR_OCR_MODELS.items()}
|
@@ -787,6 +882,7 @@ class MDROcrEngine:
|
|
787 |
try: self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.")
|
788 |
except Exception as e: print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None
|
789 |
return self._text_system
|
|
|
790 |
def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]:
|
791 |
"""Finds and recognizes text fragments in a NumPy image (BGR)."""
|
792 |
system = self._get_system()
|
@@ -799,22 +895,26 @@ class MDROcrEngine:
|
|
799 |
if not txt or mdr_is_whitespace(txt) or conf < 0.1: continue
|
800 |
pts = [(float(p[0]), float(p[1])) for p in box_pts]
|
801 |
if len(pts)==4: r=MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]); if r.is_valid and r.area>1: yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r)
|
|
|
802 |
def _preprocess(self, img: np.ndarray) -> np.ndarray:
|
803 |
if len(img.shape)==3 and img.shape[2]==4: a=img[:,:,3]/255.0; bg=(255,255,255); new=np.zeros_like(img[:,:,:3]); [setattr(new[:,:,i], 'flags.writeable', True) for i in range(3)]; [np.copyto(new[:,:,i], (bg[i]*(1-a)+img[:,:,i]*a)) for i in range(3)]; img=new.astype(np.uint8)
|
804 |
elif len(img.shape)==2: img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
805 |
elif not (len(img.shape)==3 and img.shape[2]==3): raise ValueError("Unsupported image format")
|
806 |
return img
|
807 |
|
808 |
-
# --- MDR Layout Reading Internals
|
809 |
_MDR_MAX_LEN = 510; _MDR_CLS_ID = 0; _MDR_SEP_ID = 2; _MDR_PAD_ID = 1
|
|
|
810 |
def mdr_boxes_to_reader_inputs(boxes: List[List[int]], max_len=_MDR_MAX_LEN) -> Dict[str, torch.Tensor]:
|
811 |
t_boxes = boxes[:max_len]; i_boxes = [[0,0,0,0]] + t_boxes + [[0,0,0,0]]
|
812 |
i_ids = [_MDR_CLS_ID] + [_MDR_PAD_ID]*len(t_boxes) + [_MDR_SEP_ID]
|
813 |
a_mask = [1]*len(i_ids); pad_len = (max_len+2) - len(i_ids)
|
814 |
if pad_len > 0: i_boxes.extend([[0,0,0,0]]*pad_len); i_ids.extend([_MDR_PAD_ID]*pad_len); a_mask.extend([0]*pad_len)
|
815 |
return {"bbox": torch.tensor([i_boxes]), "input_ids": torch.tensor([i_ids]), "attention_mask": torch.tensor([a_mask])}
|
|
|
816 |
def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification) -> Dict[str, torch.Tensor]:
|
817 |
return {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
|
818 |
def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]:
|
819 |
if length == 0: return []; rel_logits = logits[1:length+1, :length]; orders = rel_logits.argmax(dim=1).tolist()
|
820 |
while True:
|
@@ -830,14 +930,17 @@ def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]:
|
|
830 |
orders[idx] = rel_logits[idx, :].argmax().item(); rel_logits[idx, order] = orig_logit
|
831 |
return orders
|
832 |
|
833 |
-
# --- MDR Layout Reading Engine
|
834 |
@dataclass
|
835 |
class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; order: int; value: tuple[float, float, float, float]
|
|
|
836 |
class MDRLayoutReader:
|
837 |
"""Determines reading order of layout elements using LayoutLMv3."""
|
|
|
838 |
def __init__(self, model_path: str):
|
839 |
self._model_path = model_path; self._model: LayoutLMv3ForTokenClassification | None = None
|
840 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
841 |
def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
|
842 |
if self._model is None:
|
843 |
cache = mdr_ensure_directory(self._model_path); name = "microsoft/layoutlmv3-base"; h_path = os.path.join(cache, "models--hantian--layoutreader")
|
@@ -847,6 +950,7 @@ class MDRLayoutReader:
|
|
847 |
self._model.to(self._device); self._model.eval(); print(f"MDR LayoutReader loaded on {self._device}.")
|
848 |
except Exception as e: print(f"ERROR loading MDR LayoutReader: {e}"); self._model = None
|
849 |
return self._model
|
|
|
850 |
def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]:
|
851 |
w, h = size;
|
852 |
if w<=0 or h<=0 or not layouts: return layouts;
|
@@ -873,6 +977,7 @@ class MDRLayoutReader:
|
|
873 |
if len(orders) != len(bbox_list): print("MDR LayoutReader order mismatch"); return layouts # Fallback
|
874 |
for i, order_idx in enumerate(orders): bbox_list[i].order = order_idx
|
875 |
return self._apply_order(layouts, bbox_list)
|
|
|
876 |
def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None:
|
877 |
line_h = self._estimate_line_h(layouts); bbox_list = []
|
878 |
for i, l in enumerate(layouts):
|
@@ -880,6 +985,7 @@ class MDRLayoutReader:
|
|
880 |
else: bbox_list.extend(self._gen_virtual(l, i, line_h, w, h))
|
881 |
if len(bbox_list) > _MDR_MAX_LEN: print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})"); return None
|
882 |
bbox_list.sort(key=lambda b: (b.value[1], b.value[0])); return bbox_list
|
|
|
883 |
def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]:
|
884 |
layout_map = defaultdict(list); [layout_map[b.layout_index].append(b) for b in bbox_list]
|
885 |
layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes]
|
@@ -893,9 +999,11 @@ class MDRLayoutReader:
|
|
893 |
else: frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
894 |
for frag in frags: frag.order = nfo; nfo += 1
|
895 |
return sorted_layouts
|
|
|
896 |
def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float:
|
897 |
heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1]>0]
|
898 |
return self._median(heights) if heights else 15.0
|
|
|
899 |
def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]:
|
900 |
x0,y0,x1,y1 = l.rect.wrapper; lh,lw = y1-y0,x1-x0
|
901 |
if lh<=0 or lw<=0 or line_h<=0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(x0,y0,x1,y1)); return
|
@@ -910,16 +1018,19 @@ class MDRLayoutReader:
|
|
910 |
ly0,ly1 = max(0,min(ph,cur_y)), max(0,min(ph,cur_y+act_line_h)); lx0,lx1 = max(0,min(pw,x0)), max(0,min(pw,x1))
|
911 |
if ly1>ly0 and lx1>lx0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(lx0,ly0,lx1,ly1))
|
912 |
cur_y += act_line_h
|
|
|
913 |
def _median(self, nums: list[float|int]) -> float:
|
914 |
if not nums: return 0.0; s_nums = sorted(nums); n = len(s_nums)
|
915 |
return float(s_nums[n//2]) if n%2==1 else float((s_nums[n//2-1]+s_nums[n//2])/2.0)
|
916 |
|
917 |
-
# --- MDR LaTeX Extractor
|
918 |
class MDRLatexExtractor:
|
919 |
"""Extracts LaTeX from formula images using pix2tex."""
|
|
|
920 |
def __init__(self, model_path: str):
|
921 |
self._model_path = model_path; self._model: LatexOCR | None = None
|
922 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
923 |
def extract(self, image: Image) -> str | None:
|
924 |
if LatexOCR is None: return None;
|
925 |
image = mdr_expand_image(image, 0.1); model = self._get_model()
|
@@ -927,6 +1038,7 @@ class MDRLatexExtractor:
|
|
927 |
try:
|
928 |
with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None
|
929 |
except Exception as e: print(f"MDR LaTeX error: {e}"); return None
|
|
|
930 |
def _get_model(self) -> LatexOCR | None:
|
931 |
if self._model is None and LatexOCR is not None:
|
932 |
mdr_ensure_directory(self._model_path); wp, rp, cp = Path(self._model_path)/"weights.pth", Path(self._model_path)/"image_resizer.pth", Path(self._model_path)/"config.yaml"
|
@@ -935,19 +1047,23 @@ class MDRLatexExtractor:
|
|
935 |
try: args = Munch({"config":str(cp), "checkpoint":str(wp), "device":self._device, "no_cuda":self._device=="cpu", "no_resize":False, "temperature":0.0}); self._model = LatexOCR(args); print(f"MDR LaTeX loaded on {self._device}.")
|
936 |
except Exception as e: print(f"ERROR initializing MDR LatexOCR: {e}"); self._model = None
|
937 |
return self._model
|
|
|
938 |
def _download(self):
|
939 |
tag = "v0.0.1"; base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/"; files = {"weights.pth": base+"weights.pth", "image_resizer.pth": base+"image_resizer.pth"}
|
940 |
mdr_ensure_directory(self._model_path); [mdr_download_model(url, Path(self._model_path)/name) for name, url in files.items() if not (Path(self._model_path)/name).exists()]
|
941 |
|
942 |
-
# --- MDR Table Parser
|
943 |
MDRTableOutputFormat = Literal["latex", "markdown", "html"]
|
|
|
944 |
class MDRTableParser:
|
945 |
"""Parses table structure/content from images using StructTable model."""
|
|
|
946 |
def __init__(self, device: Literal["cpu", "cuda"], model_path: str):
|
947 |
self._model: Any | None = None; self._model_path = mdr_ensure_directory(model_path)
|
948 |
self._device = device if torch.cuda.is_available() and device=="cuda" else "cpu"
|
949 |
self._disabled = self._device == "cpu"
|
950 |
if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.")
|
|
|
951 |
def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None:
|
952 |
if self._disabled: return None;
|
953 |
fmt: MDRTableOutputFormat | None = None
|
@@ -962,6 +1078,7 @@ class MDRTableParser:
|
|
962 |
with torch.no_grad(): results = model([img_rgb], output_format=fmt)
|
963 |
return results[0] if results else None
|
964 |
except Exception as e: print(f"MDR Table parsing error: {e}"); return None
|
|
|
965 |
def _get_model(self):
|
966 |
if self._model is None and not self._disabled:
|
967 |
try:
|
@@ -974,23 +1091,31 @@ class MDRTableParser:
|
|
974 |
except Exception as e: print(f"ERROR loading MDR StructTable: {e}"); self._model=None
|
975 |
return self._model
|
976 |
|
977 |
-
# --- MDR Image Optimizer
|
978 |
_MDR_TINY_ROTATION = 0.005
|
|
|
979 |
@dataclass
|
980 |
class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationAdjuster; fragment_origin_rectangles: list[MDRRectangle]
|
|
|
981 |
class MDRImageOptimizer:
|
982 |
"""Handles image rotation detection and coordinate adjustments."""
|
|
|
983 |
def __init__(self, raw_image: Image, adjust_points: bool):
|
984 |
self._raw = raw_image; self._image = raw_image; self._adjust_points = adjust_points
|
985 |
self._fragments: list[MDROcrFragment] = []; self._rotation: float = 0.0; self._rot_ctx: _MDR_RotationContext | None = None
|
|
|
986 |
@property
|
987 |
def image(self) -> Image: return self._image
|
|
|
988 |
@property
|
989 |
def adjusted_image(self) -> Image | None: return self._image if self._rot_ctx is not None else None
|
|
|
990 |
@property
|
991 |
def rotation(self) -> float: return self._rotation
|
|
|
992 |
@property
|
993 |
def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
|
|
994 |
def receive_fragments(self, fragments: list[MDROcrFragment]):
|
995 |
self._fragments = fragments;
|
996 |
if not fragments: return;
|
@@ -1005,12 +1130,13 @@ class MDRImageOptimizer:
|
|
1005 |
to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False),
|
1006 |
to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True))
|
1007 |
adj = self._rot_ctx.to_new; [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r:=f.rect)]
|
|
|
1008 |
def finalize_layout_coords(self, layouts: list[MDRLayoutElement]):
|
1009 |
if self._rot_ctx is None or self._adjust_points: return
|
1010 |
if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles)]
|
1011 |
adj = self._rot_ctx.to_origin; [setattr(l, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for l in layouts if (r:=l.rect)]
|
1012 |
|
1013 |
-
# --- MDR Image Clipping
|
1014 |
def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
|
1015 |
"""Clips a potentially rotated rectangle from an image."""
|
1016 |
try:
|
@@ -1026,18 +1152,21 @@ def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, w
|
|
1026 |
out_w, out_h = ceil(avg_w+wrap_w), ceil(avg_h+wrap_h)
|
1027 |
return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, fillcolor=(255,255,255))
|
1028 |
except Exception as e: print(f"MDR Clipping error: {e}"); return new_image("RGB", (10,10), (255,255,255))
|
|
|
1029 |
def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
|
1030 |
"""Clips a layout region from the MDRExtractionResult image."""
|
1031 |
img = res.adjusted_image if res.adjusted_image else res.extracted_image
|
1032 |
return mdr_clip_from_image(img, layout.rect, wrap_w, wrap_h)
|
1033 |
|
1034 |
-
# --- MDR Debug Plotting
|
1035 |
_MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); _MDR_LAYOUT_COLORS = { MDRLayoutClass.TITLE: (0x0A,0x12,0x2C,255), MDRLayoutClass.PLAIN_TEXT: (0x3C,0x67,0x90,255), MDRLayoutClass.ABANDON: (0xC0,0xBB,0xA9,180), MDRLayoutClass.FIGURE: (0x5B,0x91,0x3C,255), MDRLayoutClass.FIGURE_CAPTION: (0x77,0xB3,0x54,255), MDRLayoutClass.TABLE: (0x44,0x17,0x52,255), MDRLayoutClass.TABLE_CAPTION: (0x81,0x75,0xA0,255), MDRLayoutClass.TABLE_FOOTNOTE: (0xEF,0xB6,0xC9,255), MDRLayoutClass.ISOLATE_FORMULA: (0xFA,0x38,0x27,255), MDRLayoutClass.FORMULA_CAPTION: (0xFF,0x9D,0x24,255) }; _MDR_DEFAULT_COLOR = (0x80,0x80,0x80,255); _MDR_RGBA = tuple[int,int,int,int]
|
|
|
1036 |
def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
|
1037 |
"""Draws layout and fragment boxes onto an image for debugging."""
|
1038 |
if not layouts: return;
|
1039 |
try: l_font, f_font = load_default(size=25), load_default(size=15); draw = ImageDraw.Draw(image, mode="RGBA")
|
1040 |
except Exception as e: print(f"MDR Plot init error: {e}"); return
|
|
|
1041 |
def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA):
|
1042 |
try: x,y=pos; txt=str(num); txt_pos=(round(x)+3, round(y)+1); bbox=draw.textbbox(txt_pos,txt,font=font); bg_rect=(bbox[0]-2,bbox[1]-1,bbox[2]+2,bbox[3]+1); bg_color=(color[0],color[1],color[2],180); draw.rectangle(bg_rect,fill=bg_color); draw.text(txt_pos,txt,font=font,fill=(255,255,255,255))
|
1043 |
except Exception as e: print(f"MDR Draw num error: {e}")
|
@@ -1049,26 +1178,65 @@ def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
|
|
1049 |
try: draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1)
|
1050 |
except Exception as e: print(f"MDR Fragment draw error: {e}")
|
1051 |
|
1052 |
-
# --- MDR Extraction Engine
|
1053 |
class MDRExtractionEngine:
|
1054 |
"""Core engine for extracting structured information from a document image."""
|
|
|
1055 |
def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"]="cpu", ocr_for_each_layouts: bool=True, extract_formula: bool=True, extract_table_format: MDRTableLayoutParsedFormat|None=None):
|
1056 |
-
self._model_dir = model_dir_path
|
|
|
1057 |
self._ocr_each = ocr_for_each_layouts; self._ext_formula = extract_formula; self._ext_table = extract_table_format
|
1058 |
self._yolo: YOLOv10 | None = None
|
1059 |
-
|
1060 |
-
self.
|
1061 |
-
self.
|
1062 |
-
self.
|
|
|
1063 |
print(f"MDR Extraction Engine initialized on device: {self._device}")
|
|
|
|
|
1064 |
def _get_yolo_model(self) -> YOLOv10 | None:
|
|
|
1065 |
if self._yolo is None and YOLOv10 is not None:
|
1066 |
-
|
1067 |
-
|
1068 |
-
|
1069 |
-
|
1070 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1071 |
return self._yolo
|
|
|
1072 |
def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult:
|
1073 |
"""Analyzes a single page image to extract layout and content."""
|
1074 |
print(" Engine: Analyzing image..."); optimizer = MDRImageOptimizer(image, adjust_points)
|
@@ -1088,6 +1256,7 @@ class MDRExtractionEngine:
|
|
1088 |
print(" Engine: Finalizing coords..."); optimizer.finalize_layout_coords(layouts)
|
1089 |
print(" Engine: Analysis complete.")
|
1090 |
return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image)
|
|
|
1091 |
def _run_yolo_detection(self, img: Image, yolo: YOLOv10) -> Generator[MDRLayoutElement, None, None]:
|
1092 |
img_rgb = img.convert("RGB"); res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False)
|
1093 |
if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: return
|
@@ -1101,6 +1270,7 @@ class MDRExtractionEngine:
|
|
1101 |
if cls == MDRLayoutClass.TABLE: yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None)
|
1102 |
elif cls == MDRLayoutClass.ISOLATE_FORMULA: yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None)
|
1103 |
elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[])
|
|
|
1104 |
def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]:
|
1105 |
if not frags or not layouts: return layouts
|
1106 |
layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts]
|
@@ -1120,11 +1290,13 @@ class MDRExtractionEngine:
|
|
1120 |
layouts[best_idx].fragments.append(frag)
|
1121 |
for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
1122 |
return layouts
|
|
|
1123 |
def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]):
|
1124 |
for i, l in enumerate(layouts):
|
1125 |
if l.cls == MDRLayoutClass.FIGURE: continue
|
1126 |
try: mdr_correct_layout_fragments(self._ocr_engine, img, l)
|
1127 |
except Exception as e: print(f" Engine: OCR correction error layout {i}: {e}")
|
|
|
1128 |
def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer):
|
1129 |
img_to_clip = optimizer.image
|
1130 |
for l in layouts:
|
@@ -1135,25 +1307,33 @@ class MDRExtractionEngine:
|
|
1135 |
try: t_img = mdr_clip_from_image(img_to_clip, l.rect); parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width>1 and t_img.height>1 else None
|
1136 |
except Exception as e: print(f" Engine: Table parse error: {e}"); parsed = None
|
1137 |
if parsed: l.parsed = (parsed, self._ext_table)
|
|
|
1138 |
def _should_keep_layout(self, l: MDRLayoutElement) -> bool:
|
1139 |
if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True
|
1140 |
return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA]
|
1141 |
|
1142 |
-
# --- MDR Page Section Linking
|
1143 |
class _MDR_LinkedShape:
|
1144 |
"""Internal helper for managing layout linking across pages."""
|
|
|
1145 |
def __init__(self, layout: MDRLayoutElement): self.layout=layout; self.pre:list[MDRLayoutElement|None]=[None,None]; self.nex:list[MDRLayoutElement|None]=[None,None]
|
|
|
1146 |
@property
|
1147 |
def distance2(self) -> float: x,y=self.layout.rect.lt; return x*x+y*y
|
|
|
1148 |
class MDRPageSection:
|
1149 |
"""Represents a page's layouts for framework detection via linking."""
|
|
|
1150 |
def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]):
|
1151 |
self._page_index = page_index; self._shapes = [_MDR_LinkedShape(l) for l in layouts]; self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0]))
|
|
|
1152 |
@property
|
1153 |
def page_index(self) -> int: return self._page_index
|
|
|
1154 |
def find_framework_elements(self) -> list[MDRLayoutElement]:
|
1155 |
"""Identifies framework layouts based on links to other pages."""
|
1156 |
return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)]
|
|
|
1157 |
def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None:
|
1158 |
"""Links matching shapes between this section and the next."""
|
1159 |
if offset not in (1,2): return
|
@@ -1169,6 +1349,7 @@ class MDRPageSection:
|
|
1169 |
r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect); ovr = self._symmetric_iou(r1_rel, r2_rel)
|
1170 |
if ovr > max_ovr: max_ovr, best_s2 = ovr, s2
|
1171 |
if max_ovr >= 0.80 and best_s2 is not None: s1.nex[offset-1] = best_s2.layout; best_s2.pre[offset-1] = s1.layout # Link both ways
|
|
|
1172 |
def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool:
|
1173 |
l1, l2 = s1.layout, s2.layout; sz1, sz2 = l1.rect.size, l2.rect.size; thresh = 0.90
|
1174 |
if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: return False
|
@@ -1183,10 +1364,12 @@ class MDRPageSection:
|
|
1183 |
if max_sim > 0.75: matches += 1; if best_j != -1: used_f2[best_j] = True
|
1184 |
max_c = max(c1, c2); rate_frags = matches / max_c
|
1185 |
return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95))
|
|
|
1186 |
def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float:
|
1187 |
r1_rel = self._relative_rect(l1.rect.lt, f1.rect); r2_rel = self._relative_rect(l2.rect.lt, f2.rect)
|
1188 |
geom_sim = self._symmetric_iou(r1_rel, r2_rel); text_sim, _ = mdr_check_text_similarity(f1.text, f2.text)
|
1189 |
return (geom_sim + text_sim) / 2.0
|
|
|
1190 |
def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None:
|
1191 |
best_pair, min_dist2 = None, float('inf')
|
1192 |
for i, s1 in enumerate(self._shapes):
|
@@ -1194,10 +1377,13 @@ class MDRPageSection:
|
|
1194 |
if not match_list: continue
|
1195 |
for s2 in match_list: dist2 = s1.distance2 + s2.distance2; if dist2 < min_dist2: min_dist2, best_pair = dist2, (s1, s2)
|
1196 |
return best_pair
|
|
|
1197 |
def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool:
|
1198 |
if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx]
|
|
|
1199 |
def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle:
|
1200 |
ox, oy = origin; r=rect; return MDRRectangle(lt=(r.lt[0]-ox, r.lt[1]-oy), rt=(r.rt[0]-ox, r.rt[1]-oy), lb=(r.lb[0]-ox, r.lb[1]-oy), rb=(r.rb[0]-ox, r.rb[1]-oy))
|
|
|
1201 |
def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float:
|
1202 |
try: p1, p2 = Polygon(r1), Polygon(r2);
|
1203 |
except: return 0.0
|
@@ -1207,21 +1393,26 @@ class MDRPageSection:
|
|
1207 |
if inter.is_empty or inter.area < 1e-6: return 0.0
|
1208 |
union_area = union.area; return inter.area / union_area if union_area > 1e-6 else 1.0
|
1209 |
|
1210 |
-
# --- MDR Document Iterator
|
1211 |
_MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context
|
|
|
1212 |
@dataclass
|
1213 |
class MDRProcessingParams:
|
1214 |
"""Parameters for processing a document."""
|
1215 |
pdf: str | FitzDocument; page_indexes: Iterable[int] | None; report_progress: MDRProgressReportCallback | None
|
|
|
1216 |
class MDRDocumentIterator:
|
1217 |
"""Iterates through document pages, handles context, and calls the extraction engine."""
|
|
|
1218 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, debug_dir_path: str | None):
|
1219 |
self._debug_dir = debug_dir_path
|
1220 |
self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, ocr_for_each_layouts=(ocr_level==MDROcrLevel.OncePerLayout), extract_formula=extract_formula, extract_table_format=extract_table_format)
|
|
|
1221 |
def iterate_sections(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]:
|
1222 |
"""Yields page index, extraction result, and content layouts for each requested page."""
|
1223 |
for res, sec in self._process_and_link_sections(params):
|
1224 |
framework = set(sec.find_framework_elements()); content = [l for l in res.layouts if l not in framework]; yield sec.page_index, res, content
|
|
|
1225 |
def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[tuple[MDRExtractionResult, MDRPageSection], None, None]:
|
1226 |
queue: list[tuple[MDRExtractionResult, MDRPageSection]] = []
|
1227 |
for page_idx, res in self._run_extraction_on_pages(params):
|
@@ -1232,6 +1423,7 @@ class MDRDocumentIterator:
|
|
1232 |
queue.append((res, cur_sec))
|
1233 |
if len(queue) > _MDR_CONTEXT_PAGES: yield queue.pop(0)
|
1234 |
for res, sec in queue: yield res, sec
|
|
|
1235 |
def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]:
|
1236 |
if self._debug_dir: mdr_ensure_directory(self._debug_dir)
|
1237 |
doc, should_close = None, False
|
@@ -1254,6 +1446,7 @@ class MDRDocumentIterator:
|
|
1254 |
except Exception as e: print(f" Iterator: Page {page_idx+1} processing error: {e}")
|
1255 |
finally:
|
1256 |
if should_close and doc: doc.close()
|
|
|
1257 |
def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]:
|
1258 |
count = doc.page_count;
|
1259 |
if idxs is None: all_p = list(range(count)); return all_p, all_p
|
@@ -1261,19 +1454,22 @@ class MDRDocumentIterator:
|
|
1261 |
for i in idxs:
|
1262 |
if 0<=i<count: enable.add(i); [scan.add(j) for j in range(max(0, i-_MDR_CONTEXT_PAGES), min(count, i+_MDR_CONTEXT_PAGES+1))]
|
1263 |
return sorted(list(scan)), sorted(list(enable))
|
|
|
1264 |
def _render_page_image(self, page: FitzPage, dpi: int) -> Image:
|
1265 |
mat = FitzMatrix(dpi/72.0, dpi/72.0); pix = page.get_pixmap(matrix=mat, alpha=False)
|
1266 |
return frombytes("RGB", (pix.width, pix.height), pix.samples)
|
|
|
1267 |
def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str):
|
1268 |
try: plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy(); mdr_plot_layout(plot_img, res.layouts); plot_img.save(os.path.join(path, f"mdr_plot_page_{idx+1}.png"))
|
1269 |
except Exception as e: print(f" Iterator: Plot generation error page {idx+1}: {e}")
|
1270 |
|
1271 |
-
# --- MagicDataReadiness Main Processor
|
1272 |
class MagicPDFProcessor:
|
1273 |
"""
|
1274 |
Main class for processing PDF documents to extract structured data blocks
|
1275 |
using the MagicDataReadiness pipeline.
|
1276 |
"""
|
|
|
1277 |
def __init__(self, device: Literal["cpu", "cuda"]="cuda", model_dir_path: str="./mdr_models", ocr_level: MDROcrLevel=MDROcrLevel.Once, extract_formula: bool=True, extract_table_format: MDRExtractedTableFormat|None=None, debug_dir_path: str|None=None):
|
1278 |
"""
|
1279 |
Initializes the MagicPDFProcessor.
|
@@ -1436,7 +1632,7 @@ if __name__ == '__main__':
|
|
1436 |
print(" MagicDataReadiness PDF Processor - Example Usage")
|
1437 |
print("="*60)
|
1438 |
|
1439 |
-
# --- 1. Configuration (!!! MODIFY THESE PATHS !!!) ---
|
1440 |
# Directory where models are stored or will be downloaded
|
1441 |
# IMPORTANT: Create this directory or ensure it's writable!
|
1442 |
MDR_MODEL_DIRECTORY = "./mdr_pipeline_models"
|
@@ -1449,7 +1645,7 @@ if __name__ == '__main__':
|
|
1449 |
MDR_DEBUG_DIRECTORY = "./mdr_debug_output"
|
1450 |
|
1451 |
# Specify device ('cuda' or 'cpu').
|
1452 |
-
MDR_DEVICE = "
|
1453 |
|
1454 |
# Specify desired table format
|
1455 |
MDR_TABLE_FORMAT = MDRExtractedTableFormat.MARKDOWN
|
|
|
38 |
from PIL import ImageDraw
|
39 |
from PIL.ImageFont import load_default, FreeTypeFont
|
40 |
from shapely.geometry import Polygon
|
41 |
+
import pyclipper
|
42 |
from unicodedata import category
|
43 |
from alphabet_detector import AlphabetDetector
|
44 |
+
from munch import Munch
|
45 |
+
from transformers import LayoutLMv3ForTokenClassification
|
46 |
+
import onnxruntime
|
47 |
+
# --- HUGGING FACE HUB IMPORT ONLY BECAUSE RUNNING IN SPACES NOT NECESSARY IN PROD ---
|
48 |
+
from huggingface_hub import hf_hub_download, HfHubDownloadError
|
49 |
|
50 |
+
# --- External Dependencies ---
|
|
|
51 |
try:
|
52 |
from doclayout_yolo import YOLOv10
|
53 |
except ImportError:
|
|
|
59 |
print("Warning: Could not import LatexOCR from pix2tex.cli. LaTeX extraction will fail.")
|
60 |
LatexOCR = None
|
61 |
try:
|
|
|
62 |
pass # from struct_eqtable import build_model
|
63 |
except ImportError:
|
64 |
print("Warning: Could not import build_model from struct_eqtable. Table parsing might fail.")
|
65 |
|
66 |
# --- MagicDataReadiness Core Components ---
|
67 |
|
68 |
+
# --- MDR Utilities ---
|
69 |
+
|
70 |
def mdr_download_model(url: str, file_path: Path):
|
71 |
"""Downloads a model file from a URL to a local path."""
|
72 |
try:
|
|
|
86 |
if file_path.exists(): os.remove(file_path)
|
87 |
raise e
|
88 |
|
89 |
+
# --- MDR Utilities ---
|
90 |
def mdr_ensure_directory(path: str) -> str:
|
91 |
"""Ensures a directory exists, creating it if necessary."""
|
92 |
path = os.path.abspath(path)
|
|
|
147 |
return p1.intersection(p2).area
|
148 |
except: return 0.0
|
149 |
|
150 |
+
# --- MDR Data Structures ---
|
151 |
@dataclass
|
152 |
class MDROcrFragment:
|
153 |
"""Represents a fragment of text identified by OCR."""
|
154 |
order: int; text: str; rank: float; rect: MDRRectangle
|
155 |
+
|
156 |
class MDRLayoutClass(Enum):
|
157 |
"""Enumeration of different layout types identified."""
|
158 |
TITLE=0; PLAIN_TEXT=1; ABANDON=2; FIGURE=3; FIGURE_CAPTION=4; TABLE=5; TABLE_CAPTION=6; TABLE_FOOTNOTE=7; ISOLATE_FORMULA=8; FORMULA_CAPTION=9
|
159 |
+
|
160 |
class MDRTableLayoutParsedFormat(Enum):
|
161 |
"""Enumeration for formats of parsed table content."""
|
162 |
LATEX=auto(); MARKDOWN=auto(); HTML=auto()
|
163 |
+
|
164 |
@dataclass
|
165 |
class MDRBaseLayoutElement:
|
166 |
"""Base class for layout elements found on a page."""
|
167 |
rect: MDRRectangle; fragments: list[MDROcrFragment]
|
168 |
+
|
169 |
@dataclass
|
170 |
class MDRPlainLayoutElement(MDRBaseLayoutElement):
|
171 |
"""Layout element for plain text, titles, captions, figures, etc."""
|
172 |
cls: Literal[MDRLayoutClass.TITLE, MDRLayoutClass.PLAIN_TEXT, MDRLayoutClass.ABANDON, MDRLayoutClass.FIGURE, MDRLayoutClass.FIGURE_CAPTION, MDRLayoutClass.TABLE_CAPTION, MDRLayoutClass.TABLE_FOOTNOTE, MDRLayoutClass.FORMULA_CAPTION]
|
173 |
+
|
174 |
@dataclass
|
175 |
class MDRTableLayoutElement(MDRBaseLayoutElement):
|
176 |
"""Layout element specifically for tables."""
|
177 |
parsed: tuple[str, MDRTableLayoutParsedFormat] | None; cls: Literal[MDRLayoutClass.TABLE] = MDRLayoutClass.TABLE
|
178 |
+
|
179 |
@dataclass
|
180 |
class MDRFormulaLayoutElement(MDRBaseLayoutElement):
|
181 |
"""Layout element specifically for formulas."""
|
182 |
latex: str | None; cls: Literal[MDRLayoutClass.ISOLATE_FORMULA] = MDRLayoutClass.ISOLATE_FORMULA
|
183 |
+
|
184 |
MDRLayoutElement = MDRPlainLayoutElement | MDRTableLayoutElement | MDRFormulaLayoutElement # Type alias
|
185 |
+
|
186 |
@dataclass
|
187 |
class MDRExtractionResult:
|
188 |
"""Holds the complete result of extracting from a single page image."""
|
189 |
rotation: float; layouts: list[MDRLayoutElement]; extracted_image: Image; adjusted_image: Image | None
|
190 |
|
191 |
+
# --- MDR Data Structures ---
|
192 |
+
|
193 |
MDRProgressReportCallback: TypeAlias = Callable[[int, int], None]
|
194 |
+
|
195 |
class MDROcrLevel(Enum): Once=auto(); OncePerLayout=auto()
|
196 |
+
|
197 |
class MDRExtractedTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); DISABLE=auto()
|
198 |
+
|
199 |
class MDRTextKind(Enum): TITLE=0; PLAIN_TEXT=1; ABANDON=2
|
200 |
+
|
201 |
@dataclass
|
202 |
class MDRTextSpan:
|
203 |
"""Represents a span of text content within a block."""
|
204 |
content: str; rank: float; rect: MDRRectangle
|
205 |
+
|
206 |
@dataclass
|
207 |
class MDRBasicBlock:
|
208 |
"""Base class for structured blocks extracted from the document."""
|
209 |
rect: MDRRectangle; texts: list[MDRTextSpan]; font_size: float # Relative font size (0-1)
|
210 |
+
|
211 |
@dataclass
|
212 |
class MDRTextBlock(MDRBasicBlock):
|
213 |
"""A structured block containing text content."""
|
214 |
kind: MDRTextKind; has_paragraph_indentation: bool = False; last_line_touch_end: bool = False
|
215 |
+
|
216 |
class MDRTableFormat(Enum): LATEX=auto(); MARKDOWN=auto(); HTML=auto(); UNRECOGNIZABLE=auto()
|
217 |
+
|
218 |
@dataclass
|
219 |
class MDRTableBlock(MDRBasicBlock):
|
220 |
"""A structured block representing a table."""
|
221 |
content: str; format: MDRTableFormat; image: Image # Image clip of the table
|
222 |
+
|
223 |
@dataclass
|
224 |
class MDRFormulaBlock(MDRBasicBlock):
|
225 |
"""A structured block representing a formula."""
|
226 |
content: str | None; image: Image # Image clip of the formula
|
227 |
+
|
228 |
@dataclass
|
229 |
class MDRFigureBlock(MDRBasicBlock):
|
230 |
"""A structured block representing a figure/image."""
|
231 |
image: Image # Image clip of the figure
|
232 |
+
|
233 |
MDRAssetBlock = MDRTableBlock | MDRFormulaBlock | MDRFigureBlock # Type alias
|
234 |
+
|
235 |
MDRStructuredBlock = MDRTextBlock | MDRAssetBlock # Type alias
|
236 |
|
237 |
+
# --- MDR Utilities ---
|
238 |
def mdr_similarity_ratio(v1: float, v2: float) -> float:
|
239 |
"""Calculates the ratio of the smaller value to the larger value (0-1)."""
|
240 |
if v1==0 and v2==0: return 1.0;
|
241 |
if v1<0 or v2<0: return 0.0;
|
242 |
v1, v2 = (v2, v1) if v1 > v2 else (v1, v2);
|
243 |
return 1.0 if v2==0 else v1/v2
|
244 |
+
|
245 |
def mdr_intersection_bounds_size(r1: MDRRectangle, r2: MDRRectangle) -> tuple[float, float]:
|
246 |
"""Calculates width/height of the intersection bounding box."""
|
247 |
try:
|
|
|
251 |
if inter.is_empty: return 0.0, 0.0;
|
252 |
minx, miny, maxx, maxy = inter.bounds; return maxx-minx, maxy-miny
|
253 |
except: return 0.0, 0.0
|
254 |
+
|
255 |
_MDR_CJKA_PATTERN = re.compile(r"[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3\u0600-\u06ff]")
|
256 |
+
|
257 |
def mdr_contains_cjka(text: str):
|
258 |
"""Checks if text contains Chinese, Japanese, Korean, or Arabic chars."""
|
259 |
return bool(_MDR_CJKA_PATTERN.search(text)) if text else False
|
260 |
|
261 |
+
# --- MDR Text Processing ---
|
262 |
class _MDR_TokenPhase(Enum): Init=0; Letter=1; Character=2; Number=3; Space=4
|
263 |
+
|
264 |
_mdr_alphabet_detector = AlphabetDetector()
|
265 |
+
|
266 |
def _mdr_is_letter(char: str):
|
267 |
if not category(char).startswith("L"): return False
|
268 |
try: return _mdr_alphabet_detector.is_latin(char) or _mdr_alphabet_detector.is_cyrillic(char) or _mdr_alphabet_detector.is_greek(char) or _mdr_alphabet_detector.is_hebrew(char)
|
269 |
except: return False
|
270 |
+
|
271 |
def mdr_split_into_words(text: str):
|
272 |
"""Splits text into words, numbers, and individual non-alphanumeric chars."""
|
273 |
if not text: return;
|
|
|
287 |
if is_s: phase=_MDR_TokenPhase.Space
|
288 |
else: yield char; phase=_MDR_TokenPhase.Character
|
289 |
if phase in (_MDR_TokenPhase.Letter, _MDR_TokenPhase.Number): w=buf.getvalue(); yield w if w else None
|
290 |
+
|
291 |
def mdr_check_text_similarity(t1: str, t2: str) -> tuple[float, int]:
|
292 |
"""Calculates word-based similarity between two texts."""
|
293 |
w1, w2 = list(mdr_split_into_words(t1)), list(mdr_split_into_words(t2)); l1, l2 = len(w1), len(w2)
|
|
|
300 |
if not taken[i] and word1==word2: taken[i]=True; matches+=1; break
|
301 |
mismatches = l2 - matches; return 1.0 - (mismatches/l2), l2
|
302 |
|
303 |
+
# --- MDR Geometry Processing ---
|
304 |
class MDRRotationAdjuster:
|
305 |
"""Adjusts point coordinates based on image rotation."""
|
306 |
+
|
307 |
def __init__(self, origin_size: tuple[int, int], new_size: tuple[int, int], rotation: float, to_origin_coordinate: bool):
|
308 |
fs, ts = (new_size, origin_size) if to_origin_coordinate else (origin_size, new_size)
|
309 |
self._rot = rotation if to_origin_coordinate else -rotation
|
310 |
self._c_off = (fs[0]/2.0, fs[1]/2.0); self._n_off = (ts[0]/2.0, ts[1]/2.0)
|
311 |
+
|
312 |
def adjust(self, point: MDRPoint) -> MDRPoint:
|
313 |
x, y = point[0]-self._c_off[0], point[1]-self._c_off[1]
|
314 |
if x!=0 or y!=0: cos_r, sin_r = cos(self._rot), sin(self._rot); x, y = x*cos_r-y*sin_r, x*sin_r+y*cos_r
|
315 |
return x+self._n_off[0], y+self._n_off[1]
|
316 |
+
|
317 |
def mdr_normalize_vertical_rotation(rot: float) -> float:
|
318 |
while rot >= pi: rot -= pi;
|
319 |
while rot < 0: rot += pi;
|
320 |
return rot
|
321 |
+
|
322 |
def _mdr_get_rectangle_angles(rect: MDRRectangle) -> tuple[list[float], list[float]] | None:
|
323 |
h_angs, v_angs = [], []
|
324 |
for i, (p1, p2) in enumerate(rect.segments):
|
|
|
330 |
else: v_angs.append(ang)
|
331 |
if not h_angs or not v_angs: return None
|
332 |
return h_angs, v_angs
|
333 |
+
|
334 |
def _mdr_normalize_horizontal_angles(rots: list[float]) -> list[float]: return rots
|
335 |
+
|
336 |
def _mdr_find_median(data: list[float]) -> float:
|
337 |
if not data: return 0.0; s_data = sorted(data); n = len(s_data);
|
338 |
return s_data[n//2] if n%2==1 else (s_data[n//2-1]+s_data[n//2])/2.0
|
339 |
+
|
340 |
def _mdr_find_mean(data: list[float]) -> float: return sum(data)/len(data) if data else 0.0
|
341 |
+
|
342 |
def mdr_calculate_image_rotation(frags: list[MDROcrFragment]) -> float:
|
343 |
all_h, all_v = [], [];
|
344 |
for f in frags:
|
|
|
351 |
while rot_est >= pi/2: rot_est -= pi;
|
352 |
while rot_est < -pi/2: rot_est += pi;
|
353 |
return rot_est
|
354 |
+
|
355 |
def mdr_calculate_rectangle_rotation(rect: MDRRectangle) -> tuple[float, float]:
|
356 |
res = _mdr_get_rectangle_angles(rect);
|
357 |
if res is None: return 0.0, pi/2.0;
|
|
|
359 |
h_rots = _mdr_normalize_horizontal_angles(h_rots); v_rots = [mdr_normalize_vertical_rotation(a) for a in v_rots]
|
360 |
return _mdr_find_mean(h_rots), _mdr_find_mean(v_rots)
|
361 |
|
362 |
+
# --- MDR ONNX OCR Internals ---
|
363 |
class _MDR_PredictBase:
|
364 |
"""Base class for ONNX model prediction components."""
|
365 |
+
|
366 |
def get_onnx_session(self, model_path: str, use_gpu: bool):
|
367 |
try:
|
368 |
sess_opts = onnxruntime.SessionOptions(); sess_opts.log_severity_level = 3
|
|
|
375 |
if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers():
|
376 |
print(" CUDAExecutionProvider not available. Check ONNXRuntime-GPU installation and CUDA setup.")
|
377 |
raise e
|
378 |
+
|
379 |
def get_output_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_outputs()]
|
380 |
+
|
381 |
def get_input_name(self, sess: onnxruntime.InferenceSession) -> List[str]: return [n.name for n in sess.get_inputs()]
|
382 |
+
|
383 |
def get_input_feed(self, names: List[str], img_np: np.ndarray) -> Dict[str, np.ndarray]: return {name: img_np for name in names}
|
384 |
|
385 |
+
# --- MDR ONNX OCR Internals ---
|
386 |
class _MDR_NormalizeImage:
|
387 |
+
|
388 |
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
389 |
self.scale = np.float32(eval(scale) if isinstance(scale, str) else (scale if scale is not None else 1.0 / 255.0))
|
390 |
mean = mean if mean is not None else [0.485, 0.456, 0.406]; std = std if std is not None else [0.229, 0.224, 0.225]
|
391 |
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3); self.mean = np.array(mean).reshape(shape).astype('float32'); self.std = np.array(std).reshape(shape).astype('float32')
|
392 |
+
|
393 |
def __call__(self, data): img = data['image']; img = np.array(img) if isinstance(img, Image) else img; data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std; return data
|
394 |
+
|
395 |
class _MDR_DetResizeForTest:
|
396 |
+
|
397 |
def __init__(self, **kwargs):
|
398 |
self.resize_type = 0; self.keep_ratio = False
|
399 |
if 'image_shape' in kwargs: self.image_shape = kwargs['image_shape']; self.resize_type = 1; self.keep_ratio = kwargs.get('keep_ratio', False)
|
400 |
elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len']; self.limit_type = kwargs.get('limit_type', 'min')
|
401 |
elif 'resize_long' in kwargs: self.resize_type = 2; self.resize_long = kwargs.get('resize_long', 960)
|
402 |
else: self.limit_side_len = 736; self.limit_type = 'min'
|
403 |
+
|
404 |
def __call__(self, data):
|
405 |
img = data['image']; src_h, src_w, _ = img.shape
|
406 |
if src_h + src_w < 64: img = self._pad(img)
|
|
|
409 |
else: img, ratios = self._resize1(img)
|
410 |
if img is None: return None
|
411 |
data['image'] = img; data['shape'] = np.array([src_h, src_w, ratios[0], ratios[1]]); return data
|
412 |
+
|
413 |
def _pad(self, im, v=0): h,w,c=im.shape; p=np.zeros((max(32,h),max(32,w),c),np.uint8)+v; p[:h,:w,:]=im; return p
|
414 |
+
|
415 |
def _resize1(self, img): rh,rw=self.image_shape; oh,ow=img.shape[:2]; if self.keep_ratio: rw=ow*rh/oh; N=math.ceil(rw/32); rw=N*32; r_h,r_w=float(rh)/oh,float(rw)/ow; img=cv2.resize(img,(int(rw),int(rh))); return img,[r_h,r_w]
|
416 |
+
|
417 |
def _resize0(self, img): lsl=self.limit_side_len; h,w,_=img.shape; r=1.0; if self.limit_type=='max': r=float(lsl)/max(h,w) if max(h,w)>lsl else 1.0; elif self.limit_type=='min': r=float(lsl)/min(h,w) if min(h,w)<lsl else 1.0; elif self.limit_type=='resize_long': r=float(lsl)/max(h,w); else: raise Exception('Unsupported'); rh,rw=int(h*r),int(w*r); rh=max(int(round(rh/32)*32),32); rw=max(int(round(rw/32)*32),32); if int(rw)<=0 or int(rh)<=0: return None,(None,None); img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
418 |
+
|
419 |
def _resize2(self, img): h,w,_=img.shape; rl=self.resize_long; r=float(rl)/max(h,w); rh,rw=int(h*r),int(w*r); ms=128; rh=(rh+ms-1)//ms*ms; rw=(rw+ms-1)//ms*ms; img=cv2.resize(img,(int(rw),int(rh))); r_h,r_w=rh/float(h),rw/float(w); return img,[r_h,r_w]
|
420 |
+
|
421 |
class _MDR_ToCHWImage:
|
422 |
+
|
423 |
def __call__(self, data): img=data['image']; img=np.array(img) if isinstance(img,Image) else img; data['image']=img.transpose((2,0,1)); return data
|
424 |
+
|
425 |
class _MDR_KeepKeys:
|
426 |
+
|
427 |
def __init__(self, keep_keys, **kwargs): self.keep_keys=keep_keys
|
428 |
+
|
429 |
def __call__(self, data): return [data[key] for key in self.keep_keys]
|
430 |
|
|
|
431 |
def mdr_ocr_transform(data, ops=None):
|
432 |
ops = ops if ops is not None else [];
|
433 |
for op in ops: data = op(data); if data is None: return None;
|
434 |
return data
|
435 |
+
|
436 |
def mdr_ocr_create_operators(op_param_list, global_config=None):
|
437 |
ops = []
|
438 |
for operator in op_param_list:
|
|
|
444 |
else: raise ValueError(f"Operator class '{op_class_name}' not found.")
|
445 |
return ops
|
446 |
|
|
|
447 |
class _MDR_DBPostProcess:
|
448 |
+
|
449 |
def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5, use_dilation=False, score_mode="fast", box_type='quad', **kwargs):
|
450 |
self.thresh, self.box_thresh, self.max_cand = thresh, box_thresh, max_candidates; self.unclip_r, self.min_sz, self.score_m, self.box_t = unclip_ratio, 3, score_mode, box_type
|
451 |
assert score_mode in ["slow", "fast"]; self.dila_k = np.array([[1,1],[1,1]], dtype=np.uint8) if use_dilation else None
|
452 |
+
|
453 |
def _polygons_from_bitmap(self, pred, bmp, dw, dh):
|
454 |
h, w = bmp.shape; boxes, scores = [], []
|
455 |
contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
466 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
467 |
boxes.append(box.tolist()); scores.append(score)
|
468 |
return boxes, scores
|
469 |
+
|
470 |
def _boxes_from_bitmap(self, pred, bmp, dw, dh):
|
471 |
h, w = bmp.shape; contours, _ = cv2.findContours((bmp*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
472 |
num_contours = min(len(contours), self.max_cand); boxes, scores = [], []
|
|
|
482 |
box = np.array(box); box[:,0]=np.clip(np.round(box[:,0]/w*dw),0,dw); box[:,1]=np.clip(np.round(box[:,1]/h*dh),0,dh)
|
483 |
boxes.append(box.astype("int32")); scores.append(score)
|
484 |
return np.array(boxes, dtype="int32"), scores
|
485 |
+
|
486 |
def _unclip(self, box, ratio):
|
487 |
poly = Polygon(box); dist = poly.area*ratio/poly.length; offset = pyclipper.PyclipperOffset(); offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
488 |
expanded = offset.Execute(dist);
|
489 |
if not expanded: raise ValueError("Unclip failed"); return np.array(expanded[0])
|
490 |
+
|
491 |
def _get_mini_boxes(self, contour):
|
492 |
bb = cv2.minAreaRect(contour); pts = sorted(list(cv2.boxPoints(bb)), key=lambda x:x[0])
|
493 |
i1,i4 = (0,1) if pts[1][1]>pts[0][1] else (1,0); i2,i3 = (2,3) if pts[3][1]>pts[2][1] else (3,2)
|
494 |
box = [pts[i1], pts[i2], pts[i3], pts[i4]]; return box, min(bb[1])
|
495 |
+
|
496 |
def _box_score_fast(self, bmp, box):
|
497 |
h,w = bmp.shape[:2]; xmin=np.clip(np.floor(box[:,0].min()).astype("int32"),0,w-1); xmax=np.clip(np.ceil(box[:,0].max()).astype("int32"),0,w-1)
|
498 |
ymin=np.clip(np.floor(box[:,1].min()).astype("int32"),0,h-1); ymax=np.clip(np.ceil(box[:,1].max()).astype("int32"),0,h-1)
|
499 |
mask = np.zeros((ymax-ymin+1, xmax-xmin+1), dtype=np.uint8); box[:,0]-=xmin; box[:,1]-=ymin
|
500 |
cv2.fillPoly(mask, box.reshape(1,-1,2).astype("int32"), 1);
|
501 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
502 |
+
|
503 |
def _box_score_slow(self, bmp, contour): # Not used if fast
|
504 |
h,w = bmp.shape[:2]; contour = np.reshape(contour.copy(),(-1,2)); xmin=np.clip(np.min(contour[:,0]),0,w-1); xmax=np.clip(np.max(contour[:,0]),0,w-1)
|
505 |
ymin=np.clip(np.min(contour[:,1]),0,h-1); ymax=np.clip(np.max(contour[:,1]),0,h-1); mask=np.zeros((ymax-ymin+1,xmax-xmin+1),dtype=np.uint8)
|
506 |
contour[:,0]-=xmin; contour[:,1]-=ymin; cv2.fillPoly(mask, contour.reshape(1,-1,2).astype("int32"), 1);
|
507 |
return cv2.mean(bmp[ymin:ymax+1, xmin:xmax+1], mask)[0] if np.sum(mask)>0 else 0.0
|
508 |
+
|
509 |
def __call__(self, outs_dict, shape_list):
|
510 |
pred = outs_dict['maps'][:,0,:,:]; seg = pred > self.thresh; boxes_batch = []
|
511 |
for batch_idx in range(pred.shape[0]):
|
|
|
516 |
boxes_batch.append({'points': boxes})
|
517 |
return boxes_batch
|
518 |
|
|
|
519 |
class _MDR_TextDetector(_MDR_PredictBase):
|
520 |
+
|
521 |
def __init__(self, args):
|
522 |
super().__init__(); self.args = args
|
523 |
pre_ops = [{'DetResizeForTest': {'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type}}, {'NormalizeImage': {'std': [0.229,0.224,0.225], 'mean': [0.485,0.456,0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}]
|
|
|
526 |
self.post_op = _MDR_DBPostProcess(**post_params)
|
527 |
self.sess = self.get_onnx_session(args.det_model_dir, args.use_gpu)
|
528 |
self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
529 |
+
|
530 |
def _order_pts(self, pts): r=np.zeros((4,2),dtype="float32"); s=pts.sum(axis=1); r[0]=pts[np.argmin(s)]; r[2]=pts[np.argmax(s)]; tmp=np.delete(pts,(np.argmin(s),np.argmax(s)),axis=0); d=np.diff(np.array(tmp),axis=1); r[1]=tmp[np.argmin(d)]; r[3]=tmp[np.argmax(d)]; return r
|
531 |
+
|
532 |
def _clip_pts(self, pts, h, w): pts[:,0]=np.clip(pts[:,0],0,w-1); pts[:,1]=np.clip(pts[:,1],0,h-1); return pts
|
533 |
+
|
534 |
def _filter_quad(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._order_pts(box); box=self._clip_pts(box,h,w); rw=int(np.linalg.norm(box[0]-box[1])); rh=int(np.linalg.norm(box[0]-box[3])); if rw<=3 or rh<=3: continue; new_boxes.append(box); return np.array(new_boxes)
|
535 |
+
|
536 |
def _filter_poly(self, boxes, shape): h,w=shape[0:2]; new_boxes=[]; for box in boxes: box=np.array(box) if isinstance(box,list) else box; box=self._clip_pts(box,h,w); if Polygon(box).area<10: continue; new_boxes.append(box); return np.array(new_boxes)
|
537 |
+
|
538 |
def __call__(self, img):
|
539 |
ori_im = img.copy(); data = {"image": img}; data = mdr_ocr_transform(data, self.pre_op)
|
540 |
if data is None: return None; img, shape_list = data;
|
|
|
543 |
preds = {"maps": outputs[0]}; post_res = self.post_op(preds, shape_list); boxes = post_res[0]['points']
|
544 |
return self._filter_poly(boxes, ori_im.shape) if self.args.det_box_type=='poly' else self._filter_quad(boxes, ori_im.shape)
|
545 |
|
|
|
546 |
class _MDR_ClsPostProcess:
|
547 |
+
|
548 |
def __init__(self, label_list=None, **kwargs): self.labels = label_list if label_list else {0:'0', 1:'180'}
|
549 |
+
|
550 |
def __call__(self, preds, label=None, *args, **kwargs):
|
551 |
preds = np.array(preds) if not isinstance(preds, np.ndarray) else preds; idxs = preds.argmax(axis=1)
|
552 |
return [(self.labels[idx], float(preds[i,idx])) for i,idx in enumerate(idxs)]
|
553 |
|
|
|
554 |
class _MDR_TextClassifier(_MDR_PredictBase):
|
555 |
+
|
556 |
def __init__(self, args):
|
557 |
super().__init__(); self.shape = tuple(map(int, args.cls_image_shape.split(','))) if isinstance(args.cls_image_shape, str) else args.cls_image_shape
|
558 |
self.batch_num = args.cls_batch_num; self.thresh = args.cls_thresh; self.post_op = _MDR_ClsPostProcess(label_list=args.label_list)
|
559 |
self.sess = self.get_onnx_session(args.cls_model_dir, args.use_gpu); self.input_name = self.get_input_name(self.sess); self.output_name = self.get_output_name(self.sess)
|
560 |
+
|
561 |
def _resize_norm(self, img):
|
562 |
imgC,imgH,imgW = self.shape; h,w = img.shape[:2]; r=w/float(h) if h>0 else 0; rw=int(math.ceil(imgH*r)); rw=min(rw,imgW)
|
563 |
resized = cv2.resize(img,(rw,imgH)); resized = resized.astype("float32")
|
564 |
if imgC==1: resized = resized/255.0; resized = resized[np.newaxis,:]
|
565 |
else: resized = resized.transpose((2,0,1))/255.0
|
566 |
resized -= 0.5; resized /= 0.5; padding = np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:rw]=resized; return padding
|
567 |
+
|
568 |
def __call__(self, img_list):
|
569 |
if not img_list: return img_list, []; img_list_cp = copy.deepcopy(img_list); num = len(img_list_cp)
|
570 |
ratios = [img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list_cp]; indices = np.argsort(np.array(ratios))
|
|
|
580 |
if "180" in label and score > self.thresh: img_list[orig_idx] = cv2.rotate(img_list[orig_idx], cv2.ROTATE_180)
|
581 |
return img_list, results
|
582 |
|
|
|
583 |
class _MDR_BaseRecLabelDecode:
|
584 |
+
|
585 |
def __init__(self, char_path=None, use_space=False):
|
586 |
self.beg, self.end, self.rev = "sos", "eos", False; self.chars = []
|
587 |
if char_path is None: self.chars = list("0123456789abcdefghijklmnopqrstuvwxyz")
|
|
|
592 |
if any("\u0600"<=c<="\u06FF" for c in self.chars): self.rev=True
|
593 |
except FileNotFoundError: print(f"Warn: Dict not found {char_path}"); self.chars=list("0123456789abcdefghijklmnopqrstuvwxyz"); if use_space: self.chars.append(" ")
|
594 |
d_char = self.add_special_char(list(self.chars)); self.dict={c:i for i,c in enumerate(d_char)}; self.character=d_char
|
595 |
+
|
596 |
def add_special_char(self, chars): return chars
|
597 |
+
|
598 |
def get_ignored_tokens(self): return []
|
599 |
+
|
600 |
def _reverse(self, pred): res=[]; cur=""; for c in pred: if not re.search("[a-zA-Z0-9 :*./%+-]",c): res.extend([cur,c] if cur!="" else [c]); cur="" else: cur+=c; if cur!="": res.append(cur); return "".join(res[::-1])
|
601 |
+
|
602 |
def decode(self, idxs, probs=None, remove_dup=False):
|
603 |
res=[]; ignored=self.get_ignored_tokens(); bs=len(idxs)
|
604 |
for b_idx in range(bs):
|
|
|
612 |
if self.rev: txt=self._reverse(txt)
|
613 |
res.append((txt, float(np.mean(conf_l))))
|
614 |
return res
|
615 |
+
|
616 |
class _MDR_CTCLabelDecode(_MDR_BaseRecLabelDecode):
|
617 |
def __init__(self, char_path=None, use_space=False, **kwargs): super().__init__(char_path, use_space)
|
618 |
def add_special_char(self, chars): return ["blank"]+chars
|
|
|
621 |
preds = preds[-1] if isinstance(preds,(tuple,list)) else preds; preds = np.array(preds) if not isinstance(preds,np.ndarray) else preds
|
622 |
idxs=preds.argmax(axis=2); probs=preds.max(axis=2); txt=self.decode(idxs, probs, remove_dup=True); return txt
|
623 |
|
|
|
624 |
class _MDR_TextRecognizer(_MDR_PredictBase):
|
625 |
+
|
626 |
def __init__(self, args):
|
627 |
super().__init__(); shape_str=getattr(args,'rec_image_shape',"3,48,320"); self.shape=tuple(map(int,shape_str.split(',')))
|
628 |
self.batch_num=getattr(args,'rec_batch_num',6); self.algo=getattr(args,'rec_algorithm','SVTR_LCNet')
|
629 |
self.post_op=_MDR_CTCLabelDecode(char_path=args.rec_char_dict_path, use_space=getattr(args,'use_space_char',True))
|
630 |
self.sess=self.get_onnx_session(args.rec_model_dir, args.use_gpu); self.input_name=self.get_input_name(self.sess); self.output_name=self.get_output_name(self.sess)
|
631 |
+
|
632 |
def _resize_norm(self, img, max_r):
|
633 |
imgC,imgH,imgW = self.shape; h,w = img.shape[:2];
|
634 |
if h==0 or w==0: return np.zeros((imgC,imgH,imgW),dtype=np.float32)
|
|
|
638 |
if len(resized.shape)==2: resized=resized[:,:,np.newaxis]
|
639 |
resized=resized.transpose((2,0,1))/255.0; resized-=0.5; resized/=0.5
|
640 |
padding=np.zeros((imgC,imgH,imgW),dtype=np.float32); padding[:,:,0:tw]=resized; return padding
|
641 |
+
|
642 |
def __call__(self, img_list):
|
643 |
if not img_list: return []; num=len(img_list); ratios=[img.shape[1]/float(img.shape[0]) if img.shape[0]>0 else 0 for img in img_list]
|
644 |
indices=np.argsort(np.array(ratios)); results=[["",0.0]]*num; batch_n=self.batch_num
|
|
|
652 |
for i in range(len(rec_out)): results[indices[start+i]]=rec_out[i]
|
653 |
return results
|
654 |
|
655 |
+
# --- MDR ONNX OCR System ---
|
656 |
class _MDR_TextSystem:
|
657 |
+
|
658 |
def __init__(self, args):
|
659 |
class ArgsObject: # Helper to access dict args with dot notation
|
660 |
def __init__(self, **entries): self.__dict__.update(entries)
|
|
|
666 |
self.drop_score = getattr(args, 'drop_score', 0.5)
|
667 |
self.classifier = _MDR_TextClassifier(args) if self.use_cls else None
|
668 |
self.crop_idx = 0; self.save_crop = getattr(args, 'save_crop_res', False); self.crop_dir = getattr(args, 'crop_res_save_dir', "./output/mdr_crop_res")
|
669 |
+
|
670 |
def _sort_boxes(self, boxes):
|
671 |
if boxes is None or len(boxes)==0: return []
|
672 |
def key(box): min_y=min(p[1] for p in box); min_x=min(p[0] for p in box); return (min_y, min_x)
|
673 |
try: return list(sorted(boxes, key=key))
|
674 |
except: return list(boxes) # Fallback
|
675 |
+
|
676 |
def __call__(self, img, classify=True):
|
677 |
ori_im = img.copy(); boxes = self.detector(img)
|
678 |
if boxes is None or len(boxes)==0: return [], []
|
|
|
694 |
if score >= self.drop_score: final_boxes.append(box); final_rec.append(res)
|
695 |
if self.save_crop: self._save_crops(crops, rec_res)
|
696 |
return final_boxes, final_rec
|
697 |
+
|
698 |
def _save_crops(self, crops, recs):
|
699 |
mdr_ensure_directory(self.crop_dir); num = len(crops)
|
700 |
for i in range(num): txt, score = recs[i]; safe=re.sub(r'\W+', '_', txt)[:20]; fname=f"crop_{self.crop_idx+i}_{safe}_{score:.2f}.jpg"; cv2.imwrite(os.path.join(self.crop_dir, fname), crops[i])
|
701 |
self.crop_idx += num
|
702 |
|
703 |
+
# --- MDR ONNX OCR Utilities ---
|
704 |
def mdr_get_rotated_crop(img, points):
|
705 |
"""Crops and perspective-transforms a quadrilateral region."""
|
706 |
pts = np.array(points, dtype="float32"); assert len(pts)==4
|
|
|
712 |
dh, dw = dst.shape[0:2]
|
713 |
if dh>0 and dw>0 and dh*1.0/dw >= 1.5: dst = cv2.rotate(dst, cv2.ROTATE_90_CLOCKWISE)
|
714 |
return dst
|
715 |
+
|
716 |
def mdr_get_min_area_crop(img, points):
|
717 |
"""Crops the minimum area rectangle containing the points."""
|
718 |
bb = cv2.minAreaRect(np.array(points).astype(np.int32)); box_pts = cv2.boxPoints(bb)
|
719 |
return mdr_get_rotated_crop(img, box_pts)
|
720 |
|
721 |
+
# --- MDR Layout Processing ---
|
722 |
_MDR_INCLUDES_MIN_RATE = 0.99
|
723 |
+
|
724 |
class _MDR_OverlapMatrixContext:
|
725 |
+
|
726 |
def __init__(self, layouts: list[MDRLayoutElement]):
|
727 |
length = len(layouts); self.polys: list[Polygon|None] = []
|
728 |
for l in layouts:
|
|
|
736 |
p2 = self.polys[j];
|
737 |
if p2 is None: continue
|
738 |
r_ij = self._rate(p1, p2); r_ji = self._rate(p2, p1); self.matrix[i][j]=r_ij; self.matrix[j][i]=r_ji
|
739 |
+
|
740 |
def _rate(self, p1: Polygon, p2: Polygon) -> float: # Rate p1 covers p2
|
741 |
try: inter = p1.intersection(p2);
|
742 |
except: return 0.0
|
|
|
745 |
_, _, px1, py1 = p2.bounds; pw, ph = px1-p2.bounds[0], py1-p2.bounds[1]
|
746 |
if pw < 1e-6 or ph < 1e-6: return 0.0
|
747 |
wr = min(iw/pw, 1.0); hr = min(ih/ph, 1.0); return (wr+hr)/2.0
|
748 |
+
|
749 |
def others(self, idx: int):
|
750 |
for i, r in enumerate(self.matrix[idx]):
|
751 |
if i != idx and i not in self.removed: yield r
|
752 |
+
|
753 |
def includes(self, idx: int): # Layouts included BY idx
|
754 |
for i, r in enumerate(self.matrix[idx]):
|
755 |
if i != idx and i not in self.removed and r >= _MDR_INCLUDES_MIN_RATE:
|
|
|
811 |
for i, f in enumerate(merged): f.order = i
|
812 |
return merged
|
813 |
|
814 |
+
# --- MDR Layout Processing ---
|
815 |
_MDR_CORRECTION_MIN_OVERLAP = 0.5
|
816 |
+
|
817 |
def mdr_correct_layout_fragments(ocr_engine: 'MDROcrEngine', source_img: Image, layout: MDRLayoutElement):
|
818 |
if not layout.fragments: return;
|
819 |
try:
|
|
|
848 |
final = [n if n.rank >= o.rank else o for o, n in matched]; final.extend(unmatched_orig); final.extend(unmatched_new)
|
849 |
layout.fragments = final; layout.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
850 |
|
851 |
+
# --- MDR OCR Engine ---
|
852 |
+
|
853 |
_MDR_OCR_MODELS = {"det": ("ppocrv4","det","det.onnx"), "cls": ("ppocrv4","cls","cls.onnx"), "rec": ("ppocrv4","rec","rec.onnx"), "keys": ("ch_ppocr_server_v2.0","ppocr_keys_v1.txt")}
|
854 |
+
|
855 |
_MDR_OCR_URL_BASE = "https://huggingface.co/moskize/OnnxOCR/resolve/main/"
|
856 |
+
|
857 |
@dataclass
|
858 |
class _MDR_ONNXParams: # Simplified container
|
859 |
use_gpu: bool; det_model_dir: str; cls_model_dir: str; rec_model_dir: str; rec_char_dict_path: str
|
|
|
864 |
|
865 |
class MDROcrEngine:
|
866 |
"""Handles OCR detection and recognition using ONNX models."""
|
867 |
+
|
868 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str):
|
869 |
self._device = device; self._model_dir = mdr_ensure_directory(model_dir_path)
|
870 |
self._text_system: _MDR_TextSystem | None = None; self._onnx_params: _MDR_ONNXParams | None = None
|
871 |
self._ensure_models(); self._get_system() # Init on creation
|
872 |
+
|
873 |
def _ensure_models(self):
|
874 |
for key, parts in _MDR_OCR_MODELS.items():
|
875 |
fp = Path(self._model_dir) / Path(*parts)
|
876 |
if not fp.exists(): print(f"Downloading MDR OCR model: {fp.name}..."); url = _MDR_OCR_URL_BASE + "/".join(parts); mdr_download_model(url, fp)
|
877 |
+
|
878 |
def _get_system(self) -> _MDR_TextSystem | None:
|
879 |
if self._text_system is None:
|
880 |
paths = {k: str(Path(self._model_dir)/Path(*p)) for k,p in _MDR_OCR_MODELS.items()}
|
|
|
882 |
try: self._text_system = _MDR_TextSystem(self._onnx_params); print(f"MDR OCR System initialized.")
|
883 |
except Exception as e: print(f"ERROR initializing MDR OCR System: {e}"); self._text_system = None
|
884 |
return self._text_system
|
885 |
+
|
886 |
def find_text_fragments(self, image_np: np.ndarray) -> Generator[MDROcrFragment, None, None]:
|
887 |
"""Finds and recognizes text fragments in a NumPy image (BGR)."""
|
888 |
system = self._get_system()
|
|
|
895 |
if not txt or mdr_is_whitespace(txt) or conf < 0.1: continue
|
896 |
pts = [(float(p[0]), float(p[1])) for p in box_pts]
|
897 |
if len(pts)==4: r=MDRRectangle(lt=pts[0], rt=pts[1], rb=pts[2], lb=pts[3]); if r.is_valid and r.area>1: yield MDROcrFragment(order=-1, text=txt, rank=float(conf), rect=r)
|
898 |
+
|
899 |
def _preprocess(self, img: np.ndarray) -> np.ndarray:
|
900 |
if len(img.shape)==3 and img.shape[2]==4: a=img[:,:,3]/255.0; bg=(255,255,255); new=np.zeros_like(img[:,:,:3]); [setattr(new[:,:,i], 'flags.writeable', True) for i in range(3)]; [np.copyto(new[:,:,i], (bg[i]*(1-a)+img[:,:,i]*a)) for i in range(3)]; img=new.astype(np.uint8)
|
901 |
elif len(img.shape)==2: img=cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
902 |
elif not (len(img.shape)==3 and img.shape[2]==3): raise ValueError("Unsupported image format")
|
903 |
return img
|
904 |
|
905 |
+
# --- MDR Layout Reading Internals ---
|
906 |
_MDR_MAX_LEN = 510; _MDR_CLS_ID = 0; _MDR_SEP_ID = 2; _MDR_PAD_ID = 1
|
907 |
+
|
908 |
def mdr_boxes_to_reader_inputs(boxes: List[List[int]], max_len=_MDR_MAX_LEN) -> Dict[str, torch.Tensor]:
|
909 |
t_boxes = boxes[:max_len]; i_boxes = [[0,0,0,0]] + t_boxes + [[0,0,0,0]]
|
910 |
i_ids = [_MDR_CLS_ID] + [_MDR_PAD_ID]*len(t_boxes) + [_MDR_SEP_ID]
|
911 |
a_mask = [1]*len(i_ids); pad_len = (max_len+2) - len(i_ids)
|
912 |
if pad_len > 0: i_boxes.extend([[0,0,0,0]]*pad_len); i_ids.extend([_MDR_PAD_ID]*pad_len); a_mask.extend([0]*pad_len)
|
913 |
return {"bbox": torch.tensor([i_boxes]), "input_ids": torch.tensor([i_ids]), "attention_mask": torch.tensor([a_mask])}
|
914 |
+
|
915 |
def mdr_prepare_reader_inputs(inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification) -> Dict[str, torch.Tensor]:
|
916 |
return {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
917 |
+
|
918 |
def mdr_parse_reader_logits(logits: torch.Tensor, length: int) -> List[int]:
|
919 |
if length == 0: return []; rel_logits = logits[1:length+1, :length]; orders = rel_logits.argmax(dim=1).tolist()
|
920 |
while True:
|
|
|
930 |
orders[idx] = rel_logits[idx, :].argmax().item(); rel_logits[idx, order] = orig_logit
|
931 |
return orders
|
932 |
|
933 |
+
# --- MDR Layout Reading Engine ---
|
934 |
@dataclass
|
935 |
class _MDR_ReaderBBox: layout_index: int; fragment_index: int; virtual: bool; order: int; value: tuple[float, float, float, float]
|
936 |
+
|
937 |
class MDRLayoutReader:
|
938 |
"""Determines reading order of layout elements using LayoutLMv3."""
|
939 |
+
|
940 |
def __init__(self, model_path: str):
|
941 |
self._model_path = model_path; self._model: LayoutLMv3ForTokenClassification | None = None
|
942 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
943 |
+
|
944 |
def _get_model(self) -> LayoutLMv3ForTokenClassification | None:
|
945 |
if self._model is None:
|
946 |
cache = mdr_ensure_directory(self._model_path); name = "microsoft/layoutlmv3-base"; h_path = os.path.join(cache, "models--hantian--layoutreader")
|
|
|
950 |
self._model.to(self._device); self._model.eval(); print(f"MDR LayoutReader loaded on {self._device}.")
|
951 |
except Exception as e: print(f"ERROR loading MDR LayoutReader: {e}"); self._model = None
|
952 |
return self._model
|
953 |
+
|
954 |
def determine_reading_order(self, layouts: list[MDRLayoutElement], size: tuple[int, int]) -> list[MDRLayoutElement]:
|
955 |
w, h = size;
|
956 |
if w<=0 or h<=0 or not layouts: return layouts;
|
|
|
977 |
if len(orders) != len(bbox_list): print("MDR LayoutReader order mismatch"); return layouts # Fallback
|
978 |
for i, order_idx in enumerate(orders): bbox_list[i].order = order_idx
|
979 |
return self._apply_order(layouts, bbox_list)
|
980 |
+
|
981 |
def _prepare_bboxes(self, layouts: list[MDRLayoutElement], w: int, h: int) -> list[_MDR_ReaderBBox] | None:
|
982 |
line_h = self._estimate_line_h(layouts); bbox_list = []
|
983 |
for i, l in enumerate(layouts):
|
|
|
985 |
else: bbox_list.extend(self._gen_virtual(l, i, line_h, w, h))
|
986 |
if len(bbox_list) > _MDR_MAX_LEN: print(f"Too many boxes ({len(bbox_list)}>{_MDR_MAX_LEN})"); return None
|
987 |
bbox_list.sort(key=lambda b: (b.value[1], b.value[0])); return bbox_list
|
988 |
+
|
989 |
def _apply_order(self, layouts: list[MDRLayoutElement], bbox_list: list[_MDR_ReaderBBox]) -> list[MDRLayoutElement]:
|
990 |
layout_map = defaultdict(list); [layout_map[b.layout_index].append(b) for b in bbox_list]
|
991 |
layout_orders = [(idx, self._median([b.order for b in bboxes])) for idx, bboxes in layout_map.items() if bboxes]
|
|
|
999 |
else: frags.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
1000 |
for frag in frags: frag.order = nfo; nfo += 1
|
1001 |
return sorted_layouts
|
1002 |
+
|
1003 |
def _estimate_line_h(self, layouts: list[MDRLayoutElement]) -> float:
|
1004 |
heights = [f.rect.size[1] for l in layouts for f in l.fragments if f.rect.size[1]>0]
|
1005 |
return self._median(heights) if heights else 15.0
|
1006 |
+
|
1007 |
def _gen_virtual(self, l: MDRLayoutElement, l_idx: int, line_h: float, pw: int, ph: int) -> Generator[_MDR_ReaderBBox, None, None]:
|
1008 |
x0,y0,x1,y1 = l.rect.wrapper; lh,lw = y1-y0,x1-x0
|
1009 |
if lh<=0 or lw<=0 or line_h<=0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(x0,y0,x1,y1)); return
|
|
|
1018 |
ly0,ly1 = max(0,min(ph,cur_y)), max(0,min(ph,cur_y+act_line_h)); lx0,lx1 = max(0,min(pw,x0)), max(0,min(pw,x1))
|
1019 |
if ly1>ly0 and lx1>lx0: yield _MDR_ReaderBBox(l_idx,-1,True,-1,(lx0,ly0,lx1,ly1))
|
1020 |
cur_y += act_line_h
|
1021 |
+
|
1022 |
def _median(self, nums: list[float|int]) -> float:
|
1023 |
if not nums: return 0.0; s_nums = sorted(nums); n = len(s_nums)
|
1024 |
return float(s_nums[n//2]) if n%2==1 else float((s_nums[n//2-1]+s_nums[n//2])/2.0)
|
1025 |
|
1026 |
+
# --- MDR LaTeX Extractor ---
|
1027 |
class MDRLatexExtractor:
|
1028 |
"""Extracts LaTeX from formula images using pix2tex."""
|
1029 |
+
|
1030 |
def __init__(self, model_path: str):
|
1031 |
self._model_path = model_path; self._model: LatexOCR | None = None
|
1032 |
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
1033 |
+
|
1034 |
def extract(self, image: Image) -> str | None:
|
1035 |
if LatexOCR is None: return None;
|
1036 |
image = mdr_expand_image(image, 0.1); model = self._get_model()
|
|
|
1038 |
try:
|
1039 |
with torch.no_grad(): img_rgb = image.convert('RGB') if image.mode!='RGB' else image; latex = model(img_rgb); return latex if latex else None
|
1040 |
except Exception as e: print(f"MDR LaTeX error: {e}"); return None
|
1041 |
+
|
1042 |
def _get_model(self) -> LatexOCR | None:
|
1043 |
if self._model is None and LatexOCR is not None:
|
1044 |
mdr_ensure_directory(self._model_path); wp, rp, cp = Path(self._model_path)/"weights.pth", Path(self._model_path)/"image_resizer.pth", Path(self._model_path)/"config.yaml"
|
|
|
1047 |
try: args = Munch({"config":str(cp), "checkpoint":str(wp), "device":self._device, "no_cuda":self._device=="cpu", "no_resize":False, "temperature":0.0}); self._model = LatexOCR(args); print(f"MDR LaTeX loaded on {self._device}.")
|
1048 |
except Exception as e: print(f"ERROR initializing MDR LatexOCR: {e}"); self._model = None
|
1049 |
return self._model
|
1050 |
+
|
1051 |
def _download(self):
|
1052 |
tag = "v0.0.1"; base = f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/"; files = {"weights.pth": base+"weights.pth", "image_resizer.pth": base+"image_resizer.pth"}
|
1053 |
mdr_ensure_directory(self._model_path); [mdr_download_model(url, Path(self._model_path)/name) for name, url in files.items() if not (Path(self._model_path)/name).exists()]
|
1054 |
|
1055 |
+
# --- MDR Table Parser ---
|
1056 |
MDRTableOutputFormat = Literal["latex", "markdown", "html"]
|
1057 |
+
|
1058 |
class MDRTableParser:
|
1059 |
"""Parses table structure/content from images using StructTable model."""
|
1060 |
+
|
1061 |
def __init__(self, device: Literal["cpu", "cuda"], model_path: str):
|
1062 |
self._model: Any | None = None; self._model_path = mdr_ensure_directory(model_path)
|
1063 |
self._device = device if torch.cuda.is_available() and device=="cuda" else "cpu"
|
1064 |
self._disabled = self._device == "cpu"
|
1065 |
if self._disabled: print("Warning: MDR Table parsing requires CUDA. Disabled.")
|
1066 |
+
|
1067 |
def parse_table_image(self, image: Image, format: MDRTableLayoutParsedFormat) -> str | None:
|
1068 |
if self._disabled: return None;
|
1069 |
fmt: MDRTableOutputFormat | None = None
|
|
|
1078 |
with torch.no_grad(): results = model([img_rgb], output_format=fmt)
|
1079 |
return results[0] if results else None
|
1080 |
except Exception as e: print(f"MDR Table parsing error: {e}"); return None
|
1081 |
+
|
1082 |
def _get_model(self):
|
1083 |
if self._model is None and not self._disabled:
|
1084 |
try:
|
|
|
1091 |
except Exception as e: print(f"ERROR loading MDR StructTable: {e}"); self._model=None
|
1092 |
return self._model
|
1093 |
|
1094 |
+
# --- MDR Image Optimizer ---
|
1095 |
_MDR_TINY_ROTATION = 0.005
|
1096 |
+
|
1097 |
@dataclass
|
1098 |
class _MDR_RotationContext: to_origin: MDRRotationAdjuster; to_new: MDRRotationAdjuster; fragment_origin_rectangles: list[MDRRectangle]
|
1099 |
+
|
1100 |
class MDRImageOptimizer:
|
1101 |
"""Handles image rotation detection and coordinate adjustments."""
|
1102 |
+
|
1103 |
def __init__(self, raw_image: Image, adjust_points: bool):
|
1104 |
self._raw = raw_image; self._image = raw_image; self._adjust_points = adjust_points
|
1105 |
self._fragments: list[MDROcrFragment] = []; self._rotation: float = 0.0; self._rot_ctx: _MDR_RotationContext | None = None
|
1106 |
+
|
1107 |
@property
|
1108 |
def image(self) -> Image: return self._image
|
1109 |
+
|
1110 |
@property
|
1111 |
def adjusted_image(self) -> Image | None: return self._image if self._rot_ctx is not None else None
|
1112 |
+
|
1113 |
@property
|
1114 |
def rotation(self) -> float: return self._rotation
|
1115 |
+
|
1116 |
@property
|
1117 |
def image_np(self) -> np.ndarray: img_rgb = np.array(self._raw.convert("RGB")); return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
1118 |
+
|
1119 |
def receive_fragments(self, fragments: list[MDROcrFragment]):
|
1120 |
self._fragments = fragments;
|
1121 |
if not fragments: return;
|
|
|
1130 |
to_new=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, False),
|
1131 |
to_origin=MDRRotationAdjuster(orig_sz, new_sz, self._rotation, True))
|
1132 |
adj = self._rot_ctx.to_new; [setattr(f, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for f in fragments if (r:=f.rect)]
|
1133 |
+
|
1134 |
def finalize_layout_coords(self, layouts: list[MDRLayoutElement]):
|
1135 |
if self._rot_ctx is None or self._adjust_points: return
|
1136 |
if len(self._fragments) == len(self._rot_ctx.fragment_origin_rectangles): [setattr(f, 'rect', orig_r) for f, orig_r in zip(self._fragments, self._rot_ctx.fragment_origin_rectangles)]
|
1137 |
adj = self._rot_ctx.to_origin; [setattr(l, 'rect', MDRRectangle(lt=adj.adjust(r.lt), rt=adj.adjust(r.rt), lb=adj.adjust(r.lb), rb=adj.adjust(r.rb))) for l in layouts if (r:=l.rect)]
|
1138 |
|
1139 |
+
# --- MDR Image Clipping ---
|
1140 |
def mdr_clip_from_image(image: Image, rect: MDRRectangle, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
|
1141 |
"""Clips a potentially rotated rectangle from an image."""
|
1142 |
try:
|
|
|
1152 |
out_w, out_h = ceil(avg_w+wrap_w), ceil(avg_h+wrap_h)
|
1153 |
return image.transform((out_w, out_h), PILTransform.AFFINE, p_mat, PILResampling.BICUBIC, fillcolor=(255,255,255))
|
1154 |
except Exception as e: print(f"MDR Clipping error: {e}"); return new_image("RGB", (10,10), (255,255,255))
|
1155 |
+
|
1156 |
def mdr_clip_layout(res: MDRExtractionResult, layout: MDRLayoutElement, wrap_w: float = 0.0, wrap_h: float = 0.0) -> Image:
|
1157 |
"""Clips a layout region from the MDRExtractionResult image."""
|
1158 |
img = res.adjusted_image if res.adjusted_image else res.extracted_image
|
1159 |
return mdr_clip_from_image(img, layout.rect, wrap_w, wrap_h)
|
1160 |
|
1161 |
+
# --- MDR Debug Plotting ---
|
1162 |
_MDR_FRAG_COLOR = (0x49, 0xCF, 0xCB, 200); _MDR_LAYOUT_COLORS = { MDRLayoutClass.TITLE: (0x0A,0x12,0x2C,255), MDRLayoutClass.PLAIN_TEXT: (0x3C,0x67,0x90,255), MDRLayoutClass.ABANDON: (0xC0,0xBB,0xA9,180), MDRLayoutClass.FIGURE: (0x5B,0x91,0x3C,255), MDRLayoutClass.FIGURE_CAPTION: (0x77,0xB3,0x54,255), MDRLayoutClass.TABLE: (0x44,0x17,0x52,255), MDRLayoutClass.TABLE_CAPTION: (0x81,0x75,0xA0,255), MDRLayoutClass.TABLE_FOOTNOTE: (0xEF,0xB6,0xC9,255), MDRLayoutClass.ISOLATE_FORMULA: (0xFA,0x38,0x27,255), MDRLayoutClass.FORMULA_CAPTION: (0xFF,0x9D,0x24,255) }; _MDR_DEFAULT_COLOR = (0x80,0x80,0x80,255); _MDR_RGBA = tuple[int,int,int,int]
|
1163 |
+
|
1164 |
def mdr_plot_layout(image: Image, layouts: Iterable[MDRLayoutElement]) -> None:
|
1165 |
"""Draws layout and fragment boxes onto an image for debugging."""
|
1166 |
if not layouts: return;
|
1167 |
try: l_font, f_font = load_default(size=25), load_default(size=15); draw = ImageDraw.Draw(image, mode="RGBA")
|
1168 |
except Exception as e: print(f"MDR Plot init error: {e}"); return
|
1169 |
+
|
1170 |
def _draw_num(pos: MDRPoint, num: int, font: FreeTypeFont, color: _MDR_RGBA):
|
1171 |
try: x,y=pos; txt=str(num); txt_pos=(round(x)+3, round(y)+1); bbox=draw.textbbox(txt_pos,txt,font=font); bg_rect=(bbox[0]-2,bbox[1]-1,bbox[2]+2,bbox[3]+1); bg_color=(color[0],color[1],color[2],180); draw.rectangle(bg_rect,fill=bg_color); draw.text(txt_pos,txt,font=font,fill=(255,255,255,255))
|
1172 |
except Exception as e: print(f"MDR Draw num error: {e}")
|
|
|
1178 |
try: draw.polygon([p for p in f.rect], outline=_MDR_FRAG_COLOR, width=1)
|
1179 |
except Exception as e: print(f"MDR Fragment draw error: {e}")
|
1180 |
|
1181 |
+
# --- MDR Extraction Engine ---
|
1182 |
class MDRExtractionEngine:
|
1183 |
"""Core engine for extracting structured information from a document image."""
|
1184 |
+
|
1185 |
def __init__(self, model_dir_path: str, device: Literal["cpu", "cuda"]="cpu", ocr_for_each_layouts: bool=True, extract_formula: bool=True, extract_table_format: MDRTableLayoutParsedFormat|None=None):
|
1186 |
+
self._model_dir = model_dir_path # Base directory for all models
|
1187 |
+
self._device = device if torch.cuda.is_available() else "cpu"
|
1188 |
self._ocr_each = ocr_for_each_layouts; self._ext_formula = extract_formula; self._ext_table = extract_table_format
|
1189 |
self._yolo: YOLOv10 | None = None
|
1190 |
+
# Initialize sub-modules, passing the main model_dir_path
|
1191 |
+
self._ocr_engine = MDROcrEngine(device=self._device, model_dir_path=os.path.join(self._model_dir, "onnx_ocr"))
|
1192 |
+
self._table_parser = MDRTableParser(device=self._device, model_path=os.path.join(self._model_dir, "struct_eqtable"))
|
1193 |
+
self._latex_extractor = MDRLatexExtractor(model_path=os.path.join(self._model_dir, "latex"))
|
1194 |
+
self._layout_reader = MDRLayoutReader(model_path=os.path.join(self._model_dir, "layoutreader"))
|
1195 |
print(f"MDR Extraction Engine initialized on device: {self._device}")
|
1196 |
+
|
1197 |
+
# --- MODIFIED _get_yolo_model METHOD for HF ---
|
1198 |
def _get_yolo_model(self) -> YOLOv10 | None:
|
1199 |
+
"""Loads the YOLOv10 layout detection model using hf_hub_download."""
|
1200 |
if self._yolo is None and YOLOv10 is not None:
|
1201 |
+
repo_id = "juliozhao/DocLayout-YOLO-DocStructBench"
|
1202 |
+
filename = "doclayout_yolo_docstructbench_imgsz1024.pt"
|
1203 |
+
# Use a subdirectory within the main model dir for YOLO cache via HF Hub
|
1204 |
+
yolo_cache_dir = Path(self._model_dir) / "yolo_hf_cache"
|
1205 |
+
mdr_ensure_directory(str(yolo_cache_dir)) # Ensure cache dir exists
|
1206 |
+
|
1207 |
+
print(f"Attempting to load YOLO model '{filename}' from repo '{repo_id}'...")
|
1208 |
+
print(f"Hugging Face Hub cache directory for YOLO: {yolo_cache_dir}")
|
1209 |
+
|
1210 |
+
try:
|
1211 |
+
# Download the model file using huggingface_hub, caching it
|
1212 |
+
yolo_model_filepath = hf_hub_download(
|
1213 |
+
repo_id=repo_id,
|
1214 |
+
filename=filename,
|
1215 |
+
cache_dir=yolo_cache_dir, # Cache within our designated structure
|
1216 |
+
local_files_only=False, # Allow download
|
1217 |
+
force_download=False, # Use cache if available
|
1218 |
+
)
|
1219 |
+
print(f"YOLO model file path: {yolo_model_filepath}")
|
1220 |
+
|
1221 |
+
# Load the model using the downloaded file path
|
1222 |
+
self._yolo = YOLOv10(yolo_model_filepath)
|
1223 |
+
print("MDR YOLOv10 model loaded successfully.")
|
1224 |
+
|
1225 |
+
except HfHubDownloadError as e:
|
1226 |
+
print(f"ERROR: Failed to download YOLO model from Hugging Face Hub: {e}")
|
1227 |
+
self._yolo = None
|
1228 |
+
except FileNotFoundError as e: # Catch if hf_hub_download fails finding file
|
1229 |
+
print(f"ERROR: YOLO model file not found via Hugging Face Hub: {e}")
|
1230 |
+
self._yolo = None
|
1231 |
+
except Exception as e:
|
1232 |
+
print(f"ERROR: Failed to load YOLOv10 model from {yolo_model_filepath}: {e}")
|
1233 |
+
self._yolo = None
|
1234 |
+
|
1235 |
+
elif YOLOv10 is None:
|
1236 |
+
print("MDR YOLOv10 class not available. Layout detection skipped.")
|
1237 |
+
|
1238 |
return self._yolo
|
1239 |
+
|
1240 |
def analyze_image(self, image: Image, adjust_points: bool=False) -> MDRExtractionResult:
|
1241 |
"""Analyzes a single page image to extract layout and content."""
|
1242 |
print(" Engine: Analyzing image..."); optimizer = MDRImageOptimizer(image, adjust_points)
|
|
|
1256 |
print(" Engine: Finalizing coords..."); optimizer.finalize_layout_coords(layouts)
|
1257 |
print(" Engine: Analysis complete.")
|
1258 |
return MDRExtractionResult(rotation=optimizer.rotation, layouts=layouts, extracted_image=image, adjusted_image=optimizer.adjusted_image)
|
1259 |
+
|
1260 |
def _run_yolo_detection(self, img: Image, yolo: YOLOv10) -> Generator[MDRLayoutElement, None, None]:
|
1261 |
img_rgb = img.convert("RGB"); res = yolo.predict(source=img_rgb, imgsz=1024, conf=0.2, device=self._device, verbose=False)
|
1262 |
if not res or not hasattr(res[0], 'boxes') or res[0].boxes is None: return
|
|
|
1270 |
if cls == MDRLayoutClass.TABLE: yield MDRTableLayoutElement(cls=cls, rect=rect, fragments=[], parsed=None)
|
1271 |
elif cls == MDRLayoutClass.ISOLATE_FORMULA: yield MDRFormulaLayoutElement(cls=cls, rect=rect, fragments=[], latex=None)
|
1272 |
elif cls in MDRPlainLayoutElement.__annotations__['cls'].__args__: yield MDRPlainLayoutElement(cls=cls, rect=rect, fragments=[])
|
1273 |
+
|
1274 |
def _match_fragments_to_layouts(self, frags: list[MDROcrFragment], layouts: list[MDRLayoutElement]) -> list[MDRLayoutElement]:
|
1275 |
if not frags or not layouts: return layouts
|
1276 |
layout_polys = [(Polygon(l.rect) if l.rect.is_valid else None) for l in layouts]
|
|
|
1290 |
layouts[best_idx].fragments.append(frag)
|
1291 |
for l in layouts: l.fragments.sort(key=lambda f: (f.rect.lt[1], f.rect.lt[0]))
|
1292 |
return layouts
|
1293 |
+
|
1294 |
def _run_ocr_correction(self, img: Image, layouts: list[MDRLayoutElement]):
|
1295 |
for i, l in enumerate(layouts):
|
1296 |
if l.cls == MDRLayoutClass.FIGURE: continue
|
1297 |
try: mdr_correct_layout_fragments(self._ocr_engine, img, l)
|
1298 |
except Exception as e: print(f" Engine: OCR correction error layout {i}: {e}")
|
1299 |
+
|
1300 |
def _parse_special_layouts(self, layouts: list[MDRLayoutElement], optimizer: MDRImageOptimizer):
|
1301 |
img_to_clip = optimizer.image
|
1302 |
for l in layouts:
|
|
|
1307 |
try: t_img = mdr_clip_from_image(img_to_clip, l.rect); parsed = self._table_parser.parse_table_image(t_img, self._ext_table) if t_img.width>1 and t_img.height>1 else None
|
1308 |
except Exception as e: print(f" Engine: Table parse error: {e}"); parsed = None
|
1309 |
if parsed: l.parsed = (parsed, self._ext_table)
|
1310 |
+
|
1311 |
def _should_keep_layout(self, l: MDRLayoutElement) -> bool:
|
1312 |
if l.fragments and not all(mdr_is_whitespace(f.text) for f in l.fragments): return True
|
1313 |
return l.cls in [MDRLayoutClass.FIGURE, MDRLayoutClass.TABLE, MDRLayoutClass.ISOLATE_FORMULA]
|
1314 |
|
1315 |
+
# --- MDR Page Section Linking ---
|
1316 |
class _MDR_LinkedShape:
|
1317 |
"""Internal helper for managing layout linking across pages."""
|
1318 |
+
|
1319 |
def __init__(self, layout: MDRLayoutElement): self.layout=layout; self.pre:list[MDRLayoutElement|None]=[None,None]; self.nex:list[MDRLayoutElement|None]=[None,None]
|
1320 |
+
|
1321 |
@property
|
1322 |
def distance2(self) -> float: x,y=self.layout.rect.lt; return x*x+y*y
|
1323 |
+
|
1324 |
class MDRPageSection:
|
1325 |
"""Represents a page's layouts for framework detection via linking."""
|
1326 |
+
|
1327 |
def __init__(self, page_index: int, layouts: Iterable[MDRLayoutElement]):
|
1328 |
self._page_index = page_index; self._shapes = [_MDR_LinkedShape(l) for l in layouts]; self._shapes.sort(key=lambda s: (s.layout.rect.lt[1], s.layout.rect.lt[0]))
|
1329 |
+
|
1330 |
@property
|
1331 |
def page_index(self) -> int: return self._page_index
|
1332 |
+
|
1333 |
def find_framework_elements(self) -> list[MDRLayoutElement]:
|
1334 |
"""Identifies framework layouts based on links to other pages."""
|
1335 |
return [s.layout for s in self._shapes if any(s.pre) or any(s.nex)]
|
1336 |
+
|
1337 |
def link_to_next(self, next_section: 'MDRPageSection', offset: int) -> None:
|
1338 |
"""Links matching shapes between this section and the next."""
|
1339 |
if offset not in (1,2): return
|
|
|
1349 |
r2_rel = self._relative_rect(orig_n_pt, s2.layout.rect); ovr = self._symmetric_iou(r1_rel, r2_rel)
|
1350 |
if ovr > max_ovr: max_ovr, best_s2 = ovr, s2
|
1351 |
if max_ovr >= 0.80 and best_s2 is not None: s1.nex[offset-1] = best_s2.layout; best_s2.pre[offset-1] = s1.layout # Link both ways
|
1352 |
+
|
1353 |
def _shapes_match(self, s1: _MDR_LinkedShape, s2: _MDR_LinkedShape) -> bool:
|
1354 |
l1, l2 = s1.layout, s2.layout; sz1, sz2 = l1.rect.size, l2.rect.size; thresh = 0.90
|
1355 |
if mdr_similarity_ratio(sz1[0], sz2[0]) < thresh or mdr_similarity_ratio(sz1[1], sz2[1]) < thresh: return False
|
|
|
1364 |
if max_sim > 0.75: matches += 1; if best_j != -1: used_f2[best_j] = True
|
1365 |
max_c = max(c1, c2); rate_frags = matches / max_c
|
1366 |
return self._check_match_threshold(rate_frags, max_c, (0.0, 0.45, 0.45, 0.6, 0.8, 0.95))
|
1367 |
+
|
1368 |
def _fragment_sim(self, l1: MDRLayoutElement, l2: MDRLayoutElement, f1: MDROcrFragment, f2: MDROcrFragment) -> float:
|
1369 |
r1_rel = self._relative_rect(l1.rect.lt, f1.rect); r2_rel = self._relative_rect(l2.rect.lt, f2.rect)
|
1370 |
geom_sim = self._symmetric_iou(r1_rel, r2_rel); text_sim, _ = mdr_check_text_similarity(f1.text, f2.text)
|
1371 |
return (geom_sim + text_sim) / 2.0
|
1372 |
+
|
1373 |
def _find_origin_pair(self, matches_matrix: list[list[_MDR_LinkedShape]], next_shapes: list[_MDR_LinkedShape]) -> tuple[_MDR_LinkedShape, _MDR_LinkedShape] | None:
|
1374 |
best_pair, min_dist2 = None, float('inf')
|
1375 |
for i, s1 in enumerate(self._shapes):
|
|
|
1377 |
if not match_list: continue
|
1378 |
for s2 in match_list: dist2 = s1.distance2 + s2.distance2; if dist2 < min_dist2: min_dist2, best_pair = dist2, (s1, s2)
|
1379 |
return best_pair
|
1380 |
+
|
1381 |
def _check_match_threshold(self, rate: float, count: int, thresholds: Sequence[float]) -> bool:
|
1382 |
if not thresholds: return False; idx = min(count, len(thresholds)-1); return rate >= thresholds[idx]
|
1383 |
+
|
1384 |
def _relative_rect(self, origin: MDRPoint, rect: MDRRectangle) -> MDRRectangle:
|
1385 |
ox, oy = origin; r=rect; return MDRRectangle(lt=(r.lt[0]-ox, r.lt[1]-oy), rt=(r.rt[0]-ox, r.rt[1]-oy), lb=(r.lb[0]-ox, r.lb[1]-oy), rb=(r.rb[0]-ox, r.rb[1]-oy))
|
1386 |
+
|
1387 |
def _symmetric_iou(self, r1: MDRRectangle, r2: MDRRectangle) -> float:
|
1388 |
try: p1, p2 = Polygon(r1), Polygon(r2);
|
1389 |
except: return 0.0
|
|
|
1393 |
if inter.is_empty or inter.area < 1e-6: return 0.0
|
1394 |
union_area = union.area; return inter.area / union_area if union_area > 1e-6 else 1.0
|
1395 |
|
1396 |
+
# --- MDR Document Iterator ---
|
1397 |
_MDR_CONTEXT_PAGES = 2 # Look behind/ahead pages for context
|
1398 |
+
|
1399 |
@dataclass
|
1400 |
class MDRProcessingParams:
|
1401 |
"""Parameters for processing a document."""
|
1402 |
pdf: str | FitzDocument; page_indexes: Iterable[int] | None; report_progress: MDRProgressReportCallback | None
|
1403 |
+
|
1404 |
class MDRDocumentIterator:
|
1405 |
"""Iterates through document pages, handles context, and calls the extraction engine."""
|
1406 |
+
|
1407 |
def __init__(self, device: Literal["cpu", "cuda"], model_dir_path: str, ocr_level: MDROcrLevel, extract_formula: bool, extract_table_format: MDRTableLayoutParsedFormat | None, debug_dir_path: str | None):
|
1408 |
self._debug_dir = debug_dir_path
|
1409 |
self._engine = MDRExtractionEngine(device=device, model_dir_path=model_dir_path, ocr_for_each_layouts=(ocr_level==MDROcrLevel.OncePerLayout), extract_formula=extract_formula, extract_table_format=extract_table_format)
|
1410 |
+
|
1411 |
def iterate_sections(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult, list[MDRLayoutElement]], None, None]:
|
1412 |
"""Yields page index, extraction result, and content layouts for each requested page."""
|
1413 |
for res, sec in self._process_and_link_sections(params):
|
1414 |
framework = set(sec.find_framework_elements()); content = [l for l in res.layouts if l not in framework]; yield sec.page_index, res, content
|
1415 |
+
|
1416 |
def _process_and_link_sections(self, params: MDRProcessingParams) -> Generator[tuple[MDRExtractionResult, MDRPageSection], None, None]:
|
1417 |
queue: list[tuple[MDRExtractionResult, MDRPageSection]] = []
|
1418 |
for page_idx, res in self._run_extraction_on_pages(params):
|
|
|
1423 |
queue.append((res, cur_sec))
|
1424 |
if len(queue) > _MDR_CONTEXT_PAGES: yield queue.pop(0)
|
1425 |
for res, sec in queue: yield res, sec
|
1426 |
+
|
1427 |
def _run_extraction_on_pages(self, params: MDRProcessingParams) -> Generator[tuple[int, MDRExtractionResult], None, None]:
|
1428 |
if self._debug_dir: mdr_ensure_directory(self._debug_dir)
|
1429 |
doc, should_close = None, False
|
|
|
1446 |
except Exception as e: print(f" Iterator: Page {page_idx+1} processing error: {e}")
|
1447 |
finally:
|
1448 |
if should_close and doc: doc.close()
|
1449 |
+
|
1450 |
def _get_page_ranges(self, doc: FitzDocument, idxs: Iterable[int]|None) -> tuple[Sequence[int], Sequence[int]]:
|
1451 |
count = doc.page_count;
|
1452 |
if idxs is None: all_p = list(range(count)); return all_p, all_p
|
|
|
1454 |
for i in idxs:
|
1455 |
if 0<=i<count: enable.add(i); [scan.add(j) for j in range(max(0, i-_MDR_CONTEXT_PAGES), min(count, i+_MDR_CONTEXT_PAGES+1))]
|
1456 |
return sorted(list(scan)), sorted(list(enable))
|
1457 |
+
|
1458 |
def _render_page_image(self, page: FitzPage, dpi: int) -> Image:
|
1459 |
mat = FitzMatrix(dpi/72.0, dpi/72.0); pix = page.get_pixmap(matrix=mat, alpha=False)
|
1460 |
return frombytes("RGB", (pix.width, pix.height), pix.samples)
|
1461 |
+
|
1462 |
def _save_debug_plot(self, img: Image, idx: int, res: MDRExtractionResult, path: str):
|
1463 |
try: plot_img = res.adjusted_image.copy() if res.adjusted_image else img.copy(); mdr_plot_layout(plot_img, res.layouts); plot_img.save(os.path.join(path, f"mdr_plot_page_{idx+1}.png"))
|
1464 |
except Exception as e: print(f" Iterator: Plot generation error page {idx+1}: {e}")
|
1465 |
|
1466 |
+
# --- MagicDataReadiness Main Processor ---
|
1467 |
class MagicPDFProcessor:
|
1468 |
"""
|
1469 |
Main class for processing PDF documents to extract structured data blocks
|
1470 |
using the MagicDataReadiness pipeline.
|
1471 |
"""
|
1472 |
+
|
1473 |
def __init__(self, device: Literal["cpu", "cuda"]="cuda", model_dir_path: str="./mdr_models", ocr_level: MDROcrLevel=MDROcrLevel.Once, extract_formula: bool=True, extract_table_format: MDRExtractedTableFormat|None=None, debug_dir_path: str|None=None):
|
1474 |
"""
|
1475 |
Initializes the MagicPDFProcessor.
|
|
|
1632 |
print(" MagicDataReadiness PDF Processor - Example Usage")
|
1633 |
print("="*60)
|
1634 |
|
1635 |
+
# --- 1. Configuration (!!! MODIFY THESE PATHS WHEN OUTSIDE HF !!!) ---
|
1636 |
# Directory where models are stored or will be downloaded
|
1637 |
# IMPORTANT: Create this directory or ensure it's writable!
|
1638 |
MDR_MODEL_DIRECTORY = "./mdr_pipeline_models"
|
|
|
1645 |
MDR_DEBUG_DIRECTORY = "./mdr_debug_output"
|
1646 |
|
1647 |
# Specify device ('cuda' or 'cpu').
|
1648 |
+
MDR_DEVICE = "cpu"
|
1649 |
|
1650 |
# Specify desired table format
|
1651 |
MDR_TABLE_FORMAT = MDRExtractedTableFormat.MARKDOWN
|