File size: 4,542 Bytes
a7b8c18 2e79922 a7b8c18 2e79922 a7b8c18 2e79922 a7b8c18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from pathlib import Path
from typing import Optional, Union
import numpy as np
from ultralytics import YOLO
from ultralyticsplus import YOLO as YOLO2
class TableDetector:
"""
A class for detecting tables in document images using YOLO models.
Attributes:
model_path (Path): Path to the YOLO model weights
confidence (float): Confidence threshold for detection
iou_threshold (float): IoU threshold for NMS
"""
def __init__(
self,
confidence: float = 0.50,
iou_threshold: float = 0.45
) -> None:
"""
Initialize the TableDetector with model and parameters.
Args:
model_path: Path to the YOLO model weights
confidence: Confidence threshold for detection
iou_threshold: IoU threshold for NMS
"""
self.model_path = 'src/models/table-detection-and-extraction.pt'
self.model = YOLO(str(self.model_path))
self.min_conf = confidence
self.iou = iou_threshold
def load_model(self):
model = YOLO2('foduucom/table-detection-and-extraction')
model.overrides['conf'] = 0.25 # NMS confidence threshold
model.overrides['iou'] = 0.45 # NMS IoU threshold
model.overrides['agnostic_nms'] = False # NMS class-agnostic
model.overrides['max_det'] = 1000
return model
def detect(self, image_path: Union[str, Path]) -> Optional[np.ndarray]:
"""
Detect tables in the given image.
Args:
image_path: Path to the input image
Returns:
Array of bounding box coordinates or None if no tables detected
"""
results = self.model.predict(str(image_path), verbose=False, iou = self.iou, conf = self.min_conf)
if results:
boxes = results[0].boxes.xyxy.numpy()
cord = self.merge_boxes(boxes)
return [sorted(cord, key = lambda x : (x[2]-x[0])* (x[3]-x[1]), reverse=True)[0]] if len(cord) > 0 else []
return None
def merge_boxes(self, boxes: np.ndarray, overlap_threshold: float = 35) -> np.ndarray:
"""
Merge overlapping bounding boxes.
Args:
boxes: Array of bounding box coordinates
overlap_threshold: Threshold for merging overlapping boxes
Returns:
Array of merged bounding box coordinates
"""
# Sort boxes by area in descending order
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sorted_indices = np.argsort(-areas)
boxes = boxes[sorted_indices]
merged_boxes = []
for box in boxes:
if not merged_boxes:
merged_boxes.append(box)
continue
overlap_found = False
for i, merged_box in enumerate(merged_boxes):
iou = self._calculate_overlap(box, merged_box)
if iou > overlap_threshold:
# Keep the larger box
box_area = (box[2] - box[0]) * (box[3] - box[1])
merged_area = (merged_box[2] - merged_box[0]) * (merged_box[3] - merged_box[1])
if box_area > merged_area:
merged_boxes[i] = box
overlap_found = True
break
if not overlap_found:
merged_boxes.append(box)
return np.array(merged_boxes).astype(int)
@staticmethod
def _calculate_overlap(box1: np.ndarray, box2: np.ndarray) -> float:
"""
Calculate the percentage overlap between two boxes.
Args:
box1: First bounding box coordinates
box2: Second bounding box coordinates
Returns:
Percentage of overlap between the boxes
"""
x_left = max(box1[0], box2[0])
y_top = max(box1[1], box2[1])
x_right = min(box1[2], box2[2])
y_bottom = min(box1[3], box2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
min_area = min(box1_area, box2_area)
if min_area == 0:
return 0.0
return (intersection_area / min_area) * 100 |