Spaces:
Running
Running
import gradio as gr | |
import cv2 | |
import numpy as np | |
import os | |
import tempfile | |
from ultralytics import YOLO | |
# 加载YOLOv8模型 | |
model_path = "docgenome_object_detection_yolov8.pt" | |
model = YOLO(model_path) | |
def detect_and_visualize(image): | |
""" | |
对上传的图像进行目标检测并可视化结果 | |
Args: | |
image: 上传的图像 | |
Returns: | |
annotated_image: 带有检测框的图像 | |
yolo_annotations: YOLO格式的标注内容 | |
""" | |
# 运行检测 | |
results = model(image) | |
# 获取第一帧的结果 | |
result = results[0] | |
# 创建图像副本用于可视化 | |
annotated_image = image.copy() | |
# 准备YOLO格式的标注内容 | |
yolo_annotations = [] | |
# 获取图像尺寸 | |
img_height, img_width = image.shape[:2] | |
# 在原图上绘制检测结果 | |
for box in result.boxes: | |
# 获取边界框坐标 | |
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
# 获取置信度 | |
conf = float(box.conf[0]) | |
# 获取类别ID和名称 | |
cls_id = int(box.cls[0]) | |
cls_name = result.names[cls_id] | |
# 为每个类别生成不同的颜色 | |
color = tuple(np.random.randint(0, 255, 3).tolist()) | |
# 绘制边界框 | |
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2) | |
# 准备标签文本 | |
label = f'{cls_name} {conf:.2f}' | |
# 计算标签大小 | |
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
# 绘制标签背景 | |
cv2.rectangle(annotated_image, (x1, y1-label_height-5), (x1+label_width, y1), color, -1) | |
# 绘制标签文本 | |
cv2.putText(annotated_image, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
# 转换为YOLO格式 (x_center, y_center, width, height) 归一化到0-1 | |
x_center = (x1 + x2) / (2 * img_width) | |
y_center = (y1 + y2) / (2 * img_height) | |
width = (x2 - x1) / img_width | |
height = (y2 - y1) / img_height | |
# 添加到YOLO标注列表 | |
yolo_annotations.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}") | |
# 将YOLO标注转换为字符串 | |
yolo_annotations_str = "\n".join(yolo_annotations) | |
return annotated_image, yolo_annotations_str | |
def save_yolo_annotations(yolo_annotations_str): | |
""" | |
保存YOLO标注到临时文件并返回文件路径 | |
Args: | |
yolo_annotations_str: YOLO格式的标注字符串 | |
Returns: | |
file_path: 保存的标注文件路径 | |
""" | |
# 创建临时文件 | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt") | |
temp_file_path = temp_file.name | |
# 写入标注内容 | |
with open(temp_file_path, "w") as f: | |
f.write(yolo_annotations_str) | |
return temp_file_path | |
# 创建Gradio界面 | |
with gr.Blocks(title="YOLOv8目标检测可视化") as demo: | |
gr.Markdown("# YOLOv8目标检测可视化") | |
gr.Markdown("上传图像,使用YOLOv8模型进行目标检测,并下载YOLO格式的标注。") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="上传图像", type="numpy") | |
detect_btn = gr.Button("开始检测") | |
with gr.Column(): | |
output_image = gr.Image(label="检测结果") | |
yolo_annotations = gr.Textbox(label="YOLO标注", lines=10) | |
download_btn = gr.Button("下载YOLO标注") | |
download_file = gr.File(label="下载文件") | |
# 设置点击事件 | |
detect_btn.click( | |
fn=detect_and_visualize, | |
inputs=[input_image], | |
outputs=[output_image, yolo_annotations] | |
) | |
download_btn.click( | |
fn=save_yolo_annotations, | |
inputs=[yolo_annotations], | |
outputs=[download_file] | |
) | |
# 启动应用 | |
if __name__ == "__main__": | |
demo.launch() |