Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
+
from ultralytics import YOLO
|
7 |
+
from PIL import Image
|
8 |
+
import requests
|
9 |
+
import io
|
10 |
+
import os
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import pandas as pd
|
13 |
+
from pathly import Path
|
14 |
+
import json
|
15 |
+
|
16 |
+
# Create directories if they don't exist
|
17 |
+
os.makedirs("models", exist_ok=True)
|
18 |
+
|
19 |
+
# Download model if it doesn't exist
|
20 |
+
model_path = "models/yolov11x_best.pt"
|
21 |
+
if not os.path.exists(model_path):
|
22 |
+
url = "https://github.com/moured/YOLOv11-Document-Layout-Analysis/releases/download/doclaynet_weights/yolov11x_best.pt"
|
23 |
+
print(f"Downloading model from {url}...")
|
24 |
+
r = requests.get(url)
|
25 |
+
with open(model_path, 'wb') as f:
|
26 |
+
f.write(r.content)
|
27 |
+
print(f"Model downloaded to {model_path}")
|
28 |
+
|
29 |
+
# Load the model
|
30 |
+
model = YOLO(model_path)
|
31 |
+
print("Model loaded successfully!")
|
32 |
+
|
33 |
+
# Define classes (from DocLayNet dataset)
|
34 |
+
CLASSES = ["Caption", "Footnote", "Formula", "List-item", "Page-footer", "Page-header",
|
35 |
+
"Picture", "Section-header", "Table", "Text", "Title"]
|
36 |
+
|
37 |
+
# Define visual elements we want to extract
|
38 |
+
VISUAL_ELEMENTS = ["Picture", "Caption", "Table", "Formula"]
|
39 |
+
|
40 |
+
# Define colors for visualization
|
41 |
+
COLORS = sv.ColorPalette.default()
|
42 |
+
|
43 |
+
# Set up the annotator
|
44 |
+
box_annotator = sv.BoxAnnotator(color=COLORS)
|
45 |
+
|
46 |
+
def predict_layout(image):
|
47 |
+
if image is None:
|
48 |
+
return None, None, None
|
49 |
+
|
50 |
+
# Convert to numpy array if it's not already
|
51 |
+
if isinstance(image, np.ndarray):
|
52 |
+
img = image
|
53 |
+
else:
|
54 |
+
img = np.array(image)
|
55 |
+
|
56 |
+
# Get image dimensions
|
57 |
+
img_height, img_width = img.shape[:2]
|
58 |
+
|
59 |
+
# Run inference
|
60 |
+
results = model(img)[0]
|
61 |
+
|
62 |
+
# Format detections
|
63 |
+
detections = sv.Detections.from_ultralytics(results)
|
64 |
+
|
65 |
+
# Get class names
|
66 |
+
class_ids = detections.class_id
|
67 |
+
labels = [f"{CLASSES[class_id]} {confidence:.2f}"
|
68 |
+
for class_id, confidence in zip(class_ids, detections.confidence)]
|
69 |
+
|
70 |
+
# Get annotated frame
|
71 |
+
annotated_image = box_annotator.annotate(
|
72 |
+
scene=img.copy(),
|
73 |
+
detections=detections,
|
74 |
+
labels=labels
|
75 |
+
)
|
76 |
+
|
77 |
+
# Extract bounding boxes for all visual elements
|
78 |
+
boxes_data = []
|
79 |
+
for i, (class_id, xyxy, confidence) in enumerate(zip(detections.class_id, detections.xyxy, detections.confidence)):
|
80 |
+
class_name = CLASSES[class_id]
|
81 |
+
|
82 |
+
# Include all visual elements (Pictures, Captions, Tables, Formulas)
|
83 |
+
# You can add or remove classes based on what you consider "visual elements"
|
84 |
+
if class_name in VISUAL_ELEMENTS:
|
85 |
+
x1, y1, x2, y2 = map(int, xyxy)
|
86 |
+
width = x2 - x1
|
87 |
+
height = y2 - y1
|
88 |
+
|
89 |
+
boxes_data.append({
|
90 |
+
"class": class_name,
|
91 |
+
"confidence": float(confidence),
|
92 |
+
"x1": int(x1),
|
93 |
+
"y1": int(y1),
|
94 |
+
"x2": int(x2),
|
95 |
+
"y2": int(y2),
|
96 |
+
"width": int(width),
|
97 |
+
"height": int(height)
|
98 |
+
})
|
99 |
+
|
100 |
+
# Create DataFrame for display
|
101 |
+
if boxes_data:
|
102 |
+
df = pd.DataFrame(boxes_data)
|
103 |
+
df = df[["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"]]
|
104 |
+
else:
|
105 |
+
df = pd.DataFrame(columns=["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"])
|
106 |
+
|
107 |
+
# Convert to JSON for download
|
108 |
+
json_data = json.dumps(boxes_data, indent=2)
|
109 |
+
|
110 |
+
return annotated_image, df, json_data
|
111 |
+
|
112 |
+
# Function to download JSON
|
113 |
+
def download_json(json_data):
|
114 |
+
if not json_data:
|
115 |
+
return None
|
116 |
+
return json_data
|
117 |
+
|
118 |
+
# Set up the Gradio interface
|
119 |
+
with gr.Blocks() as demo:
|
120 |
+
gr.Markdown("# YOLOv11x Document Layout Analysis for Visual Elements")
|
121 |
+
gr.Markdown("Upload a document image to extract visual elements including diagrams, tables, formulas, and captions.")
|
122 |
+
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column():
|
125 |
+
input_image = gr.Image(label="Input Document")
|
126 |
+
analyze_btn = gr.Button("Analyze Layout", variant="primary")
|
127 |
+
|
128 |
+
with gr.Column():
|
129 |
+
output_image = gr.Image(label="Detected Layout")
|
130 |
+
|
131 |
+
with gr.Row():
|
132 |
+
with gr.Column():
|
133 |
+
output_table = gr.DataFrame(label="Visual Elements Bounding Boxes")
|
134 |
+
json_output = gr.JSON(label="JSON Output")
|
135 |
+
download_btn = gr.Button("Download JSON")
|
136 |
+
json_file = gr.File(label="Download JSON File", interactive=False)
|
137 |
+
|
138 |
+
analyze_btn.click(
|
139 |
+
fn=predict_layout,
|
140 |
+
inputs=input_image,
|
141 |
+
outputs=[output_image, output_table, json_output]
|
142 |
+
)
|
143 |
+
|
144 |
+
download_btn.click(
|
145 |
+
fn=download_json,
|
146 |
+
inputs=[json_output],
|
147 |
+
outputs=[json_file]
|
148 |
+
)
|
149 |
+
|
150 |
+
gr.Markdown("## Detected Visual Elements")
|
151 |
+
gr.Markdown("""
|
152 |
+
This application detects and extracts coordinates for the following visual elements:
|
153 |
+
|
154 |
+
- **Pictures**: Diagrams, photos, illustrations, flowcharts, etc.
|
155 |
+
- **Tables**: Structured data presented in rows and columns
|
156 |
+
- **Formulas**: Mathematical equations and expressions
|
157 |
+
- **Captions**: Text describing pictures or tables
|
158 |
+
|
159 |
+
For each element, the system returns:
|
160 |
+
- Element type (class)
|
161 |
+
- Confidence score (0-1)
|
162 |
+
- Coordinates (x1, y1, x2, y2)
|
163 |
+
- Width and height in pixels
|
164 |
+
""")
|
165 |
+
|
166 |
+
gr.Markdown("## About")
|
167 |
+
gr.Markdown("""
|
168 |
+
This demo uses YOLOv11x for document layout analysis, with a focus on extracting visual elements.
|
169 |
+
Model from [moured/YOLOv11-Document-Layout-Analysis](https://github.com/moured/YOLOv11-Document-Layout-Analysis)
|
170 |
+
""")
|
171 |
+
|
172 |
+
# Add example images
|
173 |
+
gr.Examples(
|
174 |
+
examples=[
|
175 |
+
"https://raw.githubusercontent.com/moured/YOLOv11-Document-Layout-Analysis/main/assets/sample1.png",
|
176 |
+
"https://raw.githubusercontent.com/moured/YOLOv11-Document-Layout-Analysis/main/assets/sample2.png",
|
177 |
+
],
|
178 |
+
inputs=input_image
|
179 |
+
)
|
180 |
+
|
181 |
+
demo.launch()
|