deb113 commited on
Commit
f416353
·
verified ·
1 Parent(s): 43ad293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -179
app.py CHANGED
@@ -1,197 +1,223 @@
1
- import gradio as gr
 
 
 
2
  import torch
3
- import cv2
 
4
  import numpy as np
5
- import supervision as sv
6
- from ultralytics import YOLO
7
  from PIL import Image
8
- import requests
9
- import io
10
- import os
11
- import matplotlib.pyplot as plt
12
- import pandas as pd
13
- from pathlib import Path
14
- import json
15
 
16
- # Create directories if they don't exist
 
17
  os.makedirs("models", exist_ok=True)
18
 
19
- # Download model if it doesn't exist
20
- model_path = "models/yolov8n-doclaynet.pt"
21
- if not os.path.exists(model_path):
22
- url = "https://huggingface.co/datasets/awsaf49/yolov8-doclaynet/resolve/main/yolov8n-doclaynet.pt"
23
- print(f"Downloading smaller model from {url}...")
24
- r = requests.get(url)
25
- with open(model_path, 'wb') as f:
26
- f.write(r.content)
27
- print(f"Model downloaded to {model_path}")
28
-
29
- # Load the model
30
- model = YOLO(model_path)
31
- print("Model loaded successfully!")
32
-
33
- # Define classes (from DocLayNet dataset)
34
- CLASSES = ["Caption", "Footnote", "Formula", "List-item", "Page-footer", "Page-header",
35
- "Picture", "Section-header", "Table", "Text", "Title"]
36
 
37
- # Define visual elements we want to extract
38
- VISUAL_ELEMENTS = ["Picture", "Caption", "Table", "Formula"]
39
 
40
- # Define colors for visualization - Fix for ColorPalette issue
41
- try:
42
- # Try newer versions approach
43
- COLORS = sv.ColorPalette.default()
44
- except (AttributeError, TypeError):
45
  try:
46
- # Try alternate approach for some versions
47
- COLORS = sv.ColorPalette.from_hex(["#FF0000", "#00FF00", "#0000FF", "#FFFF00", "#FF00FF", "#00FFFF",
48
- "#FFA500", "#800080", "#008000", "#000080", "#808080"])
49
- except (AttributeError, TypeError):
50
- # Fallback for older versions or different API
51
- COLORS = sv.ColorPalette(11) # Create a color palette with 11 colors (one for each class)
52
-
53
- # Set up the annotator
54
- box_annotator = sv.BoxAnnotator(color=COLORS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- def predict_layout(image):
57
- if image is None:
58
- return None, None, None
 
59
 
60
- # Convert to numpy array if it's not already
61
- if isinstance(image, np.ndarray):
62
- img = image
63
- else:
64
- img = np.array(image)
65
-
66
- # Get image dimensions
67
- img_height, img_width = img.shape[:2]
68
-
69
- # Run inference
70
- results = model(img)[0]
71
-
72
- # Format detections
73
  try:
74
- # Try with newer supervision versions
75
- detections = sv.Detections.from_ultralytics(results)
76
- except (TypeError, AttributeError):
77
- # Fallback for older versions
78
- boxes = results.boxes.xyxy.cpu().numpy()
79
- confidence = results.boxes.conf.cpu().numpy()
80
- class_ids = results.boxes.cls.cpu().numpy().astype(int)
81
-
82
- # Create Detections object manually
83
- detections = sv.Detections(
84
- xyxy=boxes,
85
- confidence=confidence,
86
- class_id=class_ids
 
 
 
 
 
 
 
 
87
  )
88
-
89
- # Get class names
90
- class_ids = detections.class_id
91
- labels = [f"{CLASSES[class_id]} {confidence:.2f}"
92
- for class_id, confidence in zip(class_ids, detections.confidence)]
93
-
94
- # Get annotated frame
95
- annotated_image = box_annotator.annotate(
96
- scene=img.copy(),
97
- detections=detections,
98
- labels=labels
99
- )
100
-
101
- # Extract bounding boxes for all visual elements
102
- boxes_data = []
103
- for i, (class_id, xyxy, confidence) in enumerate(zip(detections.class_id, detections.xyxy, detections.confidence)):
104
- class_name = CLASSES[class_id]
105
-
106
- # Include all visual elements (Pictures, Captions, Tables, Formulas)
107
- if class_name in VISUAL_ELEMENTS:
108
- x1, y1, x2, y2 = map(int, xyxy)
109
- width = x2 - x1
110
- height = y2 - y1
111
 
112
- boxes_data.append({
113
- "class": class_name,
114
- "confidence": float(confidence),
115
- "x1": int(x1),
116
- "y1": int(y1),
117
- "x2": int(x2),
118
- "y2": int(y2),
119
- "width": int(width),
120
- "height": int(height)
121
- })
122
-
123
- # Create DataFrame for display
124
- if boxes_data:
125
- df = pd.DataFrame(boxes_data)
126
- df = df[["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"]]
127
- else:
128
- df = pd.DataFrame(columns=["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"])
129
-
130
- # Convert to JSON for download
131
- json_data = json.dumps(boxes_data, indent=2)
132
-
133
- return annotated_image, df, json_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # Function to download JSON
136
- def download_json(json_data):
137
- if not json_data:
138
- return None
139
- return json_data
140
 
141
- # Set up the Gradio interface
142
- with gr.Blocks() as demo:
143
- gr.Markdown("# Document Layout Analysis for Visual Elements (YOLOv8n)")
144
- gr.Markdown("Upload a document image to extract visual elements including diagrams, tables, formulas, and captions.")
145
-
146
- with gr.Row():
147
- with gr.Column():
148
- input_image = gr.Image(label="Input Document")
149
- analyze_btn = gr.Button("Analyze Layout", variant="primary")
150
-
151
- with gr.Column():
152
- output_image = gr.Image(label="Detected Layout")
153
-
154
- with gr.Row():
155
- with gr.Column():
156
- output_table = gr.DataFrame(label="Visual Elements Bounding Boxes")
157
- json_output = gr.JSON(label="JSON Output")
158
- download_btn = gr.Button("Download JSON")
159
- json_file = gr.File(label="Download JSON File", interactive=False)
160
-
161
- analyze_btn.click(
162
- fn=predict_layout,
163
- inputs=input_image,
164
- outputs=[output_image, output_table, json_output]
165
- )
166
-
167
- download_btn.click(
168
- fn=download_json,
169
- inputs=[json_output],
170
- outputs=[json_file]
171
- )
172
-
173
- gr.Markdown("## Detected Visual Elements")
174
- gr.Markdown("""
175
- This application detects and extracts coordinates for the following visual elements:
176
-
177
- - **Pictures**: Diagrams, photos, illustrations, flowcharts, etc.
178
- - **Tables**: Structured data presented in rows and columns
179
- - **Formulas**: Mathematical equations and expressions
180
- - **Captions**: Text describing pictures or tables
181
-
182
- For each element, the system returns:
183
- - Element type (class)
184
- - Confidence score (0-1)
185
- - Coordinates (x1, y1, x2, y2)
186
- - Width and height in pixels
187
- """)
188
-
189
- gr.Markdown("## About")
190
- gr.Markdown("""
191
- This demo uses YOLOv8n for document layout analysis, with a focus on extracting visual elements.
192
- The model is a smaller, more efficient version trained on the DocLayNet dataset.
193
- """)
194
 
 
195
  if __name__ == "__main__":
196
- # Specify a lower queue_size and a maximum number of connections to limit memory use
197
- demo.launch(share=True, max_threads=1, queue_size=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["GRADIO_TEMP_DIR"] = "./tmp"
3
+
4
+ import sys
5
  import torch
6
+ import torchvision
7
+ import gradio as gr
8
  import numpy as np
 
 
9
  from PIL import Image
10
+ from huggingface_hub import snapshot_download
11
+ from visualization import visualize_bbox
 
 
 
 
 
12
 
13
+ # Create necessary directories
14
+ os.makedirs("tmp", exist_ok=True)
15
  os.makedirs("models", exist_ok=True)
16
 
17
+ # Define class mapping
18
+ id_to_names = {
19
+ 0: 'title',
20
+ 1: 'plain text',
21
+ 2: 'abandon',
22
+ 3: 'figure',
23
+ 4: 'figure_caption',
24
+ 5: 'table',
25
+ 6: 'table_caption',
26
+ 7: 'table_footnote',
27
+ 8: 'isolate_formula',
28
+ 9: 'formula_caption'
29
+ }
 
 
 
 
30
 
31
+ # Visual elements for extraction (can be customized)
32
+ VISUAL_ELEMENTS = ['figure', 'table', 'figure_caption', 'table_caption', 'isolate_formula']
33
 
34
+ def load_model():
35
+ """Load the DocLayout-YOLO model from Hugging Face"""
 
 
 
36
  try:
37
+ # Download model weights if they don't exist
38
+ model_dir = snapshot_download(
39
+ 'juliozhao/DocLayout-YOLO-DocStructBench',
40
+ local_dir='./models/DocLayout-YOLO-DocStructBench'
41
+ )
42
+
43
+ # Select device
44
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ print(f"Using device: {device}")
46
+
47
+ # Import and load the model
48
+ from doclayout_yolo import YOLOv10
49
+ model = YOLOv10(os.path.join(
50
+ os.path.dirname(__file__),
51
+ "models",
52
+ "DocLayout-YOLO-DocStructBench",
53
+ "doclayout_yolo_docstructbench_imgsz1024.pt"
54
+ ))
55
+
56
+ return model, device
57
+
58
+ except Exception as e:
59
+ print(f"Error loading model: {e}")
60
+ return None, 'cpu'
61
 
62
+ def recognize_image(input_img, conf_threshold, iou_threshold):
63
+ """Process input image and detect document elements"""
64
+ if input_img is None:
65
+ return None, None
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
+ # Load model (global model if already loaded)
69
+ global model, device
70
+
71
+ # Run prediction
72
+ det_res = model.predict(
73
+ input_img,
74
+ imgsz=1024,
75
+ conf=conf_threshold,
76
+ device=device,
77
+ )[0]
78
+
79
+ # Extract detection results
80
+ boxes = det_res.__dict__['boxes'].xyxy
81
+ classes = det_res.__dict__['boxes'].cls
82
+ scores = det_res.__dict__['boxes'].conf
83
+
84
+ # Apply non-maximum suppression
85
+ indices = torchvision.ops.nms(
86
+ boxes=torch.Tensor(boxes),
87
+ scores=torch.Tensor(scores),
88
+ iou_threshold=iou_threshold
89
  )
90
+
91
+ boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
92
+
93
+ # Handle single detection case
94
+ if len(boxes.shape) == 1:
95
+ boxes = np.expand_dims(boxes, 0)
96
+ scores = np.expand_dims(scores, 0)
97
+ classes = np.expand_dims(classes, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Visualize results
100
+ vis_result = visualize_bbox(input_img, boxes, classes, scores, id_to_names)
101
+
102
+ # Create DataFrame for extraction
103
+ elements_data = []
104
+ for i, (box, cls_id, score) in enumerate(zip(boxes, classes, scores)):
105
+ class_name = id_to_names[int(cls_id)]
106
+
107
+ # Only extract visual elements if specified
108
+ if not VISUAL_ELEMENTS or class_name in VISUAL_ELEMENTS:
109
+ x1, y1, x2, y2 = map(int, box)
110
+ width = x2 - x1
111
+ height = y2 - y1
112
+
113
+ elements_data.append({
114
+ "class": class_name,
115
+ "confidence": float(score),
116
+ "x1": x1,
117
+ "y1": y1,
118
+ "x2": x2,
119
+ "y2": y2,
120
+ "width": width,
121
+ "height": height
122
+ })
123
+
124
+ # Convert to DataFrame for display
125
+ import pandas as pd
126
+ if elements_data:
127
+ df = pd.DataFrame(elements_data)
128
+ df = df[["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"]]
129
+ else:
130
+ df = pd.DataFrame(columns=["class", "confidence", "x1", "y1", "x2", "y2", "width", "height"])
131
+
132
+ return vis_result, df
133
+
134
+ except Exception as e:
135
+ print(f"Error processing image: {e}")
136
+ import traceback
137
+ traceback.print_exc()
138
+ return None, None
139
 
140
+ def gradio_reset():
141
+ """Reset the UI"""
142
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
 
 
143
 
144
+ # Create basic HTML header
145
+ header_html = """
146
+ <div style="text-align: center; max-width: 900px; margin: 0 auto;">
147
+ <div>
148
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
149
+ Document Layout Analysis
150
+ </h1>
151
+ <p style="margin-top: 7px; font-size: 94%;">
152
+ Detect and extract structured elements from document images using DocLayout-YOLO
153
+ </p>
154
+ </div>
155
+ </div>
156
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ # Main execution
159
  if __name__ == "__main__":
160
+ # Load model
161
+ model, device = load_model()
162
+
163
+ # Create Gradio interface
164
+ with gr.Blocks() as demo:
165
+ gr.HTML(header_html)
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ input_img = gr.Image(label="Upload Document Image", interactive=True)
170
+
171
+ with gr.Row():
172
+ clear_btn = gr.Button(value="Clear")
173
+ predict_btn = gr.Button(value="Detect Elements", interactive=True, variant="primary")
174
+
175
+ with gr.Row():
176
+ conf_threshold = gr.Slider(
177
+ label="Confidence Threshold",
178
+ minimum=0.0,
179
+ maximum=1.0,
180
+ step=0.05,
181
+ value=0.25,
182
+ )
183
+
184
+ iou_threshold = gr.Slider(
185
+ label="NMS IOU Threshold",
186
+ minimum=0.0,
187
+ maximum=1.0,
188
+ step=0.05,
189
+ value=0.45,
190
+ )
191
+
192
+ with gr.Column():
193
+ output_img = gr.Image(label="Detection Result", interactive=False)
194
+ output_table = gr.DataFrame(label="Detected Visual Elements")
195
+
196
+ with gr.Row():
197
+ gr.Markdown("""
198
+ ## Detected Elements
199
+ This application detects and extracts the following elements from document images:
200
+
201
+ - **Title**: Document and section titles
202
+ - **Plain Text**: Regular paragraph text
203
+ - **Figure**: Images, charts, diagrams, etc.
204
+ - **Figure Caption**: Text describing figures
205
+ - **Table**: Tabular data structures
206
+ - **Table Caption**: Text describing tables
207
+ - **Table Footnote**: Notes below tables
208
+ - **Formula**: Mathematical equations
209
+ - **Formula Caption**: Text describing formulas
210
+
211
+ For each element, the system returns coordinates and confidence scores.
212
+ """)
213
+
214
+ # Connect events
215
+ clear_btn.click(gradio_reset, inputs=None, outputs=[input_img, output_img, output_table])
216
+ predict_btn.click(
217
+ recognize_image,
218
+ inputs=[input_img, conf_threshold, iou_threshold],
219
+ outputs=[output_img, output_table]
220
+ )
221
+
222
+ # Launch the interface
223
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860)