nguyentantoan commited on
Commit
8afdc67
·
verified ·
1 Parent(s): a7f24f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -189
app.py CHANGED
@@ -1,23 +1,22 @@
1
- # app.py - Fixed version for HF Spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoModel, AutoTokenizer
5
  import torchvision.transforms as T
6
  from torchvision.transforms.functional import InterpolationMode
7
  from PIL import Image
8
- import base64
9
- import io
10
  import time
 
11
  import traceback
12
 
13
  # Setup
14
- device = "cpu" # HF Spaces miễn phí chỉ có CPU
15
  model = None
16
  tokenizer = None
17
  transform = None
18
 
19
  def build_transform(input_size=448):
20
- """Build image transform pipeline"""
21
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
22
  IMAGENET_STD = (0.229, 0.224, 0.225)
23
 
@@ -29,20 +28,19 @@ def build_transform(input_size=448):
29
  ])
30
 
31
  def load_model():
32
- """Load Vintern model"""
33
  global model, tokenizer, transform
34
  try:
35
- print("🤖 Loading Vintern-1B-v3.5...")
36
 
37
- model_name = "5CD-AI/Vintern-1B-v3_5"
 
38
 
39
- # Load tokenizer
40
  tokenizer = AutoTokenizer.from_pretrained(
41
  model_name,
42
  trust_remote_code=True
43
  )
44
 
45
- # Load model
46
  model = AutoModel.from_pretrained(
47
  model_name,
48
  torch_dtype=torch.float32,
@@ -50,240 +48,119 @@ def load_model():
50
  low_cpu_mem_usage=True
51
  )
52
 
53
- # Build transform
 
 
 
54
  transform = build_transform()
55
 
56
- print("✅ Model loaded successfully!")
57
  return True
58
 
59
  except Exception as e:
60
- print(f"❌ Error loading model: {e}")
61
  traceback.print_exc()
62
  return False
63
 
64
- def safe_image_processing(image):
65
- """Safely process image input"""
66
- try:
67
- # Handle different input types
68
- if image is None:
69
- return None, "❌ Không có ảnh đầu vào"
70
-
71
- # If it's a file path (string)
72
- if isinstance(image, str):
73
- if image.startswith('data:image'):
74
- # Base64 image
75
- image_data = image.split(',')[1]
76
- image_bytes = base64.b64decode(image_data)
77
- image = Image.open(io.BytesIO(image_bytes))
78
- else:
79
- # File path
80
- image = Image.open(image)
81
-
82
- # Ensure it's a PIL Image
83
- if not hasattr(image, 'mode'):
84
- return None, "❌ Định dạng ảnh không hợp lệ"
85
-
86
- # Convert to RGB if needed
87
- if image.mode != 'RGB':
88
- image = image.convert('RGB')
89
-
90
- return image, None
91
-
92
- except Exception as e:
93
- return None, f"❌ Lỗi xử lý ảnh: {str(e)}"
94
-
95
- def analyze_image(image):
96
- """Analyze image with Vintern model"""
97
  if model is None:
98
- return "❌ Model chưa được tải. Vui lòng chờ..."
99
 
100
  try:
101
  start_time = time.time()
102
 
103
- # Safe image processing
104
- processed_image, error = safe_image_processing(image)
105
- if error:
106
- return error
107
 
108
- if processed_image is None:
109
- return "❌ Không thể xử lý ảnh đầu vào"
110
 
111
- # Transform image
112
- image_tensor = transform(processed_image).unsqueeze(0).to(device)
113
 
114
  with torch.no_grad():
115
- # Main description
116
- query = "Mô tả chi tiết những gì bạn thấy trong hình ảnh này:"
117
 
118
  try:
119
- description = model.chat(
120
  tokenizer,
121
  image_tensor,
122
  query,
123
  generation_config=dict(
124
- max_new_tokens=200,
125
- do_sample=True,
126
  temperature=0.7,
127
- top_p=0.9,
128
- repetition_penalty=1.1
129
  )
130
  )
131
- except Exception as chat_error:
132
- print(f"Chat method failed: {chat_error}")
133
- # Fallback to simple generation
134
  inputs = tokenizer(query, return_tensors="pt").to(device)
135
  outputs = model.generate(
136
  **inputs,
137
- max_new_tokens=150,
138
- do_sample=True,
139
- temperature=0.7
140
- )
141
- description = tokenizer.decode(outputs[0], skip_special_tokens=True)
142
- description = description.replace(query, "").strip()
143
-
144
- # Get objects
145
- try:
146
- object_query = "Liệt kê các đối tượng chính trong ảnh:"
147
- objects_text = model.chat(
148
- tokenizer,
149
- image_tensor,
150
- object_query,
151
- generation_config=dict(max_new_tokens=100, temperature=0.5)
152
  )
153
- objects = [obj.strip() for obj in objects_text.replace(',', ' ').split() if len(obj.strip()) > 2][:5]
154
- objects_str = ", ".join(objects) if objects else "Không có"
155
- except:
156
- objects_str = "Không có"
157
 
158
  processing_time = time.time() - start_time
159
 
160
- # Format output
161
- return f"""**📝 Mô tả từ Vintern AI:**
162
- {description}
163
-
164
- **🔍 Đối tượng nhận diện:**
165
- {objects_str}
166
 
167
- **⚡ Thời gian xử lý:** {processing_time:.2f}s
168
- **🤖 Model:** Vintern-1B-v3.5 (Hugging Face Spaces)
169
- **💻 Device:** {device.upper()}
170
 
171
  ---
172
- *Để sử dụng cho video real-time, hãy sử dụng API endpoint của Space này với trangchu.html*
173
  """
174
 
175
  except Exception as e:
176
- error_msg = f"❌ Lỗi phân tích: {str(e)}"
177
- print(error_msg)
178
- traceback.print_exc()
179
- return error_msg
180
 
181
- def analyze_for_api(image_file):
182
- """API-friendly analysis function"""
183
- try:
184
- result = analyze_image(image_file)
185
- # Return simple text for API consumption
186
- return result
187
- except Exception as e:
188
- return f"Error: {str(e)}"
189
-
190
- # Load model when starting
191
- print("🚀 Initializing Vintern-1B-v3.5...")
192
  model_loaded = load_model()
193
 
194
- # Create Gradio interface
195
  with gr.Blocks(
196
- title="Vintern-1B-v3.5 Video Recognition",
197
- theme=gr.themes.Soft(),
198
- css="""
199
- .gradio-container {
200
- max-width: 1200px !important;
201
- }
202
- .upload-area {
203
- min-height: 300px;
204
- }
205
- """
206
  ) as demo:
207
 
208
- gr.Markdown("""
209
- # 🎥 Vintern-1B-v3.5 - Nhận Diện Ảnh Tiếng Việt
210
-
211
- **Powered by Hugging Face Spaces** | Model chạy hoàn toàn trên cloud
212
- """)
213
 
214
- if not model_loaded:
215
- gr.Markdown("⚠️ **Model đang được tải...** Vui lòng chờ vài phút refresh trang.")
216
- else:
217
- gr.Markdown("✅ **Model đã sẵn sàng!** Upload ảnh để bắt đầu nhận diện.")
218
 
219
  with gr.Row():
220
- with gr.Column(scale=1):
221
- image_input = gr.Image(
222
- type="pil",
223
- label="📤 Upload Ảnh",
224
- elem_classes=["upload-area"]
225
- )
226
-
227
- with gr.Row():
228
- analyze_btn = gr.Button("🔍 Phân Tích", variant="primary", scale=2)
229
- clear_btn = gr.Button("🗑️ Xóa", variant="secondary", scale=1)
230
-
231
- with gr.Column(scale=1):
232
- result_output = gr.Textbox(
233
- label="📋 Kết Quả Phân Tích",
234
- lines=15,
235
- max_lines=20,
236
- show_copy_button=True
237
- )
238
-
239
- # Event handlers
240
- analyze_btn.click(
241
- fn=analyze_image,
242
- inputs=image_input,
243
- outputs=result_output
244
- )
245
-
246
- clear_btn.click(
247
- fn=lambda: (None, ""),
248
- outputs=[image_input, result_output]
249
- )
250
 
251
- # Auto-analyze on image upload
252
  image_input.change(
253
- fn=analyze_image,
254
  inputs=image_input,
255
  outputs=result_output
256
  )
257
 
258
  gr.Markdown("""
259
- ---
260
- ## 💡 Hướng dẫn sử dụng:
261
-
262
- ### 🖼️ Phân tích ảnh đơn lẻ:
263
- 1. **Upload ảnh** từ máy tính hoặc drag & drop
264
- 2. **Kết quả tự động** hiển thị sau khi upload
265
- 3. **Hoặc nhấn "Phân Tích"** để chạy lại
266
-
267
- ### 🎥 Phân tích video real-time:
268
- 1. **Copy URL Space này:** `https://nguyentantoan-vintern-video-recognition.hf.space`
269
- 2. **Mở trangchu.html** đã được cung cấp
270
- 3. **Thay URL** trong code JavaScript
271
- 4. **Sử dụng camera** để phân tích real-time
272
-
273
- ### 🔗 API Usage:
274
- ```javascript
275
- // POST to: https://nguyentantoan-vintern-video-recognition.hf.space/api/predict
276
- // Body: FormData with image file
277
- ```
278
-
279
- **⚠️ Lưu ý:** Việc phân tích có thể mất 10-30 giây do chạy trên CPU miễn phí của HF Spaces.
280
  """)
281
 
282
- # Launch the app
283
  if __name__ == "__main__":
284
- demo.launch(
285
- server_name="0.0.0.0",
286
- server_port=7860,
287
- show_error=True,
288
- quiet=False
289
- )
 
1
+ # app_fast.py - Vintern-1B Fast Version
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoModel, AutoTokenizer
5
  import torchvision.transforms as T
6
  from torchvision.transforms.functional import InterpolationMode
7
  from PIL import Image
 
 
8
  import time
9
+ import json
10
  import traceback
11
 
12
  # Setup
13
+ device = "cpu"
14
  model = None
15
  tokenizer = None
16
  transform = None
17
 
18
  def build_transform(input_size=448):
19
+ """Optimized transform"""
20
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
21
  IMAGENET_STD = (0.229, 0.224, 0.225)
22
 
 
28
  ])
29
 
30
  def load_model():
31
+ """Load Vintern-1B (faster version)"""
32
  global model, tokenizer, transform
33
  try:
34
+ print("🚀 Loading Vintern-1B (Fast Version)...")
35
 
36
+ # Sử dụng model nhẹ hơn
37
+ model_name = "5CD-AI/Vintern-1B-v2" # Thay vì v3.5
38
 
 
39
  tokenizer = AutoTokenizer.from_pretrained(
40
  model_name,
41
  trust_remote_code=True
42
  )
43
 
 
44
  model = AutoModel.from_pretrained(
45
  model_name,
46
  torch_dtype=torch.float32,
 
48
  low_cpu_mem_usage=True
49
  )
50
 
51
+ # Optimize model for inference
52
+ model.eval()
53
+ model = torch.jit.optimize_for_inference(model)
54
+
55
  transform = build_transform()
56
 
57
+ print("✅ Fast model loaded!")
58
  return True
59
 
60
  except Exception as e:
61
+ print(f"❌ Error: {e}")
62
  traceback.print_exc()
63
  return False
64
 
65
+ def fast_analyze(image):
66
+ """Optimized analysis function"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  if model is None:
68
+ return "❌ Model chưa sẵn sàng"
69
 
70
  try:
71
  start_time = time.time()
72
 
73
+ # Quick image processing
74
+ if image is None:
75
+ return "❌ Không có ảnh"
 
76
 
77
+ if hasattr(image, 'mode') and image.mode != 'RGB':
78
+ image = image.convert('RGB')
79
 
80
+ # Fast transform
81
+ image_tensor = transform(image).unsqueeze(0).to(device)
82
 
83
  with torch.no_grad():
84
+ # Shorter, faster generation
85
+ query = "Mô tả ngắn gọn:"
86
 
87
  try:
88
+ result = model.chat(
89
  tokenizer,
90
  image_tensor,
91
  query,
92
  generation_config=dict(
93
+ max_new_tokens=100, # Ngắn hơn → nhanh hơn
94
+ do_sample=False, # Greedy → nhanh hơn
95
  temperature=0.7,
96
+ num_beams=1 # No beam search → nhanh hơn
 
97
  )
98
  )
99
+ except:
100
+ # Fallback nhanh
 
101
  inputs = tokenizer(query, return_tensors="pt").to(device)
102
  outputs = model.generate(
103
  **inputs,
104
+ max_new_tokens=80,
105
+ do_sample=False,
106
+ num_beams=1
 
 
 
 
 
 
 
 
 
 
 
 
107
  )
108
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
+ result = result.replace(query, "").strip()
 
 
110
 
111
  processing_time = time.time() - start_time
112
 
113
+ return f"""**📝 Mô tả nhanh:**
114
+ {result}
 
 
 
 
115
 
116
+ **⚡ Thời gian:** {processing_time:.1f}s
117
+ **🤖 Model:** Vintern-1B-v2 (Optimized)
118
+ **💨 Tốc độ:** {1/processing_time:.1f} FPS
119
 
120
  ---
121
+ *Model được tối ưu cho tốc độ - phù hợp real-time*
122
  """
123
 
124
  except Exception as e:
125
+ return f"❌ Lỗi: {str(e)}"
 
 
 
126
 
127
+ # Load model
128
+ print("🚀 Starting Fast Vintern Server...")
 
 
 
 
 
 
 
 
 
129
  model_loaded = load_model()
130
 
131
+ # Lightweight Gradio interface
132
  with gr.Blocks(
133
+ title="Vintern-1B Fast",
134
+ theme=gr.themes.Base(),
 
 
 
 
 
 
 
 
135
  ) as demo:
136
 
137
+ gr.Markdown("# ⚡ Vintern-1B Fast - Tốc Độ Cao")
 
 
 
 
138
 
139
+ if model_loaded:
140
+ gr.Markdown(" **Model sẵn sàng!** Tối ưu cho tốc độreal-time.")
 
 
141
 
142
  with gr.Row():
143
+ image_input = gr.Image(type="pil", label="📤 Upload Ảnh")
144
+ result_output = gr.Textbox(
145
+ label="📋 Kết Quả",
146
+ lines=8,
147
+ show_copy_button=True
148
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ # Auto-analyze on upload
151
  image_input.change(
152
+ fn=fast_analyze,
153
  inputs=image_input,
154
  outputs=result_output
155
  )
156
 
157
  gr.Markdown("""
158
+ ### ⚡ Tối ưu cho tốc độ:
159
+ - **Model nhẹ**: Vintern-1B-v2 (~1.5GB)
160
+ - **Fast generation**: Greedy decode, short output
161
+ - **Optimized**: JIT compilation, no beam search
162
+ - **Real-time ready**: ~2-5 giây/ảnh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  """)
164
 
 
165
  if __name__ == "__main__":
166
+ demo.launch(server_name="0.0.0.0", server_port=7860)