Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import io | |
import base64, os | |
from huggingface_hub import snapshot_download | |
import traceback | |
import warnings | |
import sys | |
# Suppress specific warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", message=".*_supports_sdpa.*") | |
# CRITICAL: Fix Florence2 model before any imports | |
def fix_florence2_import(): | |
"""Pre-patch the Florence2 model class before it's imported""" | |
import importlib.util | |
import types | |
# Create a custom import hook | |
class Florence2ImportHook: | |
def find_spec(self, fullname, path, target=None): | |
if "florence2" in fullname.lower() or "modeling_florence2" in fullname: | |
return importlib.util.spec_from_loader(fullname, Florence2Loader()) | |
return None | |
class Florence2Loader: | |
def create_module(self, spec): | |
return None | |
def exec_module(self, module): | |
# Load the original module | |
import importlib.machinery | |
import importlib.util | |
# Find the actual florence2 module | |
for path in sys.path: | |
florence_path = os.path.join(path, "modeling_florence2.py") | |
if os.path.exists(florence_path): | |
spec = importlib.util.spec_from_file_location("modeling_florence2", florence_path) | |
if spec and spec.loader: | |
spec.loader.exec_module(module) | |
# Patch the module after loading | |
if hasattr(module, 'Florence2ForConditionalGeneration'): | |
original_init = module.Florence2ForConditionalGeneration.__init__ | |
def patched_init(self, config): | |
# Add the missing attribute before calling super().__init__ | |
self._supports_sdpa = False | |
original_init(self, config) | |
module.Florence2ForConditionalGeneration.__init__ = patched_init | |
module.Florence2ForConditionalGeneration._supports_sdpa = False | |
break | |
# Install the import hook | |
hook = Florence2ImportHook() | |
sys.meta_path.insert(0, hook) | |
# Apply the fix before any model imports | |
try: | |
fix_florence2_import() | |
except Exception as e: | |
print(f"Warning: Could not apply import hook: {e}") | |
# Alternative fix: Monkey-patch transformers before importing utils | |
def monkey_patch_transformers(): | |
"""Monkey patch transformers to handle _supports_sdpa""" | |
try: | |
import transformers.modeling_utils as modeling_utils | |
original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation | |
def patched_check(self, *args, **kwargs): | |
# Add the attribute if missing | |
if not hasattr(self, '_supports_sdpa'): | |
self._supports_sdpa = False | |
try: | |
return original_check(self, *args, **kwargs) | |
except AttributeError as e: | |
if '_supports_sdpa' in str(e): | |
# Return a safe default | |
return "eager" | |
raise | |
modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check | |
# Also patch the getter | |
original_getattr = modeling_utils.PreTrainedModel.__getattribute__ | |
def patched_getattr(self, name): | |
if name == '_supports_sdpa' and not hasattr(self, '_supports_sdpa'): | |
return False | |
return original_getattr(self, name) | |
modeling_utils.PreTrainedModel.__getattribute__ = patched_getattr | |
print("Successfully patched transformers for Florence2 compatibility") | |
except Exception as e: | |
print(f"Warning: Could not patch transformers: {e}") | |
# Apply the monkey patch | |
monkey_patch_transformers() | |
# Now import the utils after patching | |
from util.utils import check_ocr_box, get_yolo_model, get_som_labeled_img | |
# Download repository (if not already downloaded) | |
repo_id = "microsoft/OmniParser-v2.0" | |
local_dir = "weights" | |
if not os.path.exists(local_dir): | |
snapshot_download(repo_id=repo_id, local_dir=local_dir) | |
print(f"Repository downloaded to: {local_dir}") | |
else: | |
print(f"Weights already exist at: {local_dir}") | |
# Custom function to load caption model with proper error handling | |
def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"): | |
"""Safely load caption model with multiple fallback methods""" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
try: | |
# Method 1: Try the original function with patching | |
from util.utils import get_caption_model_processor | |
return get_caption_model_processor(model_name, model_name_or_path) | |
except AttributeError as e: | |
if '_supports_sdpa' in str(e): | |
print(f"SDPA error detected, trying alternative loading method...") | |
else: | |
raise | |
# Method 2: Load directly with specific configuration | |
try: | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
print(f"Loading caption model from {model_name_or_path} with alternative method...") | |
# Load processor | |
processor = AutoProcessor.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True, | |
revision="main" | |
) | |
# Try to load model with different configurations | |
configs_to_try = [ | |
{"attn_implementation": "eager", "use_cache": False}, | |
{"use_flash_attention_2": False, "use_cache": False}, | |
{"torch_dtype": torch.float32}, # Try float32 instead of float16 | |
] | |
model = None | |
for config in configs_to_try: | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True, | |
device_map="auto" if torch.cuda.is_available() else None, | |
**config | |
) | |
# Ensure the attribute exists | |
if not hasattr(model, '_supports_sdpa'): | |
model._supports_sdpa = False | |
print(f"Model loaded successfully with config: {config}") | |
break | |
except Exception as e: | |
print(f"Failed with config {config}: {e}") | |
continue | |
if model is None: | |
raise RuntimeError("Could not load model with any configuration") | |
# Move to device if needed | |
if device.type == 'cuda' and not next(model.parameters()).is_cuda: | |
model = model.to(device) | |
return {'model': model, 'processor': processor} | |
except Exception as e: | |
print(f"Error in alternative loading: {e}") | |
raise | |
# Load models | |
try: | |
print("Loading YOLO model...") | |
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') | |
print("YOLO model loaded successfully") | |
print("Loading caption model...") | |
caption_model_processor = load_caption_model_safe() | |
print("Caption model loaded successfully") | |
except Exception as e: | |
print(f"Critical error loading models: {e}") | |
print(traceback.format_exc()) | |
caption_model_processor = None | |
# Don't raise here, let the UI handle it | |
# Markdown header text | |
MARKDOWN = """ | |
# OmniParser V2 Proπ₯ | |
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-bottom: 20px;"> | |
<p style="margin: 0;">π― <strong>AI-powered screen understanding tool</strong> that detects UI elements and extracts text with high accuracy.</p> | |
<p style="margin: 5px 0 0 0;">π Supports both PaddleOCR and EasyOCR for flexible text extraction.</p> | |
</div> | |
""" | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {DEVICE}") | |
# Custom CSS for UI enhancement | |
custom_css = """ | |
body { background-color: #f0f2f5; } | |
.gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; } | |
h1, h2, h3, h4 { color: #283E51; } | |
button { border-radius: 6px; transition: all 0.3s ease; } | |
button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); } | |
.output-image { border: 2px solid #e1e4e8; border-radius: 8px; } | |
#input_image { border: 2px dashed #4a90e2; border-radius: 8px; } | |
#input_image:hover { border-color: #2c5aa0; } | |
.gr-box { border-radius: 8px; } | |
.gr-padded { padding: 16px; } | |
""" | |
def process( | |
image_input, | |
box_threshold, | |
iou_threshold, | |
use_paddleocr, | |
imgsz | |
) -> tuple: | |
"""Process image with error handling and validation""" | |
# Input validation | |
if image_input is None: | |
return None, "β οΈ Please upload an image for processing." | |
# Check if caption model is loaded | |
if caption_model_processor is None: | |
return None, "β οΈ Caption model not loaded. There was an error during initialization. Please check the logs." | |
try: | |
# Log processing parameters | |
print(f"Processing with parameters: box_threshold={box_threshold}, " | |
f"iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}") | |
# Calculate overlay ratio based on input image width | |
image_width = image_input.size[0] | |
box_overlay_ratio = max(0.5, min(2.0, image_width / 3200)) | |
draw_bbox_config = { | |
'text_scale': 0.8 * box_overlay_ratio, | |
'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
'text_padding': max(int(3 * box_overlay_ratio), 1), | |
'thickness': max(int(3 * box_overlay_ratio), 1), | |
} | |
# Run OCR bounding box detection | |
try: | |
ocr_bbox_rslt, is_goal_filtered = check_ocr_box( | |
image_input, | |
display_img=False, | |
output_bb_format='xyxy', | |
goal_filtering=None, | |
easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
use_paddleocr=use_paddleocr | |
) | |
# Handle None result from OCR | |
if ocr_bbox_rslt is None: | |
print("OCR returned None, using empty results") | |
text, ocr_bbox = [], [] | |
else: | |
text, ocr_bbox = ocr_bbox_rslt | |
# Validate OCR results | |
if text is None: | |
text = [] | |
if ocr_bbox is None: | |
ocr_bbox = [] | |
print(f"OCR found {len(text)} text regions") | |
except Exception as e: | |
print(f"OCR error: {e}, continuing with empty OCR results") | |
text, ocr_bbox = [], [] | |
# Get labeled image and parsed content | |
try: | |
# Ensure the model has the required attribute | |
if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor: | |
model = caption_model_processor['model'] | |
if not hasattr(model, '_supports_sdpa'): | |
model._supports_sdpa = False | |
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
image_input, | |
yolo_model, | |
BOX_TRESHOLD=box_threshold, | |
output_coord_in_ratio=True, | |
ocr_bbox=ocr_bbox if ocr_bbox else [], | |
draw_bbox_config=draw_bbox_config, | |
caption_model_processor=caption_model_processor, | |
ocr_text=text if text else [], | |
iou_threshold=iou_threshold, | |
imgsz=imgsz | |
) | |
if dino_labled_img is None: | |
raise ValueError("Failed to generate labeled image") | |
except Exception as e: | |
print(f"Error in SOM processing: {e}") | |
print(traceback.format_exc()) | |
return image_input, f"β οΈ Error during element detection: {str(e)}" | |
# Decode processed image from base64 | |
try: | |
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) | |
print('Successfully decoded processed image') | |
except Exception as e: | |
print(f"Error decoding image: {e}") | |
return image_input, f"β οΈ Error decoding processed image: {str(e)}" | |
# Format parsed content list | |
if parsed_content_list and len(parsed_content_list) > 0: | |
parsed_text = "π― **Detected Elements:**\n\n" | |
for i, v in enumerate(parsed_content_list): | |
if v: # Only add non-empty content | |
parsed_text += f"**Icon {i}:** {v}\n" | |
else: | |
parsed_text = "βΉοΈ No UI elements detected. Try adjusting the detection thresholds." | |
print(f'Finished processing image. Found {len(parsed_content_list)} elements.') | |
return image, parsed_text | |
except Exception as e: | |
error_msg = f"β οΈ Unexpected error: {str(e)}" | |
print(f"Error during processing: {e}") | |
print(traceback.format_exc()) | |
return None, error_msg | |
# Build Gradio UI | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro") as demo: | |
gr.Markdown(MARKDOWN) | |
# Check if models loaded successfully | |
if caption_model_processor is None: | |
gr.Markdown("### β οΈ Warning: Caption model failed to load. Some features may not work.") | |
with gr.Row(): | |
# Left sidebar: Upload and settings | |
with gr.Column(scale=1): | |
with gr.Accordion("π€ Upload Image & Settings", open=True): | |
image_input_component = gr.Image( | |
type='pil', | |
label='Upload Screenshot/UI Image', | |
elem_id="input_image" | |
) | |
gr.Markdown("### ποΈ Detection Settings") | |
with gr.Group(): | |
box_threshold_component = gr.Slider( | |
label='π Box Threshold', | |
minimum=0.01, | |
maximum=1.0, | |
step=0.01, | |
value=0.05, | |
info="Lower values detect more elements" | |
) | |
iou_threshold_component = gr.Slider( | |
label='π² IOU Threshold', | |
minimum=0.01, | |
maximum=1.0, | |
step=0.01, | |
value=0.1, | |
info="Controls overlap filtering" | |
) | |
use_paddleocr_component = gr.Checkbox( | |
label='π€ Use PaddleOCR', | |
value=True, | |
info="β PaddleOCR | β EasyOCR" | |
) | |
imgsz_component = gr.Slider( | |
label='π Detection Image Size', | |
minimum=640, | |
maximum=1920, | |
step=32, | |
value=640, | |
info="Higher = better accuracy but slower" | |
) | |
submit_button_component = gr.Button( | |
value='π Process Image', | |
variant='primary', | |
size='lg' | |
) | |
gr.Markdown("### π‘ Quick Tips") | |
gr.Markdown(""" | |
- **Mobile apps:** Use default settings | |
- **Desktop apps:** Try image size 1280 | |
- **Complex UIs:** Lower box threshold to 0.03 | |
- **Too many boxes:** Increase IOU threshold | |
""") | |
# Right main area: Results tabs | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.Tab("πΌοΈ Annotated Image"): | |
image_output_component = gr.Image( | |
type='pil', | |
label='Processed Image with Annotations', | |
elem_classes=["output-image"] | |
) | |
with gr.Tab("π Extracted Elements"): | |
text_output_component = gr.Markdown( | |
value="*Parsed elements will appear here after processing...*", | |
elem_classes=["parsed-text"] | |
) | |
# Button click event | |
submit_button_component.click( | |
fn=process, | |
inputs=[ | |
image_input_component, | |
box_threshold_component, | |
iou_threshold_component, | |
use_paddleocr_component, | |
imgsz_component | |
], | |
outputs=[image_output_component, text_output_component], | |
show_progress=True | |
) | |
# Launch with queue support | |
if __name__ == "__main__": | |
try: | |
# Set environment variables | |
os.environ['TRANSFORMERS_OFFLINE'] = '0' | |
os.environ['HF_HUB_OFFLINE'] = '0' | |
demo.queue(max_size=10) | |
demo.launch( | |
share=False, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |
except Exception as e: | |
print(f"Failed to launch app: {e}") | |
print(traceback.format_exc()) |