|
import os |
|
os.environ["GRADIO_TEMP_DIR"] = "./tmp" |
|
|
|
import sys |
|
import torch |
|
import torchvision |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from huggingface_hub import snapshot_download |
|
from visualization import visualize_bbox |
|
|
|
|
|
os.makedirs("tmp", exist_ok=True) |
|
os.makedirs("models", exist_ok=True) |
|
|
|
|
|
id_to_names = { |
|
0: 'title', |
|
1: 'plain text', |
|
2: 'abandon', |
|
3: 'figure', |
|
4: 'figure_caption', |
|
5: 'table', |
|
6: 'table_caption', |
|
7: 'table_footnote', |
|
8: 'isolate_formula', |
|
9: 'formula_caption' |
|
} |
|
|
|
|
|
VISUAL_ELEMENTS = ['figure', 'table', 'figure_caption', 'table_caption', 'isolate_formula'] |
|
|
|
def load_model(): |
|
"""Load the DocLayout-YOLO model from Hugging Face""" |
|
try: |
|
|
|
model_dir = snapshot_download( |
|
'juliozhao/DocLayout-YOLO-DocStructBench', |
|
local_dir='./models/DocLayout-YOLO-DocStructBench' |
|
) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {device}") |
|
|
|
|
|
from doclayout_yolo import YOLOv10 |
|
model = YOLOv10(os.path.join( |
|
os.path.dirname(__file__), |
|
"models", |
|
"DocLayout-YOLO-DocStructBench", |
|
"doclayout_yolo_docstructbench_imgsz1024.pt" |
|
)) |
|
|
|
return model, device |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
return None, 'cpu' |
|
|
|
def recognize_image(input_img, conf_threshold, iou_threshold): |
|
"""Process input image and detect document elements""" |
|
if input_img is None: |
|
return None, None |
|
|
|
try: |
|
|
|
global model, device |
|
|
|
|
|
det_res = model.predict( |
|
input_img, |
|
imgsz=1024, |
|
conf=conf_threshold, |
|
device=device, |
|
)[0] |
|
|
|
|
|
boxes = det_res.__dict__['boxes'].xyxy |
|
classes = det_res.__dict__['boxes'].cls |
|
scores = det_res.__dict__['boxes'].conf |
|
|
|
|
|
indices = torchvision.ops.nms( |
|
boxes=torch.Tensor(boxes), |
|
scores=torch.Tensor(scores), |
|
iou_threshold=iou_threshold |
|
) |
|
|
|
boxes, scores, classes = boxes[indices], scores[indices], classes[indices] |
|
|
|
|
|
if len(boxes.shape) == 1: |
|
boxes = np.expand_dims(boxes, 0) |
|
scores = np.expand_dims(scores, 0) |
|
classes = np.expand_dims(classes, 0) |
|
|
|
|
|
vis_result = visualize_bbox(input_img, boxes, classes, scores, id_to_names) |
|
|
|
|
|
elements_data = [] |
|
for i, (box, cls_id, score) in enumerate(zip(boxes, classes, scores)): |
|
class_name = id_to_names[int(cls_id)] |
|
|
|
|
|
if not VISUAL_ELEMENTS or class_name in VISUAL_ELEMENTS: |
|
x1, y1, x2, y2 = map(int, box) |
|
width = x2 - x1 |
|
height = y2 - y1 |
|
|
|
elements_data.append({ |
|
"class": class_name, |
|
"confidence": float(score), |
|
"x1": x1, |
|
"y1": y1, |
|
"x2": x2, |
|
"y2": y2, |
|
"width": width, |
|
"height": height |
|
}) |
|
|
|
|
|
import pandas as pd |
|
if elements_data: |
|
df = pd.DataFrame(elements_data) |
|
df = df[["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"]] |
|
else: |
|
df = pd.DataFrame(columns=["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"]) |
|
|
|
return vis_result, df |
|
|
|
except Exception as e: |
|
print(f"Error processing image: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None, None |
|
|
|
def gradio_reset(): |
|
"""Reset the UI""" |
|
return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
|
|
|
|
|
header_html = """ |
|
<div style="text-align: center; max-width: 900px; margin: 0 auto;"> |
|
<div> |
|
<h1 style="font-weight: 900; margin-bottom: 7px;"> |
|
Document Layout Analysis |
|
</h1> |
|
<p style="margin-top: 7px; font-size: 94%;"> |
|
Detect and extract structured elements from document images using DocLayout-YOLO |
|
</p> |
|
</div> |
|
</div> |
|
""" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
model, device = load_model() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(header_html) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image(label="Upload Document Image", interactive=True) |
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button(value="Clear") |
|
predict_btn = gr.Button(value="Detect Elements", interactive=True, variant="primary") |
|
|
|
with gr.Row(): |
|
conf_threshold = gr.Slider( |
|
label="Confidence Threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.25, |
|
) |
|
|
|
iou_threshold = gr.Slider( |
|
label="NMS IOU Threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.45, |
|
) |
|
|
|
with gr.Column(): |
|
output_img = gr.Image(label="Detection Result", interactive=False) |
|
output_table = gr.DataFrame(label="Detected Visual Elements") |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
## Detected Elements |
|
This application detects and extracts the following elements from document images: |
|
|
|
- **Title**: Document and section titles |
|
- **Plain Text**: Regular paragraph text |
|
- **Figure**: Images, charts, diagrams, etc. |
|
- **Figure Caption**: Text describing figures |
|
- **Table**: Tabular data structures |
|
- **Table Caption**: Text describing tables |
|
- **Table Footnote**: Notes below tables |
|
- **Formula**: Mathematical equations |
|
- **Formula Caption**: Text describing formulas |
|
|
|
For each element, the system returns coordinates and confidence scores. |
|
""") |
|
|
|
|
|
clear_btn.click(gradio_reset, inputs=None, outputs=[input_img, output_img, output_table]) |
|
predict_btn.click( |
|
recognize_image, |
|
inputs=[input_img, conf_threshold, iou_threshold], |
|
outputs=[output_img, output_table] |
|
) |
|
|
|
|
|
demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |