lambertxiao commited on
Commit
b98d7a4
·
verified ·
1 Parent(s): e19b803

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -305
app.py CHANGED
@@ -1,311 +1,26 @@
1
- import spaces
2
  import gradio as gr
3
  from transformers import AutoModel, AutoProcessor
4
  from PIL import Image
5
  import torch
6
  import numpy as np
7
- import sys
8
- import os
9
- from pathlib import Path
10
- import shutil
11
- import types
12
 
13
- # Set environment variables for better debugging
14
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
15
-
16
- # Model configuration
17
- model_name_or_path = "lambertxiao/Vision-Language-Vision-Captioner-Qwen2.5-3B"
18
-
19
- def create_fake_module_structure():
20
- """Create fake module structure to handle import errors during download"""
21
- # Create the module hierarchy that the model expects
22
- transformers_modules = types.ModuleType('transformers_modules')
23
- lambertxiao = types.ModuleType('lambertxiao')
24
- vision_captioner = types.ModuleType('Vision-Language-Vision-Captioner-Qwen2')
25
-
26
- # Set up the module hierarchy
27
- sys.modules['transformers_modules'] = transformers_modules
28
- sys.modules['transformers_modules.lambertxiao'] = lambertxiao
29
- sys.modules['transformers_modules.lambertxiao.Vision-Language-Vision-Captioner-Qwen2'] = vision_captioner
30
-
31
- # Link the modules
32
- setattr(transformers_modules, 'lambertxiao', lambertxiao)
33
- setattr(lambertxiao, 'Vision-Language-Vision-Captioner-Qwen2', vision_captioner)
34
-
35
- # Also handle the dot notation
36
- sys.modules['transformers_modules.lambertxiao.Vision-Language-Vision-Captioner-Qwen2.5-3B'] = vision_captioner
37
-
38
- def fix_imports_in_file(file_path):
39
- """Fix import statements in a Python file"""
40
- try:
41
- with open(file_path, 'r', encoding='utf-8') as f:
42
- content = f.read()
43
-
44
- original_content = content
45
-
46
- # Fix relative imports
47
- replacements = [
48
- ("from .De_DiffusionV2_Image import", "from De_DiffusionV2_Image import"),
49
- ("from .modeling_clip import", "from modeling_clip import"),
50
- ("from .configuration_clip import", "from configuration_clip import"),
51
- ("from .modeling_florence2 import", "from modeling_florence2 import"),
52
- ("from .configuration_florence2 import", "from configuration_florence2 import"),
53
- ("from .processing_florence2 import", "from processing_florence2 import"),
54
- ("from .utils import", "from utils import"),
55
- ("from .build_unfreeze import", "from build_unfreeze import"),
56
- ("from .sd_config import", "from sd_config import"),
57
- ]
58
-
59
- for old, new in replacements:
60
- content = content.replace(old, new)
61
-
62
- # Remove or fix the problematic transformers_modules imports
63
- content = content.replace(
64
- "from transformers_modules.lambertxiao.Vision-Language-Vision-Captioner-Qwen2",
65
- "# Fixed import - removed transformers_modules prefix\nfrom"
66
- )
67
-
68
- if content != original_content:
69
- with open(file_path, 'w', encoding='utf-8') as f:
70
- f.write(content)
71
- return True
72
- except Exception as e:
73
- print(f"Error fixing {file_path}: {e}")
74
- return False
75
-
76
- def monitor_and_fix_downloads():
77
- """Monitor the cache directory and fix files as they are downloaded"""
78
- cache_base = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules"
79
-
80
- # Create a set to track fixed files
81
- fixed_files = set()
82
-
83
- def fix_new_files():
84
- # Look for Python files in the cache
85
- for py_file in cache_base.rglob("*.py"):
86
- if str(py_file) not in fixed_files:
87
- if fix_imports_in_file(py_file):
88
- print(f"✓ Fixed imports in: {py_file.name}")
89
- fixed_files.add(str(py_file))
90
-
91
- return fix_new_files
92
-
93
- # Create fake module structure first
94
- print("🔧 Setting up module structure...")
95
- create_fake_module_structure()
96
-
97
- # Setup file monitoring
98
- fix_files = monitor_and_fix_downloads()
99
-
100
- # Custom import hook to fix files on the fly
101
- class ImportFixer:
102
- def __init__(self):
103
- self.fixed_modules = set()
104
-
105
- def find_spec(self, name, path, target=None):
106
- # Fix files whenever an import is attempted
107
- fix_files()
108
- return None
109
-
110
- # Install the import hook
111
- import_fixer = ImportFixer()
112
- sys.meta_path.insert(0, import_fixer)
113
-
114
- print("📥 Downloading and loading model...")
115
-
116
- # First attempt - this might fail but will download files
117
- try:
118
- from transformers import AutoConfig
119
-
120
- # Add paths before attempting to load
121
- cache_base = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules"
122
- possible_paths = [
123
- cache_base / "lambertxiao" / "Vision-Language-Vision-Captioner-Qwen2.5-3B",
124
- cache_base / "lambertxiao",
125
- cache_base,
126
- ]
127
-
128
- for path in possible_paths:
129
- if path.exists() and str(path) not in sys.path:
130
- sys.path.insert(0, str(path))
131
-
132
- # Try to load config - this triggers download
133
- config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
134
- print("✓ Config loaded successfully")
135
-
136
- except Exception as e:
137
- print(f"⚠️ Initial load failed (expected): {e}")
138
- print("🔧 Fixing downloaded files...")
139
-
140
- # Fix all downloaded files
141
- fix_files()
142
-
143
- # Find and add all relevant directories to path
144
- cache_base = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules"
145
- for subdir in cache_base.rglob("*"):
146
- if subdir.is_dir() and "lambertxiao" in str(subdir):
147
- if str(subdir) not in sys.path:
148
- sys.path.insert(0, str(subdir))
149
-
150
- # Now load the model - should work after fixes
151
- print("\n📊 Loading model with fixed imports...")
152
- try:
153
- # Remove the import hook to avoid interference
154
- sys.meta_path.remove(import_fixer)
155
-
156
- # Load the model
157
- model = AutoModel.from_pretrained(
158
- model_name_or_path,
159
- trust_remote_code=True,
160
- low_cpu_mem_usage=False,
161
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
162
- )
163
-
164
- # Move to GPU if available
165
- if torch.cuda.is_available():
166
- model = model.to("cuda")
167
- print(f"✓ Model loaded on GPU: {torch.cuda.get_device_name(0)}")
168
- else:
169
- model = model.to("cpu")
170
- print("✓ Model loaded on CPU")
171
-
172
- except Exception as e:
173
- print(f"❌ Error loading model: {e}")
174
-
175
- # Last resort - try with minimal setup
176
- print("🔧 Attempting minimal setup...")
177
-
178
- # Clear any problematic imports
179
- modules_to_remove = [k for k in sys.modules.keys() if 'lambertxiao' in k or 'Vision-Language-Vision' in k]
180
- for module in modules_to_remove:
181
- del sys.modules[module]
182
-
183
- # Re-create fake modules
184
- create_fake_module_structure()
185
-
186
- # Try one more time
187
- model = AutoModel.from_pretrained(
188
- model_name_or_path,
189
- trust_remote_code=True,
190
- device_map="auto"
191
- )
192
-
193
- print("\n✅ Model setup complete!")
194
-
195
- def drop_incomplete_tail(text):
196
- """Remove incomplete sentences from the end of text"""
197
- if not text:
198
- return ""
199
-
200
- sentences = text.split('.')
201
- complete_sentences = [s.strip() for s in sentences if s.strip()]
202
-
203
- if not text.strip().endswith('.') and complete_sentences:
204
- complete_sentences = complete_sentences[:-1]
205
-
206
- result = '. '.join(complete_sentences)
207
- if result and complete_sentences:
208
- result += '.'
209
-
210
- return result
211
-
212
- @spaces.GPU(duration=120)
213
- def caption_image(image):
214
- """Generate caption for the image"""
215
- try:
216
- # Ensure model is on GPU when using spaces.GPU
217
- if torch.cuda.is_available():
218
- if hasattr(model, 'device') and model.device.type != 'cuda':
219
- model.to("cuda")
220
-
221
- with torch.no_grad():
222
- try:
223
- outputs = model([image], 77)
224
- except RuntimeError as e:
225
- if "CUDA error" in str(e) or "device-side assert" in str(e):
226
- print(f"⚠️ CUDA error: {e}")
227
- torch.cuda.empty_cache()
228
- torch.cuda.synchronize()
229
-
230
- # Retry with different approach
231
- outputs = model.generate(images=[image], max_length=77)
232
- else:
233
- raise e
234
-
235
- # Handle different output formats
236
- if hasattr(outputs, 'generated_text'):
237
- text = outputs.generated_text[0] if isinstance(outputs.generated_text, list) else outputs.generated_text
238
- elif isinstance(outputs, list):
239
- text = outputs[0]
240
- elif isinstance(outputs, str):
241
- text = outputs
242
- else:
243
- text = str(outputs)
244
-
245
- return text
246
-
247
- except Exception as e:
248
- print(f"Error in caption_image: {e}")
249
- return f"Error generating caption: {str(e)}"
250
-
251
- def process_image(image):
252
- """Process input image and generate caption"""
253
- try:
254
- # Convert to PIL Image if needed
255
- if isinstance(image, np.ndarray):
256
- if image.dtype != np.uint8:
257
- image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
258
-
259
- if len(image.shape) == 2:
260
- image = Image.fromarray(image, mode='L').convert('RGB')
261
- elif len(image.shape) == 3:
262
- if image.shape[2] == 4:
263
- image = Image.fromarray(image, mode='RGBA').convert('RGB')
264
- else:
265
- image = Image.fromarray(image, mode='RGB')
266
- elif isinstance(image, Image.Image):
267
- if image.mode != 'RGB':
268
- image = image.convert('RGB')
269
-
270
- # Generate caption
271
- raw_text = caption_image(image)
272
-
273
- # Clean up the text
274
- cleaned_text = drop_incomplete_tail(raw_text)
275
-
276
- return cleaned_text
277
-
278
- except Exception as e:
279
- print(f"Error processing image: {e}")
280
- return f"Error: {str(e)}"
281
-
282
- # Create Gradio interface
283
- demo = gr.Interface(
284
- fn=process_image,
285
- inputs=gr.Image(type="pil", label="Upload an image"),
286
- outputs=gr.Textbox(label="Generated Caption", lines=3),
287
- title="Vision-Language Image Captioner",
288
- description="Upload an image to generate a detailed caption using Vision-Language-Vision-Captioner-Qwen2.5-3B",
289
- examples=[],
290
- cache_examples=False,
291
- theme=gr.themes.Soft()
292
- )
293
-
294
- # GPU optimizations
295
- if torch.cuda.is_available():
296
- device_name = torch.cuda.get_device_name(0)
297
- print(f"\n🖥️ GPU detected: {device_name}")
298
-
299
- if "H100" in device_name:
300
- torch.backends.cuda.matmul.allow_tf32 = True
301
- torch.backends.cudnn.allow_tf32 = True
302
- torch.cuda.set_per_process_memory_fraction(0.85)
303
-
304
- # Launch the app
305
- if __name__ == "__main__":
306
- print("\n🌐 Launching Gradio interface...")
307
- demo.launch(
308
- share=False,
309
- debug=True,
310
- show_error=True
311
- )
 
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoProcessor
3
  from PIL import Image
4
  import torch
5
  import numpy as np
 
 
 
 
 
6
 
7
+ model_name_or_path = "lyttt/VLV_captioner"
8
+ model = AutoModel.from_pretrained(model_name_or_path, revision="master", trust_remote_code=True,low_cpu_mem_usage=False)
9
+
10
+ # @spaces.GPU(duration=120)
11
+ def greet(image):
12
+ if image.dtype != np.uint8:
13
+ image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
14
+ image = Image.fromarray(image, mode='RGB')
15
+ with torch.no_grad():
16
+ outputs = model([image], 300).generated_text[0]
17
+ def drop_incomplete_tail(text):
18
+ sentences = text.split('.')
19
+ complete_sentences = [s.strip() for s in sentences if s.strip()]
20
+ if not text.strip().endswith('.'):
21
+ complete_sentences = complete_sentences[:-1]
22
+ return '. '.join(complete_sentences) + ('.' if complete_sentences else '')
23
+ return drop_incomplete_tail(outputs)
24
+
25
+ demo = gr.Interface(fn=greet, inputs="image", outputs="text")
26
+ demo.launch()