ChaseHan commited on
Commit
6a41fcf
·
verified ·
1 Parent(s): f56f5ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -113
app.py CHANGED
@@ -9,11 +9,14 @@ import base64
9
  from openai import OpenAI
10
  from ultralytics import YOLO
11
 
12
- # Load the Latex2Layout model
13
  model_path = "latex2layout_object_detection_yolov8.pt"
 
 
14
  if not os.path.exists(model_path):
15
  raise FileNotFoundError(f"Model file not found at {model_path}")
16
 
 
17
  try:
18
  model = YOLO(model_path)
19
  except Exception as e:
@@ -21,191 +24,192 @@ except Exception as e:
21
 
22
  # Qwen API configuration
23
  QWEN_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
24
- QWEN_MODEL_ID = "qwen2.5-vl-3b-instruct"
 
 
 
 
25
 
26
  def encode_image(image_array):
27
  """
28
- Encode numpy array image to base64 string.
29
-
30
  Args:
31
- image_array: numpy array of the image
32
-
33
  Returns:
34
- base64 encoded string of the image
35
  """
36
- # Convert numpy array to PIL Image
37
- pil_image = Image.fromarray(image_array)
38
-
39
- # Convert PIL Image to bytes
40
- img_byte_arr = io.BytesIO()
41
- pil_image.save(img_byte_arr, format='PNG')
42
- img_byte_arr = img_byte_arr.getvalue()
43
-
44
- # Encode to base64
45
- return base64.b64encode(img_byte_arr).decode("utf-8")
46
-
47
- def detect_layout(image):
48
  """
49
- Perform layout detection on the uploaded image using local YOLO model.
50
-
51
  Args:
52
- image: The uploaded image as a numpy array
53
-
 
54
  Returns:
55
- annotated_image: Image with detection boxes
56
- layout_info: Layout detection results
 
57
  """
58
- if image is None:
59
- return None, "Error: No image uploaded."
60
-
61
  try:
62
- # Run detection using local YOLO model
63
  results = model(image)
64
  result = results[0]
65
-
66
- # Create a copy of the image for visualization
67
  annotated_image = image.copy()
68
  layout_info = []
69
-
70
- # Draw detection results
71
  for box in result.boxes:
72
- # Get bounding box coordinates
73
- x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
74
  conf = float(box.conf[0])
 
 
 
 
 
 
75
  cls_id = int(box.cls[0])
76
  cls_name = result.names[cls_id]
77
-
78
- # Generate a color for each class
79
  color = tuple(np.random.randint(0, 255, 3).tolist())
80
-
81
  # Draw bounding box and label
82
  cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
83
- label = f'{cls_name} {conf:.2f}'
84
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
85
- cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
86
- cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
87
-
88
- # Add detection to layout info
89
  layout_info.append({
90
- 'bbox': [x1, y1, x2, y2],
91
- 'class': cls_name,
92
- 'confidence': conf
93
  })
94
-
95
- # Format layout information for Qwen
96
- layout_info_str = json.dumps(layout_info, indent=2)
97
-
98
  return annotated_image, layout_info_str
99
-
100
  except Exception as e:
101
  return None, f"Error during layout detection: {str(e)}"
102
 
103
- def qa_about_layout(image, question, layout_info, api_key):
104
  """
105
- Answer questions about the layout using Qwen2.5-VL API.
106
-
107
  Args:
108
- image: The uploaded image
109
- question: User's question about the layout
110
- layout_info: Layout detection results from YOLO
111
- api_key: User's Qwen API key
112
-
 
113
  Returns:
114
- answer: Qwen's answer to the question
115
  """
116
- if image is None or not question:
117
- return "Please upload an image and ask a question."
118
-
119
- if not layout_info:
120
- return "No layout information available. Please detect layout first."
121
-
122
  if not api_key:
123
- return "Please enter your Qwen API key."
124
-
 
 
125
  try:
126
  # Encode image to base64
127
  base64_image = encode_image(image)
128
-
 
 
 
 
 
129
  # Initialize OpenAI client for Qwen API
130
- client = OpenAI(
131
- api_key=api_key,
132
- base_url=QWEN_BASE_URL,
133
- )
134
-
135
- # Prepare system prompt with layout information
136
- system_prompt = f"""You are a helpful assistant specialized in analyzing document layouts.
137
- The following layout information has been detected in the image:
138
- {layout_info}
139
-
140
- Please answer questions about the layout based on this information and the image."""
141
-
142
- # Prepare messages for API call
143
  messages = [
144
- {
145
- "role": "system",
146
- "content": [{"type": "text", "text": system_prompt}]
147
- },
148
  {
149
  "role": "user",
150
  "content": [
151
- {
152
- "type": "image_url",
153
- "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
154
- },
155
  {"type": "text", "text": question},
156
  ],
157
- }
158
  ]
159
-
160
  # Call Qwen API
161
- completion = client.chat.completions.create(
162
- model=QWEN_MODEL_ID,
163
- messages=messages,
164
- )
165
-
166
  return completion.choices[0].message.content
167
-
168
  except Exception as e:
169
  return f"Error during QA: {str(e)}"
170
 
171
- # Create Gradio interface
172
  with gr.Blocks(title="Latex2Layout QA System") as demo:
173
  gr.Markdown("# Latex2Layout QA System")
174
- gr.Markdown("Upload an image, detect layout elements, and ask questions about the layout.")
175
-
176
  with gr.Row():
177
  with gr.Column(scale=1):
178
  input_image = gr.Image(label="Upload Image", type="numpy")
179
  detect_btn = gr.Button("Detect Layout")
180
- gr.Markdown("**Tip**: Upload a clear image for optimal detection results.")
181
-
182
  with gr.Column(scale=1):
183
- output_image = gr.Image(label="Detection Results")
184
- layout_info = gr.Textbox(label="Layout Information", lines=10)
185
-
186
  with gr.Row():
187
  with gr.Column(scale=1):
188
  api_key_input = gr.Textbox(
189
  label="Qwen API Key",
190
- placeholder="Enter your Qwen API key here",
191
  type="password"
192
  )
193
- question_input = gr.Textbox(label="Ask a question about the layout")
 
 
 
 
 
194
  qa_btn = gr.Button("Ask Question")
195
-
196
  with gr.Column(scale=1):
197
- answer_output = gr.Textbox(label="Answer", lines=5)
198
-
199
  # Event handlers
200
  detect_btn.click(
201
  fn=detect_layout,
202
  inputs=[input_image],
203
  outputs=[output_image, layout_info]
204
  )
205
-
206
  qa_btn.click(
207
  fn=qa_about_layout,
208
- inputs=[input_image, question_input, layout_info, api_key_input],
209
  outputs=[answer_output]
210
  )
211
 
 
9
  from openai import OpenAI
10
  from ultralytics import YOLO
11
 
12
+ # Define the Latex2Layout model path
13
  model_path = "latex2layout_object_detection_yolov8.pt"
14
+
15
+ # Verify model file existence
16
  if not os.path.exists(model_path):
17
  raise FileNotFoundError(f"Model file not found at {model_path}")
18
 
19
+ # Load the Latex2Layout model with error handling
20
  try:
21
  model = YOLO(model_path)
22
  except Exception as e:
 
24
 
25
  # Qwen API configuration
26
  QWEN_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
27
+ QWEN_MODELS = {
28
+ "Qwen2.5-VL-3B-Instruct": "qwen2.5-vl-3b-instruct",
29
+ "Qwen2.5-VL-7B-Instruct": "qwen2.5-vl-7b-instruct",
30
+ "Qwen2.5-VL-14B-Instruct": "qwen2.5-vl-14b-instruct",
31
+ }
32
 
33
  def encode_image(image_array):
34
  """
35
+ Convert a numpy array image to a base64-encoded string.
36
+
37
  Args:
38
+ image_array: Numpy array representing the image.
39
+
40
  Returns:
41
+ str: Base64-encoded string of the image.
42
  """
43
+ try:
44
+ pil_image = Image.fromarray(image_array)
45
+ img_byte_arr = io.BytesIO()
46
+ pil_image.save(img_byte_arr, format='PNG')
47
+ return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
48
+ except Exception as e:
49
+ raise ValueError(f"Failed to encode image: {e}")
50
+
51
+ def detect_layout(image, confidence_threshold=0.5):
 
 
 
52
  """
53
+ Detect layout elements in the uploaded image using the Latex2Layout model.
54
+
55
  Args:
56
+ image: Uploaded image as a numpy array.
57
+ confidence_threshold: Minimum confidence score to retain detections (default: 0.5).
58
+
59
  Returns:
60
+ tuple: (annotated_image, layout_info_str)
61
+ - annotated_image: Image with bounding boxes drawn (confidence >= 0.5).
62
+ - layout_info_str: JSON string of layout detections (confidence >= 0.5).
63
  """
64
+ if image is None or not isinstance(image, np.ndarray):
65
+ return None, "Error: No image uploaded or invalid image format."
66
+
67
  try:
68
+ # Perform detection
69
  results = model(image)
70
  result = results[0]
 
 
71
  annotated_image = image.copy()
72
  layout_info = []
73
+
74
+ # Process detections
75
  for box in result.boxes:
 
 
76
  conf = float(box.conf[0])
77
+ # Filter out detections below the confidence threshold
78
+ if conf < confidence_threshold:
79
+ continue
80
+
81
+ # Extract and convert bounding box coordinates
82
+ x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
83
  cls_id = int(box.cls[0])
84
  cls_name = result.names[cls_id]
85
+
86
+ # Assign a random color for visualization
87
  color = tuple(np.random.randint(0, 255, 3).tolist())
88
+
89
  # Draw bounding box and label
90
  cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
91
+ label = f"{cls_name} {conf:.2f}"
92
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
93
+ cv2.rectangle(annotated_image, (x1, y1 - label_height - 5), (x1 + label_width, y1), color, -1)
94
+ cv2.putText(annotated_image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
95
+
96
+ # Store layout information
97
  layout_info.append({
98
+ "bbox": [x1, y1, x2, y2],
99
+ "class": cls_name,
100
+ "confidence": conf
101
  })
102
+
103
+ # Format layout info as JSON string
104
+ layout_info_str = json.dumps(layout_info, indent=2) if layout_info else "No layout elements detected with confidence >= 0.5."
 
105
  return annotated_image, layout_info_str
106
+
107
  except Exception as e:
108
  return None, f"Error during layout detection: {str(e)}"
109
 
110
+ def qa_about_layout(image, question, layout_info, api_key, model_name):
111
  """
112
+ Answer layout-related questions using the Qwen API.
113
+
114
  Args:
115
+ image: Uploaded image as a numpy array.
116
+ question: User's question about the layout.
117
+ layout_info: JSON string of layout detection results.
118
+ api_key: User's Qwen API key.
119
+ model_name: Selected Qwen model name from dropdown.
120
+
121
  Returns:
122
+ str: Qwen's response to the question.
123
  """
124
+ if image is None or not isinstance(image, np.ndarray):
125
+ return "Error: Please upload a valid image."
126
+ if not question:
127
+ return "Error: Please enter a question."
 
 
128
  if not api_key:
129
+ return "Error: Please provide a Qwen API key."
130
+ if not layout_info:
131
+ return "Error: No layout information available. Detect layout first."
132
+
133
  try:
134
  # Encode image to base64
135
  base64_image = encode_image(image)
136
+
137
+ # Map model name to ID
138
+ model_id = QWEN_MODELS.get(model_name)
139
+ if not model_id:
140
+ return "Error: Invalid Qwen model selected."
141
+
142
  # Initialize OpenAI client for Qwen API
143
+ client = OpenAI(api_key=api_key, base_url=QWEN_BASE_URL)
144
+
145
+ # Construct system prompt with layout info
146
+ system_prompt = f"""You are an assistant specialized in document layout analysis.
147
+ The following layout elements were detected in the image (confidence >= 0.5):
148
+ {layout_info}
149
+
150
+ Use this information and the image to answer layout-related questions."""
151
+
152
+ # Prepare API request messages
 
 
 
153
  messages = [
154
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
 
 
 
155
  {
156
  "role": "user",
157
  "content": [
158
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}},
 
 
 
159
  {"type": "text", "text": question},
160
  ],
161
+ },
162
  ]
163
+
164
  # Call Qwen API
165
+ completion = client.chat.completions.create(model=model_id, messages=messages)
 
 
 
 
166
  return completion.choices[0].message.content
167
+
168
  except Exception as e:
169
  return f"Error during QA: {str(e)}"
170
 
171
+ # Build Gradio interface
172
  with gr.Blocks(title="Latex2Layout QA System") as demo:
173
  gr.Markdown("# Latex2Layout QA System")
174
+ gr.Markdown("Upload an image to detect layout elements and ask questions about the layout using Qwen models.")
175
+
176
  with gr.Row():
177
  with gr.Column(scale=1):
178
  input_image = gr.Image(label="Upload Image", type="numpy")
179
  detect_btn = gr.Button("Detect Layout")
180
+ gr.Markdown("**Tip**: Use clear images for best results.")
181
+
182
  with gr.Column(scale=1):
183
+ output_image = gr.Image(label="Detected Layout")
184
+ layout_info = gr.Textbox(label="Layout Information", lines=10, interactive=False)
185
+
186
  with gr.Row():
187
  with gr.Column(scale=1):
188
  api_key_input = gr.Textbox(
189
  label="Qwen API Key",
190
+ placeholder="Enter your Qwen API key",
191
  type="password"
192
  )
193
+ model_select = gr.Dropdown(
194
+ label="Select Qwen Model",
195
+ choices=list(QWEN_MODELS.keys()),
196
+ value="Qwen2.5-VL-3B-Instruct"
197
+ )
198
+ question_input = gr.Textbox(label="Ask About the Layout", placeholder="e.g., 'Where is the heading?'")
199
  qa_btn = gr.Button("Ask Question")
200
+
201
  with gr.Column(scale=1):
202
+ answer_output = gr.Textbox(label="Answer", lines=5, interactive=False)
203
+
204
  # Event handlers
205
  detect_btn.click(
206
  fn=detect_layout,
207
  inputs=[input_image],
208
  outputs=[output_image, layout_info]
209
  )
 
210
  qa_btn.click(
211
  fn=qa_about_layout,
212
+ inputs=[input_image, question_input, layout_info, api_key_input, model_select],
213
  outputs=[answer_output]
214
  )
215