gaoyu1314's picture
Upload app.py
3ef5de0 verified
raw
history blame
4.74 kB
from typing import Optional
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image
import ast
# 定义模型路径,使用相对路径,并使用 os.path.join 确保跨平台兼容性
MODEL_DIR = 'weights'
YOLO_MODEL_PATH = os.path.join(MODEL_DIR, 'icon_detect', 'model.pt')
CAPTION_MODEL_PATH = os.path.join(MODEL_DIR, 'icon_caption')
# BLIP2_CAPTION_MODEL_PATH = os.path.join(MODEL_DIR, 'icon_caption_blip2') # 如果使用 BLIP2 模型
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
caption_model_processor = get_caption_model_processor(model_name="ollama", model_name_or_path=CAPTION_MODEL_PATH)
# caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent嘻嘻 🔥
<div>
<a href="https://arxiv.org/pdf/2408.00203">
<img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
</a>
</div>
OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
"""
DEVICE = torch.device('cuda')
def process(
image_input,
box_threshold,
iou_threshold,
use_paddleocr,
imgsz
) -> Optional[Image.Image]:
image_save_path = 'imgs/saved_image_demo.png'
image_input.save(image_save_path)
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
text, ocr_bbox_input = ocr_bbox_rslt
# Correctly handle ocr_bbox and ocr_text
if ocr_bbox_input is None or not ocr_bbox_input:
ocr_bbox = []
ocr_text = []
else:
ocr_bbox = []
for box_str in ocr_bbox_input:
try:
# 使用 eval(),但要非常小心!
box = eval(box_str) # 转换为元组
ocr_bbox.append(box)
except (SyntaxError, NameError, TypeError, ValueError):
print(f"警告:无法解析边界框字符串:{box_str}") # 打印警告信息,但继续处理其他框
continue # 跳过错误的框
ocr_text = text # 使用 check_ocr_box 返回的 text
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=ocr_text, iou_threshold=iou_threshold, imgsz=imgsz)
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print('finish processing')
parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
return image, str(parsed_content_list)
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
image_input_component = gr.Image(type='pil', label='Upload image')
box_threshold_component = gr.Slider(label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
iou_threshold_component = gr.Slider(label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
use_paddleocr_component = gr.Checkbox(label='Use PaddleOCR', value=True)
imgsz_component = gr.Slider(label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640)
submit_button_component = gr.Button(value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(type='pil', label='Image Output')
text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
submit_button_component.click(
fn=process,
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component,
use_paddleocr_component,
imgsz_component
],
outputs=[image_output_component, text_output_component]
)
demo.launch(share=True, server_port=7861, server_name='0.0.0.0')