ChaseHan commited on
Commit
f56f5ba
·
verified ·
1 Parent(s): 74651af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -28
app.py CHANGED
@@ -2,15 +2,22 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import os
5
- import requests
6
  import json
7
  from PIL import Image
8
  import io
9
  import base64
10
  from openai import OpenAI
 
11
 
12
- # API endpoints
13
- YOLO_API_ENDPOINT = "https://api.example.com/yolo" # Replace with actual YOLO API endpoint
 
 
 
 
 
 
 
14
 
15
  # Qwen API configuration
16
  QWEN_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
@@ -39,7 +46,7 @@ def encode_image(image_array):
39
 
40
  def detect_layout(image):
41
  """
42
- Perform layout detection on the uploaded image using YOLO API.
43
 
44
  Args:
45
  image: The uploaded image as a numpy array
@@ -51,46 +58,44 @@ def detect_layout(image):
51
  if image is None:
52
  return None, "Error: No image uploaded."
53
 
54
- # Convert numpy array to PIL Image
55
- pil_image = Image.fromarray(image)
56
-
57
- # Convert PIL Image to bytes for API request
58
- img_byte_arr = io.BytesIO()
59
- pil_image.save(img_byte_arr, format='PNG')
60
- img_byte_arr = img_byte_arr.getvalue()
61
-
62
- # Prepare API request
63
- files = {'image': ('image.png', img_byte_arr, 'image/png')}
64
-
65
  try:
66
- # Call YOLO API
67
- response = requests.post(YOLO_API_ENDPOINT, files=files)
68
- response.raise_for_status()
69
- detection_results = response.json()
70
 
71
  # Create a copy of the image for visualization
72
  annotated_image = image.copy()
 
73
 
74
  # Draw detection results
75
- for detection in detection_results:
76
- x1, y1, x2, y2 = detection['bbox']
77
- cls_name = detection['class']
78
- conf = detection['confidence']
 
 
79
 
80
  # Generate a color for each class
81
  color = tuple(np.random.randint(0, 255, 3).tolist())
82
 
83
  # Draw bounding box and label
84
- cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
85
  label = f'{cls_name} {conf:.2f}'
86
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
87
- cv2.rectangle(annotated_image, (int(x1), int(y1)-label_height-5), (int(x1)+label_width, int(y1)), color, -1)
88
- cv2.putText(annotated_image, label, (int(x1), int(y1)-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
 
 
 
 
 
 
 
89
 
90
  # Format layout information for Qwen
91
- layout_info = json.dumps(detection_results, indent=2)
92
 
93
- return annotated_image, layout_info
94
 
95
  except Exception as e:
96
  return None, f"Error during layout detection: {str(e)}"
 
2
  import cv2
3
  import numpy as np
4
  import os
 
5
  import json
6
  from PIL import Image
7
  import io
8
  import base64
9
  from openai import OpenAI
10
+ from ultralytics import YOLO
11
 
12
+ # Load the Latex2Layout model
13
+ model_path = "latex2layout_object_detection_yolov8.pt"
14
+ if not os.path.exists(model_path):
15
+ raise FileNotFoundError(f"Model file not found at {model_path}")
16
+
17
+ try:
18
+ model = YOLO(model_path)
19
+ except Exception as e:
20
+ raise RuntimeError(f"Failed to load Latex2Layout model: {e}")
21
 
22
  # Qwen API configuration
23
  QWEN_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
 
46
 
47
  def detect_layout(image):
48
  """
49
+ Perform layout detection on the uploaded image using local YOLO model.
50
 
51
  Args:
52
  image: The uploaded image as a numpy array
 
58
  if image is None:
59
  return None, "Error: No image uploaded."
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  try:
62
+ # Run detection using local YOLO model
63
+ results = model(image)
64
+ result = results[0]
 
65
 
66
  # Create a copy of the image for visualization
67
  annotated_image = image.copy()
68
+ layout_info = []
69
 
70
  # Draw detection results
71
+ for box in result.boxes:
72
+ # Get bounding box coordinates
73
+ x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
74
+ conf = float(box.conf[0])
75
+ cls_id = int(box.cls[0])
76
+ cls_name = result.names[cls_id]
77
 
78
  # Generate a color for each class
79
  color = tuple(np.random.randint(0, 255, 3).tolist())
80
 
81
  # Draw bounding box and label
82
+ cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
83
  label = f'{cls_name} {conf:.2f}'
84
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
85
+ cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
86
+ cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
87
+
88
+ # Add detection to layout info
89
+ layout_info.append({
90
+ 'bbox': [x1, y1, x2, y2],
91
+ 'class': cls_name,
92
+ 'confidence': conf
93
+ })
94
 
95
  # Format layout information for Qwen
96
+ layout_info_str = json.dumps(layout_info, indent=2)
97
 
98
+ return annotated_image, layout_info_str
99
 
100
  except Exception as e:
101
  return None, f"Error during layout detection: {str(e)}"