deb113 commited on
Commit
84be540
·
verified ·
1 Parent(s): 683a960

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import supervision as sv
6
+ from ultralytics import YOLO
7
+ from PIL import Image
8
+ import requests
9
+ import io
10
+ import os
11
+ import matplotlib.pyplot as plt
12
+ import pandas as pd
13
+ from pathly import Path
14
+ import json
15
+
16
+ # Create directories if they don't exist
17
+ os.makedirs("models", exist_ok=True)
18
+
19
+ # Download model if it doesn't exist
20
+ model_path = "models/yolov11x_best.pt"
21
+ if not os.path.exists(model_path):
22
+ url = "https://github.com/moured/YOLOv11-Document-Layout-Analysis/releases/download/doclaynet_weights/yolov11x_best.pt"
23
+ print(f"Downloading model from {url}...")
24
+ r = requests.get(url)
25
+ with open(model_path, 'wb') as f:
26
+ f.write(r.content)
27
+ print(f"Model downloaded to {model_path}")
28
+
29
+ # Load the model
30
+ model = YOLO(model_path)
31
+ print("Model loaded successfully!")
32
+
33
+ # Define classes (from DocLayNet dataset)
34
+ CLASSES = ["Caption", "Footnote", "Formula", "List-item", "Page-footer", "Page-header",
35
+ "Picture", "Section-header", "Table", "Text", "Title"]
36
+
37
+ # Define visual elements we want to extract
38
+ VISUAL_ELEMENTS = ["Picture", "Caption", "Table", "Formula"]
39
+
40
+ # Define colors for visualization
41
+ COLORS = sv.ColorPalette.default()
42
+
43
+ # Set up the annotator
44
+ box_annotator = sv.BoxAnnotator(color=COLORS)
45
+
46
+ def predict_layout(image):
47
+ if image is None:
48
+ return None, None, None
49
+
50
+ # Convert to numpy array if it's not already
51
+ if isinstance(image, np.ndarray):
52
+ img = image
53
+ else:
54
+ img = np.array(image)
55
+
56
+ # Get image dimensions
57
+ img_height, img_width = img.shape[:2]
58
+
59
+ # Run inference
60
+ results = model(img)[0]
61
+
62
+ # Format detections
63
+ detections = sv.Detections.from_ultralytics(results)
64
+
65
+ # Get class names
66
+ class_ids = detections.class_id
67
+ labels = [f"{CLASSES[class_id]} {confidence:.2f}"
68
+ for class_id, confidence in zip(class_ids, detections.confidence)]
69
+
70
+ # Get annotated frame
71
+ annotated_image = box_annotator.annotate(
72
+ scene=img.copy(),
73
+ detections=detections,
74
+ labels=labels
75
+ )
76
+
77
+ # Extract bounding boxes for all visual elements
78
+ boxes_data = []
79
+ for i, (class_id, xyxy, confidence) in enumerate(zip(detections.class_id, detections.xyxy, detections.confidence)):
80
+ class_name = CLASSES[class_id]
81
+
82
+ # Include all visual elements (Pictures, Captions, Tables, Formulas)
83
+ # You can add or remove classes based on what you consider "visual elements"
84
+ if class_name in VISUAL_ELEMENTS:
85
+ x1, y1, x2, y2 = map(int, xyxy)
86
+ width = x2 - x1
87
+ height = y2 - y1
88
+
89
+ boxes_data.append({
90
+ "class": class_name,
91
+ "confidence": float(confidence),
92
+ "x1": int(x1),
93
+ "y1": int(y1),
94
+ "x2": int(x2),
95
+ "y2": int(y2),
96
+ "width": int(width),
97
+ "height": int(height)
98
+ })
99
+
100
+ # Create DataFrame for display
101
+ if boxes_data:
102
+ df = pd.DataFrame(boxes_data)
103
+ df = df[["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"]]
104
+ else:
105
+ df = pd.DataFrame(columns=["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"])
106
+
107
+ # Convert to JSON for download
108
+ json_data = json.dumps(boxes_data, indent=2)
109
+
110
+ return annotated_image, df, json_data
111
+
112
+ # Function to download JSON
113
+ def download_json(json_data):
114
+ if not json_data:
115
+ return None
116
+ return json_data
117
+
118
+ # Set up the Gradio interface
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("# YOLOv11x Document Layout Analysis for Visual Elements")
121
+ gr.Markdown("Upload a document image to extract visual elements including diagrams, tables, formulas, and captions.")
122
+
123
+ with gr.Row():
124
+ with gr.Column():
125
+ input_image = gr.Image(label="Input Document")
126
+ analyze_btn = gr.Button("Analyze Layout", variant="primary")
127
+
128
+ with gr.Column():
129
+ output_image = gr.Image(label="Detected Layout")
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ output_table = gr.DataFrame(label="Visual Elements Bounding Boxes")
134
+ json_output = gr.JSON(label="JSON Output")
135
+ download_btn = gr.Button("Download JSON")
136
+ json_file = gr.File(label="Download JSON File", interactive=False)
137
+
138
+ analyze_btn.click(
139
+ fn=predict_layout,
140
+ inputs=input_image,
141
+ outputs=[output_image, output_table, json_output]
142
+ )
143
+
144
+ download_btn.click(
145
+ fn=download_json,
146
+ inputs=[json_output],
147
+ outputs=[json_file]
148
+ )
149
+
150
+ gr.Markdown("## Detected Visual Elements")
151
+ gr.Markdown("""
152
+ This application detects and extracts coordinates for the following visual elements:
153
+
154
+ - **Pictures**: Diagrams, photos, illustrations, flowcharts, etc.
155
+ - **Tables**: Structured data presented in rows and columns
156
+ - **Formulas**: Mathematical equations and expressions
157
+ - **Captions**: Text describing pictures or tables
158
+
159
+ For each element, the system returns:
160
+ - Element type (class)
161
+ - Confidence score (0-1)
162
+ - Coordinates (x1, y1, x2, y2)
163
+ - Width and height in pixels
164
+ """)
165
+
166
+ gr.Markdown("## About")
167
+ gr.Markdown("""
168
+ This demo uses YOLOv11x for document layout analysis, with a focus on extracting visual elements.
169
+ Model from [moured/YOLOv11-Document-Layout-Analysis](https://github.com/moured/YOLOv11-Document-Layout-Analysis)
170
+ """)
171
+
172
+ # Add example images
173
+ gr.Examples(
174
+ examples=[
175
+ "https://raw.githubusercontent.com/moured/YOLOv11-Document-Layout-Analysis/main/assets/sample1.png",
176
+ "https://raw.githubusercontent.com/moured/YOLOv11-Document-Layout-Analysis/main/assets/sample2.png",
177
+ ],
178
+ inputs=input_image
179
+ )
180
+
181
+ demo.launch()