Hanf Chase commited on
Commit
88158ba
·
1 Parent(s): 71e7eab
Files changed (2) hide show
  1. app.py +111 -72
  2. test_yolo.py +0 -58
app.py CHANGED
@@ -5,130 +5,169 @@ import os
5
  import tempfile
6
  from ultralytics import YOLO
7
 
8
- # 加载YOLOv8模型
9
- model_path = "docgenome_object_detection_yolov8.pt"
10
  model = YOLO(model_path)
11
 
12
  def detect_and_visualize(image):
13
  """
14
- 对上传的图像进行目标检测并可视化结果
15
 
16
  Args:
17
- image: 上传的图像
18
 
19
  Returns:
20
- annotated_image: 带有检测框的图像
21
- yolo_annotations: YOLO格式的标注内容
22
  """
23
- # 运行检测
24
- results = model(image)
25
 
26
- # 获取第一帧的结果
 
27
  result = results[0]
28
 
29
- # 创建图像副本用于可视化
30
  annotated_image = image.copy()
 
31
 
32
- # 准备YOLO格式的标注内容
33
- yolo_annotations = []
34
-
35
- # 获取图像尺寸
36
  img_height, img_width = image.shape[:2]
37
 
38
- # 在原图上绘制检测结果
39
  for box in result.boxes:
40
- # 获取边界框坐标
41
- x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
42
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
43
-
44
- # 获取置信度
45
  conf = float(box.conf[0])
46
-
47
- # 获取类别ID和名称
48
  cls_id = int(box.cls[0])
49
  cls_name = result.names[cls_id]
50
 
51
- # 为每个类别生成不同的颜色
52
  color = tuple(np.random.randint(0, 255, 3).tolist())
53
 
54
- # 绘制边界框
55
  cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
56
-
57
- # 准备标签文本
58
  label = f'{cls_name} {conf:.2f}'
59
-
60
- # 计算标签大小
61
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
62
-
63
- # 绘制标签背景
64
  cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
65
-
66
- # 绘制标签文本
67
  cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
68
 
69
- # 转换为YOLO格式 (x_center, y_center, width, height) 归一化到0-1
70
  x_center = (x1 + x2) / (2 * img_width)
71
  y_center = (y1 + y2) / (2 * img_height)
72
  width = (x2 - x1) / img_width
73
  height = (y2 - y1) / img_height
74
-
75
- # 添加到YOLO标注列表
76
- yolo_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
77
 
78
- # 将YOLO标注转换为字符串
79
- yolo_annotations_str = "\n".join(yolo_annotations)
80
-
81
- return annotated_image, yolo_annotations_str
82
 
83
- def save_yolo_annotations(yolo_annotations_str):
84
  """
85
- 保存YOLO标注到临时文件并返回文件路径
86
 
87
  Args:
88
- yolo_annotations_str: YOLO格式的标注字符串
89
 
90
  Returns:
91
- file_path: 保存的标注文件路径
92
  """
93
- # 创建临时文件
94
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
95
- temp_file_path = temp_file.name
96
-
97
- # 写入标注内容
98
- with open(temp_file_path, "w") as f:
99
- f.write(yolo_annotations_str)
100
 
101
- return temp_file_path
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # 创建Gradio界面
104
- with gr.Blocks(title="YOLOv8目标检测可视化") as demo:
105
- gr.Markdown("# YOLOv8目标检测可视化")
106
- gr.Markdown("上传图像,使用YOLOv8模型进行目标检测���并下载YOLO格式的标注。")
 
 
 
 
 
 
 
 
 
107
 
 
108
  with gr.Row():
109
- with gr.Column():
110
- input_image = gr.Image(label="上传图像", type="numpy")
111
- detect_btn = gr.Button("开始检测")
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- with gr.Column():
114
- output_image = gr.Image(label="检测结果")
115
- yolo_annotations = gr.Textbox(label="YOLO标注", lines=10)
116
- download_btn = gr.Button("下载YOLO标注")
117
- download_file = gr.File(label="下载文件")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # 设置点击事件
120
  detect_btn.click(
121
  fn=detect_and_visualize,
122
- inputs=[input_image],
123
- outputs=[output_image, yolo_annotations]
 
 
 
 
 
 
124
  )
125
 
126
  download_btn.click(
127
- fn=save_yolo_annotations,
128
- inputs=[yolo_annotations],
129
- outputs=[download_file]
130
  )
131
 
132
- # 启动应用
 
133
  if __name__ == "__main__":
134
- demo.launch()
 
5
  import tempfile
6
  from ultralytics import YOLO
7
 
8
+ # Load the Latex2Layout model
9
+ model_path = "latex2layout_object_detection_yolov8.pt"
10
  model = YOLO(model_path)
11
 
12
  def detect_and_visualize(image):
13
  """
14
+ Perform layout detection on the uploaded image using the Latex2Layout model and visualize the results.
15
 
16
  Args:
17
+ image: The uploaded image
18
 
19
  Returns:
20
+ annotated_image: Image with detection boxes
21
+ layout_annotations: Annotations in YOLO format
22
  """
23
+ if image is None:
24
+ return None, "Error: No image uploaded."
25
 
26
+ # Run detection using the Latex2Layout model
27
+ results = model(image)
28
  result = results[0]
29
 
30
+ # Create a copy of the image for visualization
31
  annotated_image = image.copy()
32
+ layout_annotations = []
33
 
34
+ # Get image dimensions
 
 
 
35
  img_height, img_width = image.shape[:2]
36
 
37
+ # Draw detection results
38
  for box in result.boxes:
39
+ x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
 
 
 
 
40
  conf = float(box.conf[0])
 
 
41
  cls_id = int(box.cls[0])
42
  cls_name = result.names[cls_id]
43
 
44
+ # Generate a color for each class
45
  color = tuple(np.random.randint(0, 255, 3).tolist())
46
 
47
+ # Draw bounding box and label
48
  cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
 
 
49
  label = f'{cls_name} {conf:.2f}'
 
 
50
  (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
 
 
51
  cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
 
 
52
  cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
53
 
54
+ # Convert to YOLO format (normalized)
55
  x_center = (x1 + x2) / (2 * img_width)
56
  y_center = (y1 + y2) / (2 * img_height)
57
  width = (x2 - x1) / img_width
58
  height = (y2 - y1) / img_height
59
+ layout_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
 
 
60
 
61
+ return annotated_image, "\n".join(layout_annotations)
 
 
 
62
 
63
+ def save_layout_annotations(layout_annotations_str):
64
  """
65
+ Save layout annotations to a temporary file and return the file path.
66
 
67
  Args:
68
+ layout_annotations_str: Annotations string in YOLO format
69
 
70
  Returns:
71
+ file_path: Path to the saved annotation file
72
  """
73
+ if not layout_annotations_str:
74
+ return None
 
 
 
 
 
75
 
76
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
77
+ with open(temp_file.name, "w") as f:
78
+ f.write(layout_annotations_str)
79
+ return temp_file.name
80
+
81
+ # Custom CSS for styling
82
+ custom_css = """
83
+ .container { max-width: 1200px; margin: auto; }
84
+ .button-primary { background-color: #4CAF50; color: white; }
85
+ .button-secondary { background-color: #008CBA; color: white; }
86
+ .gr-image { border: 2px solid #ddd; border-radius: 5px; }
87
+ .gr-textbox { font-family: monospace; }
88
+ """
89
 
90
+ # Create Gradio interface with enhanced styling
91
+ with gr.Blocks(
92
+ title="Latex2Layout Detection",
93
+ theme=gr.themes.Default(),
94
+ css=custom_css
95
+ ) as demo:
96
+ # Header with instructions
97
+ gr.Markdown(
98
+ """
99
+ # Latex2Layout Layout Detection
100
+ Upload an image to detect layout elements using the **Latex2Layout** model. View the annotated image and download the results in YOLO format.
101
+ """
102
+ )
103
 
104
+ # Main layout with two columns
105
  with gr.Row():
106
+ # Input column
107
+ with gr.Column(scale=1):
108
+ input_image = gr.Image(
109
+ label="Upload Image",
110
+ type="numpy",
111
+ height=400,
112
+ elem_classes="gr-image"
113
+ )
114
+ detect_btn = gr.Button(
115
+ "Start Detection",
116
+ variant="primary",
117
+ elem_classes="button-primary"
118
+ )
119
+ gr.Markdown("**Tip**: Upload a clear image for optimal detection results.")
120
 
121
+ # Output column
122
+ with gr.Column(scale=1):
123
+ output_image = gr.Image(
124
+ label="Detection Results",
125
+ height=400,
126
+ elem_classes="gr-image"
127
+ )
128
+ layout_annotations = gr.Textbox(
129
+ label="Layout Annotations (YOLO Format)",
130
+ lines=10,
131
+ max_lines=15,
132
+ elem_classes="gr-textbox"
133
+ )
134
+ download_btn = gr.Button(
135
+ "Download Annotations",
136
+ variant="secondary",
137
+ elem_classes="button-secondary"
138
+ )
139
+ download_file = gr.File(
140
+ label="Download File",
141
+ interactive=False
142
+ )
143
+
144
+ # Example image button (optional)
145
+ with gr.Row():
146
+ gr.Button("Load Example Image").click(
147
+ fn=lambda: cv2.imread("example_image.jpg"),
148
+ outputs=input_image
149
+ )
150
 
151
+ # Event handlers
152
  detect_btn.click(
153
  fn=detect_and_visualize,
154
+ inputs=input_image,
155
+ outputs=[output_image, layout_annotations],
156
+ _js="() => { document.querySelector('.button-primary').innerText = 'Processing...'; }",
157
+ show_progress=True
158
+ ).then(
159
+ fn=lambda: gr.update(value="Start Detection"),
160
+ outputs=detect_btn,
161
+ _js="() => { document.querySelector('.button-primary').innerText = 'Start Detection'; }"
162
  )
163
 
164
  download_btn.click(
165
+ fn=save_layout_annotations,
166
+ inputs=layout_annotations,
167
+ outputs=download_file
168
  )
169
 
170
+
171
+ # Launch the application
172
  if __name__ == "__main__":
173
+ demo.launch()
test_yolo.py DELETED
@@ -1,58 +0,0 @@
1
- from ultralytics import YOLO
2
- import cv2
3
- import numpy as np
4
-
5
- def detect_and_visualize(image_path, model_path):
6
- # 加载YOLOv8模型
7
- model = YOLO(model_path) # 例如 'yolov8n.pt', 'yolov8s.pt' 等
8
-
9
- # 读取图片
10
- image = cv2.imread(image_path)
11
-
12
- # 运行检测
13
- results = model(image)
14
-
15
- # 获取第一帧的结果
16
- result = results[0]
17
-
18
- # 在原图上绘制检测结果
19
- for box in result.boxes:
20
- # 获取边界框坐标
21
- x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
22
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
23
-
24
- # 获取置信度
25
- conf = float(box.conf[0])
26
-
27
- # 获取类别ID和名称
28
- cls_id = int(box.cls[0])
29
- cls_name = result.names[cls_id]
30
-
31
- # 为每个类别生成不同的颜色
32
- color = tuple(np.random.randint(0, 255, 3).tolist())
33
-
34
- # 绘制边界框
35
- cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
36
-
37
- # 准备标签文本
38
- label = f'{cls_name} {conf:.2f}'
39
-
40
- # 计算标签大小
41
- (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
42
-
43
- # 绘制标签背景
44
- cv2.rectangle(image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1)
45
-
46
- # 绘制标签文本
47
- cv2.putText(image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
48
-
49
- # 保存结果图片
50
- output_path = 'output_detected.jpg'
51
- cv2.imwrite(output_path, image)
52
- print(f"检测结果已保存至: {output_path}")
53
-
54
- # 使用示例
55
- if __name__ == "__main__":
56
- image_path = "./test_math.png" # 替换为你的图片路径
57
- model_path = "docgenome_object_detection_yolov8.pt" # 替换为你的模型权重路径
58
- detect_and_visualize(image_path, model_path)