ginipick commited on
Commit
2187315
·
verified ·
1 Parent(s): c2f47fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -27
app.py CHANGED
@@ -8,6 +8,11 @@ import io
8
  import base64, os
9
  from huggingface_hub import snapshot_download
10
  import traceback
 
 
 
 
 
11
 
12
  # Import 유틸리티 함수들
13
  from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
@@ -23,17 +28,99 @@ if not os.path.exists(local_dir):
23
  else:
24
  print(f"Weights already exist at: {local_dir}")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Load models with error handling
27
  try:
 
28
  yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
29
- caption_model_processor = get_caption_model_processor(
30
- model_name="florence2",
31
- model_name_or_path="weights/icon_caption"
32
- )
33
- print("Models loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  except Exception as e:
35
- print(f"Error loading models: {e}")
36
- raise
 
 
 
37
 
38
  # Markdown header text
39
  MARKDOWN = """
@@ -62,6 +149,22 @@ button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.
62
  .gr-padded { padding: 16px; }
63
  """
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @spaces.GPU
66
  @torch.inference_mode()
67
  def process(
@@ -77,6 +180,10 @@ def process(
77
  if image_input is None:
78
  return None, "⚠️ Please upload an image for processing."
79
 
 
 
 
 
80
  try:
81
  # Log processing parameters
82
  print(f"Processing with parameters: box_threshold={box_threshold}, "
@@ -125,6 +232,12 @@ def process(
125
 
126
  # Get labeled image and parsed content via SOM (YOLO + caption model)
127
  try:
 
 
 
 
 
 
128
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
129
  image_input,
130
  yolo_model,
@@ -141,6 +254,21 @@ def process(
141
  if dino_labled_img is None:
142
  raise ValueError("Failed to generate labeled image")
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  except Exception as e:
145
  print(f"Error in SOM processing: {e}")
146
  # Return original image with error message if SOM fails
@@ -258,7 +386,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro"
258
 
259
  # Button click event with loading spinner
260
  submit_button_component.click(
261
- fn=process,
262
  inputs=[
263
  image_input_component,
264
  box_threshold_component,
@@ -269,29 +397,15 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro"
269
  outputs=[image_output_component, text_output_component],
270
  show_progress=True
271
  )
272
-
273
- # Add sample images if available
274
- if os.path.exists("samples"):
275
- gr.Examples(
276
- examples=[
277
- ["samples/mobile_app.png", 0.05, 0.1, True, 640],
278
- ["samples/desktop_app.png", 0.05, 0.1, True, 1280],
279
- ],
280
- inputs=[
281
- image_input_component,
282
- box_threshold_component,
283
- iou_threshold_component,
284
- use_paddleocr_component,
285
- imgsz_component
286
- ],
287
- outputs=[image_output_component, text_output_component],
288
- fn=process,
289
- cache_examples=False
290
- )
291
 
292
  # Launch with queue support and error handling
293
  if __name__ == "__main__":
294
  try:
 
 
 
 
 
295
  demo.queue(max_size=10)
296
  demo.launch(
297
  share=False,
@@ -301,4 +415,5 @@ if __name__ == "__main__":
301
  )
302
  except Exception as e:
303
  print(f"Failed to launch app: {e}")
 
304
  raise
 
8
  import base64, os
9
  from huggingface_hub import snapshot_download
10
  import traceback
11
+ import warnings
12
+
13
+ # Suppress specific warnings
14
+ warnings.filterwarnings("ignore", category=FutureWarning)
15
+ warnings.filterwarnings("ignore", message=".*_supports_sdpa.*")
16
 
17
  # Import 유틸리티 함수들
18
  from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
 
28
  else:
29
  print(f"Weights already exist at: {local_dir}")
30
 
31
+ # Monkey patch for Florence2 model compatibility
32
+ def patch_florence2_model():
33
+ """Patch Florence2 model to fix compatibility issues with newer transformers"""
34
+ try:
35
+ import transformers
36
+ from transformers import AutoModelForCausalLM
37
+
38
+ # Try to import the Florence2 model class
39
+ try:
40
+ from transformers_modules.microsoft.Florence_2_base_ft.modeling_florence2 import Florence2ForConditionalGeneration
41
+ except ImportError:
42
+ # If not available, we'll patch it when loaded
43
+ pass
44
+
45
+ # Patch the model loading process
46
+ original_from_pretrained = AutoModelForCausalLM.from_pretrained
47
+
48
+ def patched_from_pretrained(model_name_or_path, *args, **kwargs):
49
+ # Force trust_remote_code and add config overrides for Florence2
50
+ if "florence" in model_name_or_path.lower() or "Florence" in model_name_or_path:
51
+ kwargs['trust_remote_code'] = True
52
+ # Add config to avoid SDPA issues
53
+ kwargs['attn_implementation'] = "eager"
54
+ kwargs['use_cache'] = False
55
+
56
+ model = original_from_pretrained(model_name_or_path, *args, **kwargs)
57
+
58
+ # Add missing attributes if needed
59
+ if not hasattr(model, '_supports_sdpa'):
60
+ model._supports_sdpa = False
61
+
62
+ return model
63
+
64
+ AutoModelForCausalLM.from_pretrained = patched_from_pretrained
65
+ print("Applied Florence2 compatibility patch")
66
+
67
+ except Exception as e:
68
+ print(f"Warning: Could not apply Florence2 patch: {e}")
69
+
70
+ # Apply the patch before loading models
71
+ patch_florence2_model()
72
+
73
  # Load models with error handling
74
  try:
75
+ print("Loading YOLO model...")
76
  yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
77
+ print("YOLO model loaded successfully")
78
+
79
+ print("Loading caption model...")
80
+ # Try loading with fallback options
81
+ try:
82
+ caption_model_processor = get_caption_model_processor(
83
+ model_name="florence2",
84
+ model_name_or_path="weights/icon_caption"
85
+ )
86
+ print("Florence2 caption model loaded successfully")
87
+ except Exception as e:
88
+ print(f"Error loading Florence2, trying alternative approach: {e}")
89
+ # Alternative loading method
90
+ import sys
91
+ sys.path.insert(0, "weights/icon_caption")
92
+
93
+ from transformers import AutoProcessor, AutoModelForCausalLM
94
+
95
+ # Load with specific configurations to avoid SDPA issues
96
+ processor = AutoProcessor.from_pretrained(
97
+ "weights/icon_caption",
98
+ trust_remote_code=True,
99
+ revision="main"
100
+ )
101
+
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ "weights/icon_caption",
104
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
105
+ trust_remote_code=True,
106
+ revision="main",
107
+ attn_implementation="eager", # Avoid SDPA issues
108
+ device_map="auto" if torch.cuda.is_available() else None
109
+ )
110
+
111
+ # Add missing attribute
112
+ if not hasattr(model, '_supports_sdpa'):
113
+ model._supports_sdpa = False
114
+
115
+ caption_model_processor = {'model': model, 'processor': processor}
116
+ print("Caption model loaded with alternative method")
117
+
118
  except Exception as e:
119
+ print(f"Critical error loading models: {e}")
120
+ print(traceback.format_exc())
121
+ # Try to continue with a dummy model for testing
122
+ caption_model_processor = None
123
+ raise RuntimeError(f"Failed to load models: {e}")
124
 
125
  # Markdown header text
126
  MARKDOWN = """
 
149
  .gr-padded { padding: 16px; }
150
  """
151
 
152
+ def safe_process_wrapper(*args, **kwargs):
153
+ """Wrapper to handle SDPA attribute errors"""
154
+ try:
155
+ return process(*args, **kwargs)
156
+ except AttributeError as e:
157
+ if '_supports_sdpa' in str(e):
158
+ # Try to fix the model on the fly
159
+ global caption_model_processor
160
+ if caption_model_processor and 'model' in caption_model_processor:
161
+ model = caption_model_processor['model']
162
+ if not hasattr(model, '_supports_sdpa'):
163
+ model._supports_sdpa = False
164
+ return process(*args, **kwargs)
165
+ else:
166
+ raise
167
+
168
  @spaces.GPU
169
  @torch.inference_mode()
170
  def process(
 
180
  if image_input is None:
181
  return None, "⚠️ Please upload an image for processing."
182
 
183
+ # Check if caption model is loaded
184
+ if caption_model_processor is None:
185
+ return None, "⚠️ Caption model not loaded. Please restart the application."
186
+
187
  try:
188
  # Log processing parameters
189
  print(f"Processing with parameters: box_threshold={box_threshold}, "
 
232
 
233
  # Get labeled image and parsed content via SOM (YOLO + caption model)
234
  try:
235
+ # Fix model attributes before calling
236
+ if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor:
237
+ model = caption_model_processor['model']
238
+ if not hasattr(model, '_supports_sdpa'):
239
+ model._supports_sdpa = False
240
+
241
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
242
  image_input,
243
  yolo_model,
 
254
  if dino_labled_img is None:
255
  raise ValueError("Failed to generate labeled image")
256
 
257
+ except AttributeError as e:
258
+ if '_supports_sdpa' in str(e):
259
+ print(f"SDPA attribute error, attempting to fix: {e}")
260
+ # Try to fix and retry
261
+ if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor:
262
+ caption_model_processor['model']._supports_sdpa = False
263
+ # Retry the operation
264
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
265
+ image_input, yolo_model, BOX_TRESHOLD=box_threshold,
266
+ output_coord_in_ratio=True, ocr_bbox=ocr_bbox if ocr_bbox else [],
267
+ draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor,
268
+ ocr_text=text if text else [], iou_threshold=iou_threshold, imgsz=imgsz
269
+ )
270
+ else:
271
+ raise
272
  except Exception as e:
273
  print(f"Error in SOM processing: {e}")
274
  # Return original image with error message if SOM fails
 
386
 
387
  # Button click event with loading spinner
388
  submit_button_component.click(
389
+ fn=safe_process_wrapper, # Use wrapper function
390
  inputs=[
391
  image_input_component,
392
  box_threshold_component,
 
397
  outputs=[image_output_component, text_output_component],
398
  show_progress=True
399
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  # Launch with queue support and error handling
402
  if __name__ == "__main__":
403
  try:
404
+ # Set environment variables for better compatibility
405
+ os.environ['TRANSFORMERS_OFFLINE'] = '0'
406
+ os.environ['HF_HUB_OFFLINE'] = '0'
407
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # For better error messages
408
+
409
  demo.queue(max_size=10)
410
  demo.launch(
411
  share=False,
 
415
  )
416
  except Exception as e:
417
  print(f"Failed to launch app: {e}")
418
+ print(traceback.format_exc())
419
  raise