File size: 17,920 Bytes
b20c0ea
0ef105d
b20c0ea
 
 
 
 
 
 
c2f47fd
2187315
a3e6550
2187315
 
 
 
b20c0ea
a3e6550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1974dbd
 
a3e6550
 
b20c0ea
c2f47fd
 
 
 
 
b20c0ea
a3e6550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2187315
a3e6550
2187315
a3e6550
 
 
 
 
 
 
 
2187315
a3e6550
 
 
 
 
 
2187315
a3e6550
 
 
 
 
 
 
 
 
2187315
a3e6550
 
 
 
 
 
 
 
 
 
 
 
 
2187315
a3e6550
 
 
 
 
2187315
 
a3e6550
 
2187315
a3e6550
c2f47fd
2187315
c2f47fd
2187315
 
 
a3e6550
 
 
c2f47fd
2187315
 
 
a3e6550
b20c0ea
1974dbd
b20c0ea
a499933
c2f47fd
 
 
 
 
3934656
b20c0ea
c2f47fd
 
b20c0ea
1974dbd
 
 
c2f47fd
1974dbd
c2f47fd
 
 
 
 
 
 
1974dbd
 
8b5dff8
ef66abf
b20c0ea
 
 
 
 
 
c2f47fd
 
 
 
1974dbd
c2f47fd
1974dbd
2187315
 
a3e6550
2187315
1974dbd
c2f47fd
 
 
 
1974dbd
c2f47fd
a3e6550
c2f47fd
1974dbd
 
 
 
 
 
 
a3e6550
c2f47fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1974dbd
a3e6550
c2f47fd
a3e6550
2187315
 
 
 
 
c2f47fd
 
 
 
 
a3e6550
c2f47fd
 
a3e6550
c2f47fd
 
 
 
 
 
 
 
 
a3e6550
c2f47fd
1974dbd
 
c2f47fd
 
 
 
 
 
1974dbd
a3e6550
c2f47fd
 
 
 
 
 
 
 
 
1974dbd
c2f47fd
1974dbd
c2f47fd
 
 
 
1974dbd
a3e6550
c2f47fd
b20c0ea
1974dbd
a3e6550
 
 
 
b20c0ea
c2f47fd
1974dbd
c2f47fd
1974dbd
 
c2f47fd
1974dbd
 
c2f47fd
 
 
 
 
 
 
 
 
 
a3e6550
c2f47fd
 
 
 
 
 
 
 
a3e6550
c2f47fd
 
 
 
 
a3e6550
c2f47fd
 
 
 
 
 
 
 
a3e6550
c2f47fd
 
1974dbd
c2f47fd
 
 
1974dbd
c2f47fd
 
 
a3e6550
 
 
 
c2f47fd
1974dbd
c2f47fd
1974dbd
 
c2f47fd
b7029f7
c2f47fd
 
 
b7029f7
c2f47fd
 
 
 
 
b7029f7
1974dbd
a3e6550
b20c0ea
a3e6550
b20c0ea
 
 
 
 
 
 
c2f47fd
 
b20c0ea
 
a3e6550
c2f47fd
 
a3e6550
2187315
 
 
c2f47fd
 
 
 
 
 
 
 
 
a3e6550
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
from typing import Optional
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from huggingface_hub import snapshot_download
import traceback
import warnings
import sys

# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*_supports_sdpa.*")

# CRITICAL: Fix Florence2 model before any imports
def fix_florence2_import():
    """Pre-patch the Florence2 model class before it's imported"""
    import importlib.util
    import types
    
    # Create a custom import hook
    class Florence2ImportHook:
        def find_spec(self, fullname, path, target=None):
            if "florence2" in fullname.lower() or "modeling_florence2" in fullname:
                return importlib.util.spec_from_loader(fullname, Florence2Loader())
            return None
    
    class Florence2Loader:
        def create_module(self, spec):
            return None
        
        def exec_module(self, module):
            # Load the original module
            import importlib.machinery
            import importlib.util
            
            # Find the actual florence2 module
            for path in sys.path:
                florence_path = os.path.join(path, "modeling_florence2.py")
                if os.path.exists(florence_path):
                    spec = importlib.util.spec_from_file_location("modeling_florence2", florence_path)
                    if spec and spec.loader:
                        spec.loader.exec_module(module)
                        
                        # Patch the module after loading
                        if hasattr(module, 'Florence2ForConditionalGeneration'):
                            original_init = module.Florence2ForConditionalGeneration.__init__
                            
                            def patched_init(self, config):
                                # Add the missing attribute before calling super().__init__
                                self._supports_sdpa = False
                                original_init(self, config)
                            
                            module.Florence2ForConditionalGeneration.__init__ = patched_init
                            module.Florence2ForConditionalGeneration._supports_sdpa = False
                        break
    
    # Install the import hook
    hook = Florence2ImportHook()
    sys.meta_path.insert(0, hook)

# Apply the fix before any model imports
try:
    fix_florence2_import()
except Exception as e:
    print(f"Warning: Could not apply import hook: {e}")

# Alternative fix: Monkey-patch transformers before importing utils
def monkey_patch_transformers():
    """Monkey patch transformers to handle _supports_sdpa"""
    try:
        import transformers.modeling_utils as modeling_utils
        
        original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation
        
        def patched_check(self, *args, **kwargs):
            # Add the attribute if missing
            if not hasattr(self, '_supports_sdpa'):
                self._supports_sdpa = False
            try:
                return original_check(self, *args, **kwargs)
            except AttributeError as e:
                if '_supports_sdpa' in str(e):
                    # Return a safe default
                    return "eager"
                raise
        
        modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check
        
        # Also patch the getter
        original_getattr = modeling_utils.PreTrainedModel.__getattribute__
        
        def patched_getattr(self, name):
            if name == '_supports_sdpa' and not hasattr(self, '_supports_sdpa'):
                return False
            return original_getattr(self, name)
        
        modeling_utils.PreTrainedModel.__getattribute__ = patched_getattr
        
        print("Successfully patched transformers for Florence2 compatibility")
        
    except Exception as e:
        print(f"Warning: Could not patch transformers: {e}")

# Apply the monkey patch
monkey_patch_transformers()

# Now import the utils after patching
from util.utils import check_ocr_box, get_yolo_model, get_som_labeled_img

# Download repository (if not already downloaded)
repo_id = "microsoft/OmniParser-v2.0"
local_dir = "weights"

if not os.path.exists(local_dir):
    snapshot_download(repo_id=repo_id, local_dir=local_dir)
    print(f"Repository downloaded to: {local_dir}")
else:
    print(f"Weights already exist at: {local_dir}")

# Custom function to load caption model with proper error handling
def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"):
    """Safely load caption model with multiple fallback methods"""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    try:
        # Method 1: Try the original function with patching
        from util.utils import get_caption_model_processor
        return get_caption_model_processor(model_name, model_name_or_path)
    except AttributeError as e:
        if '_supports_sdpa' in str(e):
            print(f"SDPA error detected, trying alternative loading method...")
        else:
            raise
    
    # Method 2: Load directly with specific configuration
    try:
        from transformers import AutoProcessor, AutoModelForCausalLM
        
        print(f"Loading caption model from {model_name_or_path} with alternative method...")
        
        # Load processor
        processor = AutoProcessor.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
            revision="main"
        )
        
        # Try to load model with different configurations
        configs_to_try = [
            {"attn_implementation": "eager", "use_cache": False},
            {"use_flash_attention_2": False, "use_cache": False},
            {"torch_dtype": torch.float32},  # Try float32 instead of float16
        ]
        
        model = None
        for config in configs_to_try:
            try:
                model = AutoModelForCausalLM.from_pretrained(
                    model_name_or_path,
                    trust_remote_code=True,
                    device_map="auto" if torch.cuda.is_available() else None,
                    **config
                )
                
                # Ensure the attribute exists
                if not hasattr(model, '_supports_sdpa'):
                    model._supports_sdpa = False
                
                print(f"Model loaded successfully with config: {config}")
                break
                
            except Exception as e:
                print(f"Failed with config {config}: {e}")
                continue
        
        if model is None:
            raise RuntimeError("Could not load model with any configuration")
        
        # Move to device if needed
        if device.type == 'cuda' and not next(model.parameters()).is_cuda:
            model = model.to(device)
        
        return {'model': model, 'processor': processor}
        
    except Exception as e:
        print(f"Error in alternative loading: {e}")
        raise

# Load models
try:
    print("Loading YOLO model...")
    yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
    print("YOLO model loaded successfully")
    
    print("Loading caption model...")
    caption_model_processor = load_caption_model_safe()
    print("Caption model loaded successfully")
    
except Exception as e:
    print(f"Critical error loading models: {e}")
    print(traceback.format_exc())
    caption_model_processor = None
    # Don't raise here, let the UI handle it

# Markdown header text
MARKDOWN = """
# OmniParser V2 ProπŸ”₯

<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
    <p style="margin: 0;">🎯 <strong>AI-powered screen understanding tool</strong> that detects UI elements and extracts text with high accuracy.</p>
    <p style="margin: 5px 0 0 0;">πŸ“ Supports both PaddleOCR and EasyOCR for flexible text extraction.</p>
</div>
"""

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Custom CSS for UI enhancement
custom_css = """
body { background-color: #f0f2f5; }
.gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; }
h1, h2, h3, h4 { color: #283E51; }
button { border-radius: 6px; transition: all 0.3s ease; }
button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); }
.output-image { border: 2px solid #e1e4e8; border-radius: 8px; }
#input_image { border: 2px dashed #4a90e2; border-radius: 8px; }
#input_image:hover { border-color: #2c5aa0; }
.gr-box { border-radius: 8px; }
.gr-padded { padding: 16px; }
"""

@spaces.GPU
@torch.inference_mode()
def process(
    image_input,
    box_threshold,
    iou_threshold,
    use_paddleocr,
    imgsz
) -> tuple:
    """Process image with error handling and validation"""
    
    # Input validation
    if image_input is None:
        return None, "⚠️ Please upload an image for processing."
    
    # Check if caption model is loaded
    if caption_model_processor is None:
        return None, "⚠️ Caption model not loaded. There was an error during initialization. Please check the logs."
    
    try:
        # Log processing parameters
        print(f"Processing with parameters: box_threshold={box_threshold}, "
              f"iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}")
        
        # Calculate overlay ratio based on input image width
        image_width = image_input.size[0]
        box_overlay_ratio = max(0.5, min(2.0, image_width / 3200))
        
        draw_bbox_config = {
            'text_scale': 0.8 * box_overlay_ratio,
            'text_thickness': max(int(2 * box_overlay_ratio), 1),
            'text_padding': max(int(3 * box_overlay_ratio), 1),
            'thickness': max(int(3 * box_overlay_ratio), 1),
        }
    
        # Run OCR bounding box detection
        try:
            ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
                image_input, 
                display_img=False, 
                output_bb_format='xyxy', 
                goal_filtering=None, 
                easyocr_args={'paragraph': False, 'text_threshold': 0.9}, 
                use_paddleocr=use_paddleocr
            )
            
            # Handle None result from OCR
            if ocr_bbox_rslt is None:
                print("OCR returned None, using empty results")
                text, ocr_bbox = [], []
            else:
                text, ocr_bbox = ocr_bbox_rslt
                
            # Validate OCR results
            if text is None:
                text = []
            if ocr_bbox is None:
                ocr_bbox = []
                
            print(f"OCR found {len(text)} text regions")
            
        except Exception as e:
            print(f"OCR error: {e}, continuing with empty OCR results")
            text, ocr_bbox = [], []
    
        # Get labeled image and parsed content
        try:
            # Ensure the model has the required attribute
            if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor:
                model = caption_model_processor['model']
                if not hasattr(model, '_supports_sdpa'):
                    model._supports_sdpa = False
            
            dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
                image_input, 
                yolo_model, 
                BOX_TRESHOLD=box_threshold, 
                output_coord_in_ratio=True, 
                ocr_bbox=ocr_bbox if ocr_bbox else [],
                draw_bbox_config=draw_bbox_config, 
                caption_model_processor=caption_model_processor, 
                ocr_text=text if text else [],
                iou_threshold=iou_threshold, 
                imgsz=imgsz
            )
            
            if dino_labled_img is None:
                raise ValueError("Failed to generate labeled image")
                
        except Exception as e:
            print(f"Error in SOM processing: {e}")
            print(traceback.format_exc())
            return image_input, f"⚠️ Error during element detection: {str(e)}"
    
        # Decode processed image from base64
        try:
            image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
            print('Successfully decoded processed image')
        except Exception as e:
            print(f"Error decoding image: {e}")
            return image_input, f"⚠️ Error decoding processed image: {str(e)}"
    
        # Format parsed content list
        if parsed_content_list and len(parsed_content_list) > 0:
            parsed_text = "🎯 **Detected Elements:**\n\n"
            for i, v in enumerate(parsed_content_list):
                if v:  # Only add non-empty content
                    parsed_text += f"**Icon {i}:** {v}\n"
        else:
            parsed_text = "ℹ️ No UI elements detected. Try adjusting the detection thresholds."
        
        print(f'Finished processing image. Found {len(parsed_content_list)} elements.')
        return image, parsed_text
        
    except Exception as e:
        error_msg = f"⚠️ Unexpected error: {str(e)}"
        print(f"Error during processing: {e}")
        print(traceback.format_exc())
        return None, error_msg

# Build Gradio UI
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro") as demo:
    gr.Markdown(MARKDOWN)
    
    # Check if models loaded successfully
    if caption_model_processor is None:
        gr.Markdown("### ⚠️ Warning: Caption model failed to load. Some features may not work.")
    
    with gr.Row():
        # Left sidebar: Upload and settings
        with gr.Column(scale=1):
            with gr.Accordion("πŸ“€ Upload Image & Settings", open=True):
                image_input_component = gr.Image(
                    type='pil', 
                    label='Upload Screenshot/UI Image',
                    elem_id="input_image"
                )
                
                gr.Markdown("### πŸŽ›οΈ Detection Settings")
                
                with gr.Group():
                    box_threshold_component = gr.Slider(
                        label='πŸ“Š Box Threshold', 
                        minimum=0.01, 
                        maximum=1.0, 
                        step=0.01, 
                        value=0.05,
                        info="Lower values detect more elements"
                    )
                    
                    iou_threshold_component = gr.Slider(
                        label='πŸ”² IOU Threshold', 
                        minimum=0.01, 
                        maximum=1.0, 
                        step=0.01, 
                        value=0.1,
                        info="Controls overlap filtering"
                    )
                    
                    use_paddleocr_component = gr.Checkbox(
                        label='πŸ”€ Use PaddleOCR', 
                        value=True,
                        info="βœ“ PaddleOCR | βœ— EasyOCR"
                    )
                    
                    imgsz_component = gr.Slider(
                        label='πŸ“ Detection Image Size', 
                        minimum=640, 
                        maximum=1920, 
                        step=32, 
                        value=640,
                        info="Higher = better accuracy but slower"
                    )
                
                submit_button_component = gr.Button(
                    value='πŸš€ Process Image', 
                    variant='primary',
                    size='lg'
                )
                
                gr.Markdown("### πŸ’‘ Quick Tips")
                gr.Markdown("""
                - **Mobile apps:** Use default settings
                - **Desktop apps:** Try image size 1280
                - **Complex UIs:** Lower box threshold to 0.03
                - **Too many boxes:** Increase IOU threshold
                """)
        
        # Right main area: Results tabs
        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.Tab("πŸ–ΌοΈ Annotated Image"):
                    image_output_component = gr.Image(
                        type='pil', 
                        label='Processed Image with Annotations',
                        elem_classes=["output-image"]
                    )
                    
                with gr.Tab("πŸ“ Extracted Elements"):
                    text_output_component = gr.Markdown(
                        value="*Parsed elements will appear here after processing...*",
                        elem_classes=["parsed-text"]
                    )
    
    # Button click event
    submit_button_component.click(
        fn=process,
        inputs=[
            image_input_component,
            box_threshold_component,
            iou_threshold_component,
            use_paddleocr_component,
            imgsz_component
        ],
        outputs=[image_output_component, text_output_component],
        show_progress=True
    )

# Launch with queue support
if __name__ == "__main__":
    try:
        # Set environment variables
        os.environ['TRANSFORMERS_OFFLINE'] = '0'
        os.environ['HF_HUB_OFFLINE'] = '0'
        
        demo.queue(max_size=10)
        demo.launch(
            share=False,
            show_error=True,
            server_name="0.0.0.0",
            server_port=7860
        )
    except Exception as e:
        print(f"Failed to launch app: {e}")
        print(traceback.format_exc())