Tonic commited on
Commit
408a931
Β·
unverified Β·
1 Parent(s): a33adcb

attempt transformers fix

Browse files
Files changed (1) hide show
  1. app.py +228 -12
app.py CHANGED
@@ -18,6 +18,20 @@ import cv2
18
  import re
19
  import warnings
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Try to import spaces module for ZeroGPU compatibility
22
  try:
23
  import spaces
@@ -35,6 +49,118 @@ warnings.filterwarnings("ignore", message="Setting `pad_token_id` to `eos_token_
35
  warnings.filterwarnings("ignore", message="The attention mask is not set and cannot be inferred")
36
  warnings.filterwarnings("ignore", message="The `seen_tokens` attribute is deprecated")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def initialize_model_safely():
40
  """
@@ -53,6 +179,7 @@ def initialize_model_safely():
53
 
54
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
55
 
 
56
  model = AutoModel.from_pretrained(
57
  'ucaslcl/GOT-OCR2_0',
58
  trust_remote_code=True,
@@ -71,8 +198,11 @@ def initialize_model_safely():
71
  if hasattr(model, 'config'):
72
  model.config.pad_token_id = tokenizer.eos_token_id
73
  model.config.eos_token_id = tokenizer.eos_token_id
74
-
75
- return model, tokenizer
 
 
 
76
 
77
  except Exception as e:
78
  print(f"Error initializing model: {str(e)}")
@@ -90,11 +220,16 @@ def initialize_model_safely():
90
  use_safetensors=True
91
  )
92
  model = model.eval().to(device)
93
- return model, tokenizer
 
 
 
 
94
  except Exception as fallback_error:
95
  raise Exception(f"Failed to initialize model: {str(e)}. Fallback also failed: {str(fallback_error)}")
96
 
97
- model, tokenizer = initialize_model_safely()
 
98
 
99
  UPLOAD_FOLDER = "./uploads"
100
  RESULTS_FOLDER = "./results"
@@ -120,8 +255,20 @@ def safe_model_chat(model, tokenizer, image_path, **kwargs):
120
  if "get_max_length" in str(e):
121
  # Try to fix the cache issue by clearing it
122
  try:
 
123
  if hasattr(model, 'clear_cache'):
124
  model.clear_cache()
 
 
 
 
 
 
 
 
 
 
 
125
  # Retry the call
126
  return model.chat(tokenizer, image_path, **kwargs)
127
  except:
@@ -131,9 +278,18 @@ def safe_model_chat(model, tokenizer, image_path, **kwargs):
131
  kwargs_copy = kwargs.copy()
132
  if 'use_cache' in kwargs_copy:
133
  del kwargs_copy['use_cache']
 
 
134
  return model.chat(tokenizer, image_path, **kwargs_copy)
135
  except:
136
- raise Exception("Model compatibility issue: DynamicCache error. Please try again.")
 
 
 
 
 
 
 
137
  else:
138
  raise e
139
  except Exception as e:
@@ -159,8 +315,20 @@ def safe_model_chat_crop(model, tokenizer, image_path, **kwargs):
159
  if "get_max_length" in str(e):
160
  # Try to fix the cache issue by clearing it
161
  try:
 
162
  if hasattr(model, 'clear_cache'):
163
  model.clear_cache()
 
 
 
 
 
 
 
 
 
 
 
164
  # Retry the call
165
  return model.chat_crop(tokenizer, image_path, **kwargs)
166
  except:
@@ -170,9 +338,18 @@ def safe_model_chat_crop(model, tokenizer, image_path, **kwargs):
170
  kwargs_copy = kwargs.copy()
171
  if 'use_cache' in kwargs_copy:
172
  del kwargs_copy['use_cache']
 
 
173
  return model.chat_crop(tokenizer, image_path, **kwargs_copy)
174
  except:
175
- raise Exception("Model compatibility issue: DynamicCache error. Please try again.")
 
 
 
 
 
 
 
176
  else:
177
  raise e
178
  except Exception as e:
@@ -218,19 +395,58 @@ def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
218
  # Wrap model calls in try-except to handle DynamicCache errors
219
  try:
220
  if task == "Plain Text OCR":
221
- res = safe_model_chat(model, tokenizer, image_path, ocr_type='ocr')
 
 
 
 
 
 
 
 
 
222
  return res, None, unique_id
223
  else:
224
  if task == "Format Text OCR":
225
- res = safe_model_chat(model, tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
 
 
 
 
 
 
226
  elif task == "Fine-grained OCR (Box)":
227
- res = safe_model_chat(model, tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path)
 
 
 
 
 
 
228
  elif task == "Fine-grained OCR (Color)":
229
- res = safe_model_chat(model, tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path)
 
 
 
 
 
 
230
  elif task == "Multi-crop OCR":
231
- res = safe_model_chat_crop(model, tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
 
 
 
 
 
 
232
  elif task == "Render Formatted OCR":
233
- res = safe_model_chat(model, tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
 
 
 
 
 
 
234
 
235
  if os.path.exists(result_path):
236
  with open(result_path, 'r') as f:
 
18
  import re
19
  import warnings
20
 
21
+ # Check transformers version for compatibility
22
+ try:
23
+ import transformers
24
+ transformers_version = transformers.__version__
25
+ print(f"Transformers version: {transformers_version}")
26
+
27
+ # Check if we need to use legacy cache handling
28
+ if transformers_version.startswith(('4.4', '4.5', '4.6')):
29
+ USE_LEGACY_CACHE = True
30
+ else:
31
+ USE_LEGACY_CACHE = False
32
+ except:
33
+ USE_LEGACY_CACHE = False
34
+
35
  # Try to import spaces module for ZeroGPU compatibility
36
  try:
37
  import spaces
 
49
  warnings.filterwarnings("ignore", message="The attention mask is not set and cannot be inferred")
50
  warnings.filterwarnings("ignore", message="The `seen_tokens` attribute is deprecated")
51
 
52
+ class ModelCacheManager:
53
+ """
54
+ Manages model cache to prevent DynamicCache errors
55
+ """
56
+ def __init__(self, model):
57
+ self.model = model
58
+ self._clear_all_caches()
59
+
60
+ def _clear_all_caches(self):
61
+ """Clear all possible caches"""
62
+ # Clear model cache
63
+ if hasattr(self.model, 'clear_cache'):
64
+ try:
65
+ self.model.clear_cache()
66
+ except:
67
+ pass
68
+
69
+ if hasattr(self.model, '_clear_cache'):
70
+ try:
71
+ self.model._clear_cache()
72
+ except:
73
+ pass
74
+
75
+ # Clear transformers cache based on version
76
+ try:
77
+ if USE_LEGACY_CACHE:
78
+ # Legacy cache clearing for older transformers versions
79
+ from transformers import GenerationConfig
80
+ if hasattr(GenerationConfig, 'clear_cache'):
81
+ GenerationConfig.clear_cache()
82
+ else:
83
+ # New cache clearing for recent transformers versions
84
+ try:
85
+ from transformers.cache_utils import clear_cache
86
+ clear_cache()
87
+ except:
88
+ pass
89
+
90
+ # Also try the old method as fallback
91
+ try:
92
+ from transformers import GenerationConfig
93
+ if hasattr(GenerationConfig, 'clear_cache'):
94
+ GenerationConfig.clear_cache()
95
+ except:
96
+ pass
97
+ except:
98
+ pass
99
+
100
+ # Clear torch cache
101
+ try:
102
+ import torch
103
+ if torch.cuda.is_available():
104
+ torch.cuda.empty_cache()
105
+ except:
106
+ pass
107
+
108
+ def safe_call(self, method_name, *args, **kwargs):
109
+ """Safely call model methods with cache management"""
110
+ try:
111
+ # First attempt
112
+ method = getattr(self.model, method_name)
113
+ return method(*args, **kwargs)
114
+ except AttributeError as e:
115
+ if "get_max_length" in str(e):
116
+ # Clear cache and retry
117
+ self._clear_all_caches()
118
+ try:
119
+ return method(*args, **kwargs)
120
+ except:
121
+ # Try without cache
122
+ kwargs_copy = kwargs.copy()
123
+ kwargs_copy['use_cache'] = False
124
+ return method(*args, **kwargs_copy)
125
+ else:
126
+ raise e
127
+
128
+ def direct_call(self, method_name, *args, **kwargs):
129
+ """Direct call bypassing all cache mechanisms"""
130
+ try:
131
+ # Disable cache completely
132
+ kwargs_copy = kwargs.copy()
133
+ kwargs_copy['use_cache'] = False
134
+
135
+ # Clear all caches first
136
+ self._clear_all_caches()
137
+
138
+ # Make the call
139
+ method = getattr(self.model, method_name)
140
+ return method(*args, **kwargs_copy)
141
+ except Exception as e:
142
+ # If still failing, try the original safe_call as last resort
143
+ return self.safe_call(method_name, *args, **kwargs)
144
+
145
+ def legacy_call(self, method_name, *args, **kwargs):
146
+ """Legacy call method for older transformers versions"""
147
+ try:
148
+ # For legacy versions, we need to handle cache differently
149
+ kwargs_copy = kwargs.copy()
150
+
151
+ # Remove any cache-related parameters
152
+ if 'use_cache' in kwargs_copy:
153
+ del kwargs_copy['use_cache']
154
+
155
+ # Clear caches
156
+ self._clear_all_caches()
157
+
158
+ # Make the call
159
+ method = getattr(self.model, method_name)
160
+ return method(*args, **kwargs_copy)
161
+ except Exception as e:
162
+ # Fallback to direct call
163
+ return self.direct_call(method_name, *args, **kwargs)
164
 
165
  def initialize_model_safely():
166
  """
 
179
 
180
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
181
 
182
+ # Initialize model with proper settings to avoid warnings
183
  model = AutoModel.from_pretrained(
184
  'ucaslcl/GOT-OCR2_0',
185
  trust_remote_code=True,
 
198
  if hasattr(model, 'config'):
199
  model.config.pad_token_id = tokenizer.eos_token_id
200
  model.config.eos_token_id = tokenizer.eos_token_id
201
+
202
+ # Create cache manager
203
+ cache_manager = ModelCacheManager(model)
204
+
205
+ return model, tokenizer, cache_manager
206
 
207
  except Exception as e:
208
  print(f"Error initializing model: {str(e)}")
 
220
  use_safetensors=True
221
  )
222
  model = model.eval().to(device)
223
+
224
+ # Create cache manager for fallback model
225
+ cache_manager = ModelCacheManager(model)
226
+
227
+ return model, tokenizer, cache_manager
228
  except Exception as fallback_error:
229
  raise Exception(f"Failed to initialize model: {str(e)}. Fallback also failed: {str(fallback_error)}")
230
 
231
+ # Initialize model, tokenizer, and cache manager
232
+ model, tokenizer, cache_manager = initialize_model_safely()
233
 
234
  UPLOAD_FOLDER = "./uploads"
235
  RESULTS_FOLDER = "./results"
 
255
  if "get_max_length" in str(e):
256
  # Try to fix the cache issue by clearing it
257
  try:
258
+ # Clear any existing cache
259
  if hasattr(model, 'clear_cache'):
260
  model.clear_cache()
261
+ elif hasattr(model, '_clear_cache'):
262
+ model._clear_cache()
263
+
264
+ # Try to clear cache from transformers
265
+ try:
266
+ from transformers import GenerationConfig
267
+ if hasattr(GenerationConfig, 'clear_cache'):
268
+ GenerationConfig.clear_cache()
269
+ except:
270
+ pass
271
+
272
  # Retry the call
273
  return model.chat(tokenizer, image_path, **kwargs)
274
  except:
 
278
  kwargs_copy = kwargs.copy()
279
  if 'use_cache' in kwargs_copy:
280
  del kwargs_copy['use_cache']
281
+
282
+ # Try with cache disabled
283
  return model.chat(tokenizer, image_path, **kwargs_copy)
284
  except:
285
+ # Last resort: try to recreate the model call without cache
286
+ try:
287
+ # Force cache clearing by setting use_cache=False
288
+ kwargs_copy = kwargs.copy()
289
+ kwargs_copy['use_cache'] = False
290
+ return model.chat(tokenizer, image_path, **kwargs_copy)
291
+ except:
292
+ raise Exception("Model compatibility issue: DynamicCache error. Please try again.")
293
  else:
294
  raise e
295
  except Exception as e:
 
315
  if "get_max_length" in str(e):
316
  # Try to fix the cache issue by clearing it
317
  try:
318
+ # Clear any existing cache
319
  if hasattr(model, 'clear_cache'):
320
  model.clear_cache()
321
+ elif hasattr(model, '_clear_cache'):
322
+ model._clear_cache()
323
+
324
+ # Try to clear cache from transformers
325
+ try:
326
+ from transformers import GenerationConfig
327
+ if hasattr(GenerationConfig, 'clear_cache'):
328
+ GenerationConfig.clear_cache()
329
+ except:
330
+ pass
331
+
332
  # Retry the call
333
  return model.chat_crop(tokenizer, image_path, **kwargs)
334
  except:
 
338
  kwargs_copy = kwargs.copy()
339
  if 'use_cache' in kwargs_copy:
340
  del kwargs_copy['use_cache']
341
+
342
+ # Try with cache disabled
343
  return model.chat_crop(tokenizer, image_path, **kwargs_copy)
344
  except:
345
+ # Last resort: try to recreate the model call without cache
346
+ try:
347
+ # Force cache clearing by setting use_cache=False
348
+ kwargs_copy = kwargs.copy()
349
+ kwargs_copy['use_cache'] = False
350
+ return model.chat_crop(tokenizer, image_path, **kwargs_copy)
351
+ except:
352
+ raise Exception("Model compatibility issue: DynamicCache error. Please try again.")
353
  else:
354
  raise e
355
  except Exception as e:
 
395
  # Wrap model calls in try-except to handle DynamicCache errors
396
  try:
397
  if task == "Plain Text OCR":
398
+ # Use cache manager for safer calls
399
+ try:
400
+ res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type='ocr')
401
+ except:
402
+ try:
403
+ # Fallback to direct call
404
+ res = cache_manager.direct_call('chat', tokenizer, image_path, ocr_type='ocr')
405
+ except:
406
+ # Final fallback to legacy call
407
+ res = cache_manager.legacy_call('chat', tokenizer, image_path, ocr_type='ocr')
408
  return res, None, unique_id
409
  else:
410
  if task == "Format Text OCR":
411
+ try:
412
+ res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
413
+ except:
414
+ try:
415
+ res = cache_manager.direct_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
416
+ except:
417
+ res = cache_manager.legacy_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
418
  elif task == "Fine-grained OCR (Box)":
419
+ try:
420
+ res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path)
421
+ except:
422
+ try:
423
+ res = cache_manager.direct_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path)
424
+ except:
425
+ res = cache_manager.legacy_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=result_path)
426
  elif task == "Fine-grained OCR (Color)":
427
+ try:
428
+ res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path)
429
+ except:
430
+ try:
431
+ res = cache_manager.direct_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path)
432
+ except:
433
+ res = cache_manager.legacy_call('chat', tokenizer, image_path, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=result_path)
434
  elif task == "Multi-crop OCR":
435
+ try:
436
+ res = cache_manager.safe_call('chat_crop', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
437
+ except:
438
+ try:
439
+ res = cache_manager.direct_call('chat_crop', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
440
+ except:
441
+ res = cache_manager.legacy_call('chat_crop', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
442
  elif task == "Render Formatted OCR":
443
+ try:
444
+ res = cache_manager.safe_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
445
+ except:
446
+ try:
447
+ res = cache_manager.direct_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
448
+ except:
449
+ res = cache_manager.legacy_call('chat', tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
450
 
451
  if os.path.exists(result_path):
452
  with open(result_path, 'r') as f: