Víctor Sáez commited on
Commit
4a473ee
·
1 Parent(s): 0b1e00c

Restirung Adding multilenguage support

Browse files
Files changed (1) hide show
  1. app.py +189 -133
app.py CHANGED
@@ -2,19 +2,17 @@ import gradio as gr
2
  import torch
3
  from PIL import Image, ImageDraw, ImageFont
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
 
 
5
 
6
- # Only import pipeline if translation is enabled
7
- ENABLE_TRANSLATION = False # Cambia a True solo si puedes cargar modelos Helsinki localmente
8
-
9
- if ENABLE_TRANSLATION:
10
- from transformers import pipeline
11
-
12
- # Global variables
13
  current_model = None
14
  current_processor = None
15
  current_model_name = None
16
 
 
17
  available_models = {
 
18
  "DETR ResNet-50": "facebook/detr-resnet-50",
19
  "DETR ResNet-101": "facebook/detr-resnet-101",
20
  "DETR DC5": "facebook/detr-resnet-50-dc5",
@@ -23,23 +21,37 @@ available_models = {
23
 
24
 
25
  def load_model(model_key):
 
26
  global current_model, current_processor, current_model_name
 
27
  model_name = available_models[model_key]
 
 
28
  if current_model_name != model_name:
29
  print(f"Loading model: {model_name}")
30
  current_processor = DetrImageProcessor.from_pretrained(model_name)
31
  current_model = DetrForObjectDetection.from_pretrained(model_name)
32
  current_model_name = model_name
 
 
 
33
  return current_model, current_processor
34
 
35
 
36
- def get_font(size=12):
37
- try:
38
- return ImageFont.truetype("arial.ttf", size=size)
39
- except:
40
- return ImageFont.load_default()
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  translations = {
44
  "English": {
45
  "title": "## Enhanced Object Detection App\nUpload an image to detect objects using various DETR models.",
@@ -91,131 +103,186 @@ def t(language, key):
91
 
92
 
93
  def get_translated_model_choices(language):
 
94
  model_mapping = {
95
  "DETR ResNet-50": "model_fast",
96
  "DETR ResNet-101": "model_precision",
97
  "DETR DC5": "model_small",
98
  "DETR ResNet-50 Face Only": "model_faces"
99
  }
 
100
  translated_choices = []
101
  for model_key in available_models.keys():
102
  if model_key in model_mapping:
103
  translation_key = model_mapping[model_key]
104
  translated_name = t(language, translation_key)
105
  else:
106
- translated_name = model_key
107
  translated_choices.append(translated_name)
 
108
  return translated_choices
109
 
110
 
111
  def get_model_key_from_translation(translated_name, language):
 
112
  model_mapping = {
113
  "DETR ResNet-50": "model_fast",
114
  "DETR ResNet-101": "model_precision",
115
  "DETR DC5": "model_small",
116
  "DETR ResNet-50 Face Only": "model_faces"
117
  }
 
 
118
  for model_key, translation_key in model_mapping.items():
119
  if t(language, translation_key) == translated_name:
120
  return model_key
 
 
121
  if translated_name in available_models:
122
  return translated_name
 
 
123
  return "DETR ResNet-50"
124
 
125
 
126
- # Translation logic (only if ENABLE_TRANSLATION and model is local)
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  translation_cache = {}
128
 
129
 
 
130
  def translate_label(language_label, label):
131
- if language_label == "English" or not ENABLE_TRANSLATION:
132
- return label
133
  cache_key = f"{language_label}_{label}"
134
  if cache_key in translation_cache:
135
  return translation_cache[cache_key]
136
- # Dummy fallback in Spaces, or if not preloaded, just warn
137
- translation_cache[cache_key] = f"{label} (no translation)"
138
- return translation_cache[cache_key]
139
 
 
 
 
140
 
141
- def detect_objects(image, language_selector, translated_model_selector, threshold):
142
  try:
143
- if image is None:
144
- return None, "Please upload an image before detecting objects."
145
- model_selector = get_model_key_from_translation(translated_model_selector, language_selector)
146
- model, processor = load_model(model_selector)
147
- inputs = processor(images=image, return_tensors="pt")
148
- outputs = model(**inputs)
149
- target_sizes = torch.tensor([image.size[::-1]])
150
- results = processor.post_process_object_detection(
151
- outputs, threshold=threshold, target_sizes=target_sizes
152
- )[0]
153
- image_with_boxes = image.copy()
154
- draw = ImageDraw.Draw(image_with_boxes)
155
- detection_info = f"Detected {len(results['scores'])} objects with threshold {threshold}\n"
156
- detection_info += f"Model: {translated_model_selector} ({model_selector})\n\n"
157
- colors = {
158
- 'high': 'red',
159
- 'medium': 'orange',
160
- 'low': 'yellow'
161
- }
162
- detected_objects = []
163
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
164
- confidence = score.item()
165
- box = [round(x, 2) for x in box.tolist()]
166
- if confidence > 0.8:
167
- color = colors['high']
168
- elif confidence > 0.5:
169
- color = colors['medium']
170
- else:
171
- color = colors['low']
172
- draw.rectangle(box, outline=color, width=3)
173
- label_text = model.config.id2label[label.item()]
174
- translated_label = translate_label(language_selector, label_text)
175
- display_text = f"{translated_label}: {round(confidence, 3)}"
176
- detected_objects.append({
177
- 'label': label_text,
178
- 'translated': translated_label,
179
- 'confidence': confidence,
180
- 'box': box
181
- })
182
- try:
183
- image_width = image.size[0]
184
- font_size = max(image_width // 40, 12)
185
- font = get_font(font_size)
186
- text_bbox = draw.textbbox((0, 0), display_text, font=font)
187
- text_width = text_bbox[2] - text_bbox[0]
188
- text_height = text_bbox[3] - text_bbox[1]
189
- except:
190
- font = get_font(12)
191
- text_width = 50
192
- text_height = 20
193
- text_bg = [
194
- box[0], box[1] - text_height - 4,
195
- box[0] + text_width + 4, box[1]
196
- ]
197
- draw.rectangle(text_bg, fill="black")
198
- draw.text((box[0] + 2, box[1] - text_height - 2), display_text, fill="white", font=font)
199
- if detected_objects:
200
- detection_info += "Objects found:\n"
201
- for obj in sorted(detected_objects, key=lambda x: x['confidence'], reverse=True):
202
- detection_info += f"- {obj['translated']} ({obj['label']}): {obj['confidence']:.3f}\n"
203
- else:
204
- detection_info += "No objects detected. Try lowering the threshold."
205
- return image_with_boxes, detection_info
206
  except Exception as e:
207
- import traceback
208
- print("ERROR EN DETECT_OBJECTS:", e)
209
- traceback.print_exc()
210
- return None, f"Error detecting objects: {e}"
211
 
212
 
213
- def build_app():
214
- # Crear componentes con referencias globales
215
- title = gr.Markdown(t("English", "title"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  with gr.Blocks(theme=gr.themes.Soft()) as app:
218
- title.render()
 
219
 
220
  with gr.Row():
221
  with gr.Column(scale=1):
@@ -227,14 +294,14 @@ def build_app():
227
  with gr.Column(scale=1):
228
  model_selector = gr.Dropdown(
229
  choices=get_translated_model_choices("English"),
230
- value=t("English", "model_fast"),
231
  label=t("English", "dropdown_detection_model_label")
232
  )
233
  with gr.Column(scale=1):
234
  threshold_slider = gr.Slider(
235
  minimum=0.1,
236
  maximum=0.95,
237
- value=0.5,
238
  step=0.05,
239
  label=t("English", "threshold_label")
240
  )
@@ -251,56 +318,36 @@ def build_app():
251
  max_lines=15
252
  )
253
 
 
254
  def update_interface(selected_language):
255
- try:
256
- translated_choices = get_translated_model_choices(selected_language)
257
- default_model = t(selected_language, "model_fast")
258
-
259
- # Asegurar que default_model está en las opciones
260
- if default_model not in translated_choices:
261
- default_model = translated_choices[0] if translated_choices else "General Objects (fast)"
262
-
263
- updates = []
264
- updates.append(gr.update(value=t(selected_language, "title"))) # title
265
- updates.append(gr.update(label=t(selected_language, "dropdown_label"))) # language_selector
266
- updates.append(gr.update(
267
  choices=translated_choices,
268
  value=default_model,
269
  label=t(selected_language, "dropdown_detection_model_label")
270
- )) # model_selector
271
- updates.append(gr.update(label=t(selected_language, "threshold_label"))) # threshold_slider
272
- updates.append(gr.update(label=t(selected_language, "input_label"))) # input_image
273
- updates.append(gr.update(value=t(selected_language, "button"))) # button
274
- updates.append(gr.update(label=t(selected_language, "output_label"))) # output_image
275
- updates.append(gr.update(label=t(selected_language, "info_label"))) # detection_info
276
-
277
- return updates
278
- except Exception as e:
279
- print(f"Error in update_interface: {e}")
280
- import traceback
281
- traceback.print_exc()
282
- # Retornar valores por defecto en caso de error
283
- return [
284
- gr.update(), # title
285
- gr.update(), # language_selector
286
- gr.update(), # model_selector
287
- gr.update(), # threshold_slider
288
- gr.update(), # input_image
289
- gr.update(), # button
290
- gr.update(), # output_image
291
- gr.update() # detection_info
292
- ]
293
-
294
- # Configurar el evento de cambio de idioma
295
  language_selector.change(
296
  fn=update_interface,
297
- inputs=[language_selector],
298
  outputs=[title, language_selector, model_selector, threshold_slider,
299
  input_image, button, output_image, detection_info],
300
  queue=False
301
  )
302
 
303
- # Configurar el botón de detección
304
  button.click(
305
  fn=detect_objects,
306
  inputs=[input_image, language_selector, model_selector, threshold_slider],
@@ -310,9 +357,18 @@ def build_app():
310
  return app
311
 
312
 
313
- # Precargar modelo por defecto
 
 
 
 
 
 
 
 
314
  load_model("DETR ResNet-50")
315
 
 
316
  if __name__ == "__main__":
317
  app = build_app()
318
  app.launch()
 
2
  import torch
3
  from PIL import Image, ImageDraw, ImageFont
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
+ from pathlib import Path
6
+ import transformers
7
 
8
+ # Global variables to cache models
 
 
 
 
 
 
9
  current_model = None
10
  current_processor = None
11
  current_model_name = None
12
 
13
+ # Available models with better selection
14
  available_models = {
15
+ # DETR Models
16
  "DETR ResNet-50": "facebook/detr-resnet-50",
17
  "DETR ResNet-101": "facebook/detr-resnet-101",
18
  "DETR DC5": "facebook/detr-resnet-50-dc5",
 
21
 
22
 
23
  def load_model(model_key):
24
+ """Load model and processor based on selected model key"""
25
  global current_model, current_processor, current_model_name
26
+
27
  model_name = available_models[model_key]
28
+
29
+ # Only load if it's a different model
30
  if current_model_name != model_name:
31
  print(f"Loading model: {model_name}")
32
  current_processor = DetrImageProcessor.from_pretrained(model_name)
33
  current_model = DetrForObjectDetection.from_pretrained(model_name)
34
  current_model_name = model_name
35
+ print(f"Model loaded: {model_name}")
36
+ print(f"Available labels: {list(current_model.config.id2label.values())}")
37
+
38
  return current_model, current_processor
39
 
40
 
 
 
 
 
 
41
 
42
 
43
+
44
+
45
+ # Load font
46
+ font_path = Path("assets/fonts/arial.ttf")
47
+ if not font_path.exists():
48
+
49
+ print(f"Font file {font_path} not found. Using default font.")
50
+ font = ImageFont.load_default()
51
+ else:
52
+ font = ImageFont.truetype(str(font_path), size=100) # Reduced font size
53
+
54
+ # Set up translations for the app
55
  translations = {
56
  "English": {
57
  "title": "## Enhanced Object Detection App\nUpload an image to detect objects using various DETR models.",
 
103
 
104
 
105
  def get_translated_model_choices(language):
106
+ """Get model choices translated to the selected language"""
107
  model_mapping = {
108
  "DETR ResNet-50": "model_fast",
109
  "DETR ResNet-101": "model_precision",
110
  "DETR DC5": "model_small",
111
  "DETR ResNet-50 Face Only": "model_faces"
112
  }
113
+
114
  translated_choices = []
115
  for model_key in available_models.keys():
116
  if model_key in model_mapping:
117
  translation_key = model_mapping[model_key]
118
  translated_name = t(language, translation_key)
119
  else:
120
+ translated_name = model_key # Fallback to original name
121
  translated_choices.append(translated_name)
122
+
123
  return translated_choices
124
 
125
 
126
  def get_model_key_from_translation(translated_name, language):
127
+ """Get the original model key from translated name"""
128
  model_mapping = {
129
  "DETR ResNet-50": "model_fast",
130
  "DETR ResNet-101": "model_precision",
131
  "DETR DC5": "model_small",
132
  "DETR ResNet-50 Face Only": "model_faces"
133
  }
134
+
135
+ # Reverse lookup
136
  for model_key, translation_key in model_mapping.items():
137
  if t(language, translation_key) == translated_name:
138
  return model_key
139
+
140
+ # If not found, try direct match
141
  if translated_name in available_models:
142
  return translated_name
143
+
144
+ # Default fallback
145
  return "DETR ResNet-50"
146
 
147
 
148
+ def get_helsinki_model(language_label):
149
+ """Returns the Helsinki-NLP model name for translating from English to the selected language."""
150
+ lang_map = {
151
+ "Spanish": "es",
152
+ "French": "fr",
153
+ "English": "en"
154
+ }
155
+ target = lang_map.get(language_label)
156
+ if not target or target == "en":
157
+ return None
158
+ return f"Helsinki-NLP/opus-mt-en-{target}"
159
+
160
+
161
+ # add cache for translations
162
  translation_cache = {}
163
 
164
 
165
+
166
  def translate_label(language_label, label):
167
+ """Translates the given label to the target language."""
168
+ # Check cache first
169
  cache_key = f"{language_label}_{label}"
170
  if cache_key in translation_cache:
171
  return translation_cache[cache_key]
 
 
 
172
 
173
+ model_name = get_helsinki_model(language_label)
174
+ if not model_name:
175
+ return label
176
 
 
177
  try:
178
+ translator = transformers.pipeline("translation", model=model_name)
179
+ result = translator(label, max_length=40)
180
+ translated = result[0]['translation_text']
181
+ # Cache the result
182
+ translation_cache[cache_key] = translated
183
+ return translated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  except Exception as e:
185
+ print(f"Translation error (429 or other): {e}")
186
+ return label # Return original if translation fails
 
 
187
 
188
 
189
+ def detect_objects(image, language_selector, translated_model_selector, threshold):
190
+ """Enhanced object detection with adjustable threshold and better info"""
191
+ # Get the actual model key from the translated name
192
+ model_selector = get_model_key_from_translation(translated_model_selector, language_selector)
193
+
194
+ print(f"Processing image. Language: {language_selector}, Model: {model_selector}, Threshold: {threshold}")
195
+
196
+ # Load the selected model
197
+ model, processor = load_model(model_selector)
198
+
199
+ # Process the image
200
+ inputs = processor(images=image, return_tensors="pt")
201
+ outputs = model(**inputs)
202
+
203
+ # Convert model output to usable detection results with custom threshold
204
+ target_sizes = torch.tensor([image.size[::-1]])
205
+ results = processor.post_process_object_detection(
206
+ outputs, threshold=threshold, target_sizes=target_sizes
207
+ )[0]
208
+
209
+ # Create a copy of the image for drawing
210
+ image_with_boxes = image.copy()
211
+ draw = ImageDraw.Draw(image_with_boxes)
212
+
213
+ # Detection info
214
+ detection_info = f"Detected {len(results['scores'])} objects with threshold {threshold}\n"
215
+ detection_info += f"Model: {translated_model_selector} ({model_selector})\n\n"
216
+
217
+ # Colors for different confidence levels
218
+ colors = {
219
+ 'high': 'red', # > 0.8
220
+ 'medium': 'orange', # 0.5-0.8
221
+ 'low': 'yellow' # < 0.5
222
+ }
223
+
224
+ detected_objects = []
225
+
226
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
227
+ confidence = score.item()
228
+ box = [round(x, 2) for x in box.tolist()]
229
 
230
+ # Choose color based on confidence
231
+ if confidence > 0.8:
232
+ color = colors['high']
233
+ elif confidence > 0.5:
234
+ color = colors['medium']
235
+ else:
236
+ color = colors['low']
237
+
238
+ # Draw bounding box
239
+ draw.rectangle(box, outline=color, width=3)
240
+
241
+ # Prepare label text
242
+ label_text = model.config.id2label[label.item()]
243
+ translated_label = translate_label(language_selector, label_text)
244
+ display_text = f"{translated_label}: {round(confidence, 3)}"
245
+
246
+ # Store detection info
247
+ detected_objects.append({
248
+ 'label': label_text,
249
+ 'translated': translated_label,
250
+ 'confidence': confidence,
251
+ 'box': box
252
+ })
253
+
254
+ # Calculate text position and size
255
+ try:
256
+ text_bbox = draw.textbbox((0, 0), display_text, font=font)
257
+ text_width = text_bbox[2] - text_bbox[0]
258
+ text_height = text_bbox[3] - text_bbox[1]
259
+ except:
260
+ # Fallback for older PIL versions
261
+ text_width, text_height = draw.textsize(display_text, font=font)
262
+
263
+ # Draw text background
264
+ text_bg = [
265
+ box[0], box[1] - text_height - 4,
266
+ box[0] + text_width + 4, box[1]
267
+ ]
268
+ draw.rectangle(text_bg, fill="black")
269
+ draw.text((box[0] + 2, box[1] - text_height - 2), display_text, fill="white", font=font)
270
+
271
+ # Create detailed detection info
272
+ if detected_objects:
273
+ detection_info += "Objects found:\n"
274
+ for obj in sorted(detected_objects, key=lambda x: x['confidence'], reverse=True):
275
+ detection_info += f"- {obj['translated']} ({obj['label']}): {obj['confidence']:.3f}\n"
276
+ else:
277
+ detection_info += "No objects detected. Try lowering the threshold."
278
+
279
+ return image_with_boxes, detection_info
280
+
281
+
282
+ def build_app():
283
  with gr.Blocks(theme=gr.themes.Soft()) as app:
284
+ with gr.Row():
285
+ title = gr.Markdown(t("English", "title"))
286
 
287
  with gr.Row():
288
  with gr.Column(scale=1):
 
294
  with gr.Column(scale=1):
295
  model_selector = gr.Dropdown(
296
  choices=get_translated_model_choices("English"),
297
+ value=t("English", "model_fast"), # Default to translated "fast" option
298
  label=t("English", "dropdown_detection_model_label")
299
  )
300
  with gr.Column(scale=1):
301
  threshold_slider = gr.Slider(
302
  minimum=0.1,
303
  maximum=0.95,
304
+ value=0.5, # Lowered default threshold
305
  step=0.05,
306
  label=t("English", "threshold_label")
307
  )
 
318
  max_lines=15
319
  )
320
 
321
+ # Function to update interface when language changes
322
  def update_interface(selected_language):
323
+ translated_choices = get_translated_model_choices(selected_language)
324
+ default_model = t(selected_language, "model_fast")
325
+
326
+ return [
327
+ gr.update(value=t(selected_language, "title")),
328
+ gr.update(label=t(selected_language, "dropdown_label")),
329
+ gr.update(
 
 
 
 
 
330
  choices=translated_choices,
331
  value=default_model,
332
  label=t(selected_language, "dropdown_detection_model_label")
333
+ ),
334
+ gr.update(label=t(selected_language, "threshold_label")),
335
+ gr.update(label=t(selected_language, "input_label")),
336
+ gr.update(value=t(selected_language, "button")),
337
+ gr.update(label=t(selected_language, "output_label")),
338
+ gr.update(label=t(selected_language, "info_label"))
339
+ ]
340
+
341
+ # Connect language change event
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  language_selector.change(
343
  fn=update_interface,
344
+ inputs=language_selector,
345
  outputs=[title, language_selector, model_selector, threshold_slider,
346
  input_image, button, output_image, detection_info],
347
  queue=False
348
  )
349
 
350
+ # Connect detection button click event
351
  button.click(
352
  fn=detect_objects,
353
  inputs=[input_image, language_selector, model_selector, threshold_slider],
 
357
  return app
358
 
359
 
360
+
361
+
362
+
363
+
364
+
365
+
366
+
367
+
368
+ # Initialize with default model
369
  load_model("DETR ResNet-50")
370
 
371
+ # Launch the application
372
  if __name__ == "__main__":
373
  app = build_app()
374
  app.launch()