ChaseHan commited on
Commit
15de4c7
·
verified ·
1 Parent(s): 49fbaa3
Files changed (1) hide show
  1. app.py +99 -75
app.py CHANGED
@@ -1,126 +1,142 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- import tempfile
5
  from ultralytics import YOLO
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from PIL import Image
8
- import torch
9
 
10
- # Load the Latex2Layout model for layout detection
11
- latex2layout_model_path = "latex2layout_object_detection_yolov8.pt"
12
- latex2layout_model = YOLO(latex2layout_model_path)
13
-
14
- # Download and load the Qwen2.5-VL-3B model
15
- qwen_model_path = "Qwen/Qwen2.5-VL-3B"
16
- qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_path, device_map="auto", trust_remote_code=True)
17
- qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_path)
18
 
19
  def detect_layout(image):
20
  """
21
- Detect layout elements in the image using the Latex2Layout model.
22
 
23
  Args:
24
  image: The uploaded image (numpy array)
25
 
26
  Returns:
27
- layout_description: Textual description of detected layout elements
 
28
  """
29
  if image is None:
30
- return "Error: No image provided."
31
 
32
- # Run layout detection
33
  results = latex2layout_model(image)
34
  result = results[0]
35
 
36
- layout_description = []
 
 
 
 
 
 
 
37
  for box in result.boxes:
38
  x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
 
39
  cls_id = int(box.cls[0])
40
  cls_name = result.names[cls_id]
41
- layout_description.append(f"{cls_name} at position ({x1}, {y1}, {x2}, {y2})")
 
 
 
 
 
 
 
 
 
 
42
 
43
- return ", ".join(layout_description) if layout_description else "No elements detected."
 
44
 
45
- def process_image_and_question(image, question):
46
  """
47
- Process the image with Latex2Layout and answer the question using Qwen2.5-VL.
48
 
49
  Args:
 
50
  image: The uploaded image (numpy array)
51
- question: The user's question (string)
 
52
 
53
  Returns:
54
- annotated_image: Image with detection boxes
55
- response: Answer from Qwen2.5-VL
56
  """
57
- if image is None or not question:
58
- return None, "Error: Please upload an image and provide a question."
59
-
60
- # Convert numpy image to PIL for Qwen2.5-VL
61
- image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
62
-
63
- # Detect layout using Latex2Layout
64
- layout_description = detect_layout(image)
65
-
66
- # Prepare annotated image
67
- annotated_image = image.copy()
68
- results = latex2layout_model(image)[0]
69
- for box in results.boxes:
70
- x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
71
- conf = float(box.conf[0])
72
- cls_id = int(box.cls[0])
73
- cls_name = results.names[cls_id]
74
- color = tuple(np.random.randint(0, 255, 3).tolist())
75
- cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
76
- label = f'{cls_name} {conf:.2f}'
77
- (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
78
- cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
79
- cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
80
 
81
- # Prepare input for Qwen2.5-VL
82
- input_text = f"Layout: {layout_description}\nQuestion: {question}"
83
- messages = [
84
- {
85
- "role": "user",
86
- "content": [
87
- {"type": "image", "image": image_pil},
88
- {"type": "text", "text": input_text}
89
- ]
90
  }
91
- ]
 
 
 
 
 
 
 
 
92
 
93
- # Tokenize and generate response
94
- inputs = qwen_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
95
- model_inputs = qwen_tokenizer([inputs], return_tensors="pt").to(qwen_model.device)
96
- with torch.no_grad():
97
- output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100)
98
- response = qwen_tokenizer.decode(output_ids[0][len(model_inputs["input_ids"][0]):], skip_special_tokens=True)
 
 
 
 
 
 
 
99
 
100
- return annotated_image, response
 
101
 
102
  # Custom CSS for styling
103
  custom_css = """
104
  .container { max-width: 1200px; margin: auto; }
105
  .button-primary { background-color: #4CAF50; color: white; }
106
  .gr-image { border: 2px solid #ddd; border-radius: 5px; }
107
- .gr-textbox { font-family: Arial; }
108
  """
109
 
110
  # Create Gradio interface
111
  with gr.Blocks(
112
- title="Latex2Layout Visual Q&A",
113
  theme=gr.themes.Default(),
114
  css=custom_css
115
  ) as demo:
116
  gr.Markdown(
117
  """
118
- # Latex2Layout Visual Q&A
119
- Upload an image and ask a question about its layout. The **Latex2Layout** model detects elements, and **Qwen2.5-VL** provides answers based on the image and layout information.
120
  """
121
  )
122
 
 
 
 
 
 
 
 
 
123
  with gr.Row():
 
124
  with gr.Column(scale=1):
125
  input_image = gr.Image(
126
  label="Upload Image",
@@ -130,22 +146,30 @@ with gr.Blocks(
130
  )
131
  question_input = gr.Textbox(
132
  label="Ask a Question",
133
- placeholder="e.g., What elements are in the image?",
134
  lines=2
135
  )
136
  submit_btn = gr.Button(
137
- "Get Answer",
138
  variant="primary",
139
  elem_classes="button-primary"
140
  )
 
141
 
 
142
  with gr.Column(scale=1):
143
  output_image = gr.Image(
144
  label="Detected Layout",
145
  height=400,
146
  elem_classes="gr-image"
147
  )
148
- output_text = gr.Textbox(
 
 
 
 
 
 
149
  label="Answer",
150
  lines=5,
151
  max_lines=10,
@@ -155,14 +179,14 @@ with gr.Blocks(
155
  # Event handler
156
  submit_btn.click(
157
  fn=process_image_and_question,
158
- inputs=[input_image, question_input],
159
- outputs=[output_image, output_text],
160
  _js="() => { document.querySelector('.button-primary').innerText = 'Processing...'; }",
161
  show_progress=True
162
  ).then(
163
- fn=lambda: gr.update(value="Get Answer"),
164
  outputs=submit_btn,
165
- _js="() => { document.querySelector('.button-primary').innerText = 'Get Answer'; }"
166
  )
167
 
168
  # Launch the application
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ import requests
5
  from ultralytics import YOLO
 
 
 
6
 
7
+ # Load the Latex2Layout model
8
+ model_path = "latex2layout_object_detection_yolov8.pt"
9
+ latex2layout_model = YOLO(model_path)
 
 
 
 
 
10
 
11
  def detect_layout(image):
12
  """
13
+ Perform layout detection on the uploaded image using the Latex2Layout model.
14
 
15
  Args:
16
  image: The uploaded image (numpy array)
17
 
18
  Returns:
19
+ annotated_image: Image with detection boxes drawn
20
+ layout_info: Text description of detected layout elements
21
  """
22
  if image is None:
23
+ return None, "Error: No image uploaded."
24
 
25
+ # Run detection
26
  results = latex2layout_model(image)
27
  result = results[0]
28
 
29
+ # Create a copy of the image for visualization
30
+ annotated_image = image.copy()
31
+ layout_annotations = []
32
+
33
+ # Get image dimensions
34
+ img_height, img_width = image.shape[:2]
35
+
36
+ # Process detection results
37
  for box in result.boxes:
38
  x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
39
+ conf = float(box.conf[0])
40
  cls_id = int(box.cls[0])
41
  cls_name = result.names[cls_id]
42
+
43
+ # Draw bounding box and label on the image
44
+ color = tuple(np.random.randint(0, 255, 3).tolist())
45
+ cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
46
+ label = f'{cls_name} {conf:.2f}'
47
+ (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
48
+ cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
49
+ cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
50
+
51
+ # Format layout info for Qwen2.5-VL
52
+ layout_annotations.append(f"{cls_name} at position ({x1},{y1},{x2},{y2}) with confidence {conf:.2f}")
53
 
54
+ layout_info = "Detected layout elements: " + "; ".join(layout_annotations) if layout_annotations else "No layout elements detected."
55
+ return annotated_image, layout_info
56
 
57
+ def call_qwen_vl_api(api_url, image, layout_info, question):
58
  """
59
+ Call the Qwen2.5-VL API with the image, layout info, and user question.
60
 
61
  Args:
62
+ api_url: The URL of the Qwen2.5-VL API
63
  image: The uploaded image (numpy array)
64
+ layout_info: Text description of detected layout elements
65
+ question: User's question about the image and layout
66
 
67
  Returns:
68
+ answer: Response from the Qwen2.5-VL API
 
69
  """
70
+ if not api_url:
71
+ return "Error: Please provide a valid Qwen2.5-VL API URL."
72
+ if not question:
73
+ return "Error: Please enter a question."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ try:
76
+ # Convert image to a format suitable for API (e.g., base64 or raw bytes might be needed; adjust per API spec)
77
+ # Here, we assume the API accepts a URL or raw data; for simplicity, we use a placeholder
78
+ payload = {
79
+ "image": image.tolist(), # Adjust this based on API requirements (e.g., base64 encoding)
80
+ "prompt": f"{layout_info}\n\nQuestion: {question}",
 
 
 
81
  }
82
+ response = requests.post(api_url, json=payload, timeout=30)
83
+ response.raise_for_status() # Raise an error for bad status codes
84
+ return response.json().get("answer", "Error: No answer received from API.")
85
+ except requests.exceptions.RequestException as e:
86
+ return f"Error: API call failed - {str(e)}"
87
+
88
+ def process_image_and_question(api_url, image, question):
89
+ """
90
+ Process the image with Latex2Layout and query Qwen2.5-VL API.
91
 
92
+ Args:
93
+ api_url: Qwen2.5-VL API URL
94
+ image: Uploaded image
95
+ question: User's question
96
+
97
+ Returns:
98
+ annotated_image: Image with detection boxes
99
+ layout_info: Detected layout description
100
+ answer: API response to the question
101
+ """
102
+ annotated_image, layout_info = detect_layout(image)
103
+ if annotated_image is None:
104
+ return None, layout_info, "Error: Detection failed."
105
 
106
+ answer = call_qwen_vl_api(api_url, image, layout_info, question)
107
+ return annotated_image, layout_info, answer
108
 
109
  # Custom CSS for styling
110
  custom_css = """
111
  .container { max-width: 1200px; margin: auto; }
112
  .button-primary { background-color: #4CAF50; color: white; }
113
  .gr-image { border: 2px solid #ddd; border-radius: 5px; }
114
+ .gr-textbox { font-family: monospace; }
115
  """
116
 
117
  # Create Gradio interface
118
  with gr.Blocks(
119
+ title="Latex2Layout Detection & QA",
120
  theme=gr.themes.Default(),
121
  css=custom_css
122
  ) as demo:
123
  gr.Markdown(
124
  """
125
+ # Latex2Layout Layout Detection & Q&A
126
+ Upload an image to detect layout elements using the **Latex2Layout** model, then ask questions about the layout and image content using the Qwen2.5-VL API.
127
  """
128
  )
129
 
130
+ # API URL input
131
+ api_url_input = gr.Textbox(
132
+ label="Qwen2.5-VL API URL",
133
+ placeholder="Enter the Qwen2.5-VL API URL here",
134
+ value=""
135
+ )
136
+
137
+ # Main layout
138
  with gr.Row():
139
+ # Input column
140
  with gr.Column(scale=1):
141
  input_image = gr.Image(
142
  label="Upload Image",
 
146
  )
147
  question_input = gr.Textbox(
148
  label="Ask a Question",
149
+ placeholder="e.g., What is the layout structure of this image?",
150
  lines=2
151
  )
152
  submit_btn = gr.Button(
153
+ "Detect & Ask",
154
  variant="primary",
155
  elem_classes="button-primary"
156
  )
157
+ gr.Markdown("**Tip**: Provide a clear image and specific question for best results.")
158
 
159
+ # Output column
160
  with gr.Column(scale=1):
161
  output_image = gr.Image(
162
  label="Detected Layout",
163
  height=400,
164
  elem_classes="gr-image"
165
  )
166
+ layout_output = gr.Textbox(
167
+ label="Layout Information",
168
+ lines=5,
169
+ max_lines=10,
170
+ elem_classes="gr-textbox"
171
+ )
172
+ answer_output = gr.Textbox(
173
  label="Answer",
174
  lines=5,
175
  max_lines=10,
 
179
  # Event handler
180
  submit_btn.click(
181
  fn=process_image_and_question,
182
+ inputs=[api_url_input, input_image, question_input],
183
+ outputs=[output_image, layout_output, answer_output],
184
  _js="() => { document.querySelector('.button-primary').innerText = 'Processing...'; }",
185
  show_progress=True
186
  ).then(
187
+ fn=lambda: gr.update(value="Detect & Ask"),
188
  outputs=submit_btn,
189
+ _js="() => { document.querySelector('.button-primary').innerText = 'Detect & Ask'; }"
190
  )
191
 
192
  # Launch the application