ChaseHan commited on
Commit
49fbaa3
·
verified ·
1 Parent(s): 83d84a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -97
app.py CHANGED
@@ -1,109 +1,126 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- import os
5
  import tempfile
6
  from ultralytics import YOLO
 
 
 
7
 
8
- # Load the Latex2Layout model
9
- model_path = "latex2layout_object_detection_yolov8.pt"
10
- model = YOLO(model_path)
11
 
12
- def detect_and_visualize(image):
 
 
 
 
 
13
  """
14
- Perform layout detection on the uploaded image using the Latex2Layout model and visualize the results.
15
 
16
  Args:
17
- image: The uploaded image
18
 
19
  Returns:
20
- annotated_image: Image with detection boxes
21
- layout_annotations: Annotations in YOLO format
22
  """
23
  if image is None:
24
- return None, "Error: No image uploaded."
25
 
26
- # Run detection using the Latex2Layout model
27
- results = model(image)
28
  result = results[0]
29
 
30
- # Create a copy of the image for visualization
31
- annotated_image = image.copy()
32
- layout_annotations = []
33
-
34
- # Get image dimensions
35
- img_height, img_width = image.shape[:2]
36
-
37
- # Draw detection results
38
  for box in result.boxes:
39
  x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
40
- conf = float(box.conf[0])
41
  cls_id = int(box.cls[0])
42
  cls_name = result.names[cls_id]
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Generate a color for each class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  color = tuple(np.random.randint(0, 255, 3).tolist())
46
-
47
- # Draw bounding box and label
48
  cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
49
  label = f'{cls_name} {conf:.2f}'
50
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
51
  cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
52
  cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
53
-
54
- # Convert to YOLO format (normalized)
55
- x_center = (x1 + x2) / (2 * img_width)
56
- y_center = (y1 + y2) / (2 * img_height)
57
- width = (x2 - x1) / img_width
58
- height = (y2 - y1) / img_height
59
- layout_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
60
 
61
- return annotated_image, "\n".join(layout_annotations)
62
-
63
- def save_layout_annotations(layout_annotations_str):
64
- """
65
- Save layout annotations to a temporary file and return the file path.
 
 
 
 
 
 
66
 
67
- Args:
68
- layout_annotations_str: Annotations string in YOLO format
69
-
70
- Returns:
71
- file_path: Path to the saved annotation file
72
- """
73
- if not layout_annotations_str:
74
- return None
75
 
76
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
77
- with open(temp_file.name, "w") as f:
78
- f.write(layout_annotations_str)
79
- return temp_file.name
80
 
81
  # Custom CSS for styling
82
  custom_css = """
83
  .container { max-width: 1200px; margin: auto; }
84
  .button-primary { background-color: #4CAF50; color: white; }
85
- .button-secondary { background-color: #008CBA; color: white; }
86
  .gr-image { border: 2px solid #ddd; border-radius: 5px; }
87
- .gr-textbox { font-family: monospace; }
88
  """
89
 
90
- # Create Gradio interface with enhanced styling
91
  with gr.Blocks(
92
- title="Latex2Layout Detection",
93
  theme=gr.themes.Default(),
94
  css=custom_css
95
  ) as demo:
96
- # Header with instructions
97
  gr.Markdown(
98
  """
99
- # Latex2Layout Layout Detection
100
- Upload an image to detect layout elements using the **Latex2Layout** model. View the annotated image and download the results in YOLO format.
101
  """
102
  )
103
 
104
- # Main layout with two columns
105
  with gr.Row():
106
- # Input column
107
  with gr.Column(scale=1):
108
  input_image = gr.Image(
109
  label="Upload Image",
@@ -111,63 +128,43 @@ with gr.Blocks(
111
  height=400,
112
  elem_classes="gr-image"
113
  )
114
- detect_btn = gr.Button(
115
- "Start Detection",
 
 
 
 
 
116
  variant="primary",
117
  elem_classes="button-primary"
118
  )
119
- gr.Markdown("**Tip**: Upload a clear image for optimal detection results.")
120
 
121
- # Output column
122
  with gr.Column(scale=1):
123
  output_image = gr.Image(
124
- label="Detection Results",
125
  height=400,
126
  elem_classes="gr-image"
127
  )
128
- layout_annotations = gr.Textbox(
129
- label="Layout Annotations (YOLO Format)",
130
- lines=10,
131
- max_lines=15,
132
  elem_classes="gr-textbox"
133
  )
134
- download_btn = gr.Button(
135
- "Download Annotations",
136
- variant="secondary",
137
- elem_classes="button-secondary"
138
- )
139
- download_file = gr.File(
140
- label="Download File",
141
- interactive=False
142
- )
143
 
144
- # Example image button (optional)
145
- with gr.Row():
146
- gr.Button("Load Example Image").click(
147
- fn=lambda: cv2.imread("example_image.jpg"),
148
- outputs=input_image
149
- )
150
-
151
- # Event handlers
152
- detect_btn.click(
153
- fn=detect_and_visualize,
154
- inputs=input_image,
155
- outputs=[output_image, layout_annotations],
156
  _js="() => { document.querySelector('.button-primary').innerText = 'Processing...'; }",
157
  show_progress=True
158
  ).then(
159
- fn=lambda: gr.update(value="Start Detection"),
160
- outputs=detect_btn,
161
- _js="() => { document.querySelector('.button-primary').innerText = 'Start Detection'; }"
162
- )
163
-
164
- download_btn.click(
165
- fn=save_layout_annotations,
166
- inputs=layout_annotations,
167
- outputs=download_file
168
  )
169
 
170
-
171
  # Launch the application
172
  if __name__ == "__main__":
173
  demo.launch()
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
 
4
  import tempfile
5
  from ultralytics import YOLO
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from PIL import Image
8
+ import torch
9
 
10
+ # Load the Latex2Layout model for layout detection
11
+ latex2layout_model_path = "latex2layout_object_detection_yolov8.pt"
12
+ latex2layout_model = YOLO(latex2layout_model_path)
13
 
14
+ # Download and load the Qwen2.5-VL-3B model
15
+ qwen_model_path = "Qwen/Qwen2.5-VL-3B"
16
+ qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_path, device_map="auto", trust_remote_code=True)
17
+ qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_path)
18
+
19
+ def detect_layout(image):
20
  """
21
+ Detect layout elements in the image using the Latex2Layout model.
22
 
23
  Args:
24
+ image: The uploaded image (numpy array)
25
 
26
  Returns:
27
+ layout_description: Textual description of detected layout elements
 
28
  """
29
  if image is None:
30
+ return "Error: No image provided."
31
 
32
+ # Run layout detection
33
+ results = latex2layout_model(image)
34
  result = results[0]
35
 
36
+ layout_description = []
 
 
 
 
 
 
 
37
  for box in result.boxes:
38
  x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
 
39
  cls_id = int(box.cls[0])
40
  cls_name = result.names[cls_id]
41
+ layout_description.append(f"{cls_name} at position ({x1}, {y1}, {x2}, {y2})")
42
+
43
+ return ", ".join(layout_description) if layout_description else "No elements detected."
44
+
45
+ def process_image_and_question(image, question):
46
+ """
47
+ Process the image with Latex2Layout and answer the question using Qwen2.5-VL.
48
+
49
+ Args:
50
+ image: The uploaded image (numpy array)
51
+ question: The user's question (string)
52
 
53
+ Returns:
54
+ annotated_image: Image with detection boxes
55
+ response: Answer from Qwen2.5-VL
56
+ """
57
+ if image is None or not question:
58
+ return None, "Error: Please upload an image and provide a question."
59
+
60
+ # Convert numpy image to PIL for Qwen2.5-VL
61
+ image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
62
+
63
+ # Detect layout using Latex2Layout
64
+ layout_description = detect_layout(image)
65
+
66
+ # Prepare annotated image
67
+ annotated_image = image.copy()
68
+ results = latex2layout_model(image)[0]
69
+ for box in results.boxes:
70
+ x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
71
+ conf = float(box.conf[0])
72
+ cls_id = int(box.cls[0])
73
+ cls_name = results.names[cls_id]
74
  color = tuple(np.random.randint(0, 255, 3).tolist())
 
 
75
  cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
76
  label = f'{cls_name} {conf:.2f}'
77
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
78
  cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
79
  cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
 
 
 
 
 
 
 
80
 
81
+ # Prepare input for Qwen2.5-VL
82
+ input_text = f"Layout: {layout_description}\nQuestion: {question}"
83
+ messages = [
84
+ {
85
+ "role": "user",
86
+ "content": [
87
+ {"type": "image", "image": image_pil},
88
+ {"type": "text", "text": input_text}
89
+ ]
90
+ }
91
+ ]
92
 
93
+ # Tokenize and generate response
94
+ inputs = qwen_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
95
+ model_inputs = qwen_tokenizer([inputs], return_tensors="pt").to(qwen_model.device)
96
+ with torch.no_grad():
97
+ output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100)
98
+ response = qwen_tokenizer.decode(output_ids[0][len(model_inputs["input_ids"][0]):], skip_special_tokens=True)
 
 
99
 
100
+ return annotated_image, response
 
 
 
101
 
102
  # Custom CSS for styling
103
  custom_css = """
104
  .container { max-width: 1200px; margin: auto; }
105
  .button-primary { background-color: #4CAF50; color: white; }
 
106
  .gr-image { border: 2px solid #ddd; border-radius: 5px; }
107
+ .gr-textbox { font-family: Arial; }
108
  """
109
 
110
+ # Create Gradio interface
111
  with gr.Blocks(
112
+ title="Latex2Layout Visual Q&A",
113
  theme=gr.themes.Default(),
114
  css=custom_css
115
  ) as demo:
 
116
  gr.Markdown(
117
  """
118
+ # Latex2Layout Visual Q&A
119
+ Upload an image and ask a question about its layout. The **Latex2Layout** model detects elements, and **Qwen2.5-VL** provides answers based on the image and layout information.
120
  """
121
  )
122
 
 
123
  with gr.Row():
 
124
  with gr.Column(scale=1):
125
  input_image = gr.Image(
126
  label="Upload Image",
 
128
  height=400,
129
  elem_classes="gr-image"
130
  )
131
+ question_input = gr.Textbox(
132
+ label="Ask a Question",
133
+ placeholder="e.g., What elements are in the image?",
134
+ lines=2
135
+ )
136
+ submit_btn = gr.Button(
137
+ "Get Answer",
138
  variant="primary",
139
  elem_classes="button-primary"
140
  )
 
141
 
 
142
  with gr.Column(scale=1):
143
  output_image = gr.Image(
144
+ label="Detected Layout",
145
  height=400,
146
  elem_classes="gr-image"
147
  )
148
+ output_text = gr.Textbox(
149
+ label="Answer",
150
+ lines=5,
151
+ max_lines=10,
152
  elem_classes="gr-textbox"
153
  )
 
 
 
 
 
 
 
 
 
154
 
155
+ # Event handler
156
+ submit_btn.click(
157
+ fn=process_image_and_question,
158
+ inputs=[input_image, question_input],
159
+ outputs=[output_image, output_text],
 
 
 
 
 
 
 
160
  _js="() => { document.querySelector('.button-primary').innerText = 'Processing...'; }",
161
  show_progress=True
162
  ).then(
163
+ fn=lambda: gr.update(value="Get Answer"),
164
+ outputs=submit_btn,
165
+ _js="() => { document.querySelector('.button-primary').innerText = 'Get Answer'; }"
 
 
 
 
 
 
166
  )
167
 
 
168
  # Launch the application
169
  if __name__ == "__main__":
170
  demo.launch()