File size: 11,594 Bytes
221337b
6ecfb14
53e14a8
221337b
6ecfb14
c2f455f
 
 
 
 
 
 
2e9147d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221337b
363fa1b
 
 
 
 
2e9147d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f455f
2e9147d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f455f
2e9147d
6ecfb14
2e9147d
c2f455f
 
2e9147d
 
 
c2f455f
 
 
2e9147d
 
82026bf
 
c2f455f
82026bf
 
 
 
 
 
 
 
 
 
 
 
 
c2f455f
 
 
82026bf
 
 
 
 
 
 
 
 
862343e
82026bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e9147d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f455f
2e9147d
 
 
 
 
 
c2f455f
2e9147d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221337b
2e9147d
6ecfb14
221337b
2e9147d
82026bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import DetrImageProcessor, DetrForObjectDetection

# Only import pipeline if translation is enabled
ENABLE_TRANSLATION = False  # Cambia a True solo si puedes cargar modelos Helsinki localmente

if ENABLE_TRANSLATION:
    from transformers import pipeline

# Global variables
current_model = None
current_processor = None
current_model_name = None

available_models = {
    "DETR ResNet-50": "facebook/detr-resnet-50",
    "DETR ResNet-101": "facebook/detr-resnet-101",
    "DETR DC5": "facebook/detr-resnet-50-dc5",
    "DETR ResNet-50 Face Only": "esraakh/detr_fine_tune_face_detection_final"
}

def load_model(model_key):
    global current_model, current_processor, current_model_name
    model_name = available_models[model_key]
    if current_model_name != model_name:
        print(f"Loading model: {model_name}")
        current_processor = DetrImageProcessor.from_pretrained(model_name)
        current_model = DetrForObjectDetection.from_pretrained(model_name)
        current_model_name = model_name
    return current_model, current_processor

def get_font(size=12):
    try:
        return ImageFont.truetype("arial.ttf", size=size)
    except:
        return ImageFont.load_default()

translations = {
    "English": {
        "title": "## Enhanced Object Detection App\nUpload an image to detect objects using various DETR models.",
        "input_label": "Input Image",
        "output_label": "Detected Objects",
        "dropdown_label": "Label Language",
        "dropdown_detection_model_label": "Detection Model",
        "threshold_label": "Detection Threshold",
        "button": "Detect Objects",
        "info_label": "Detection Info",
        "model_fast": "General Objects (fast)",
        "model_precision": "General Objects (high precision)",
        "model_small": "Small Objects/Details (slow)",
        "model_faces": "Face Detection (people only)"
    },
    "Spanish": {
        "title": "## Aplicación Mejorada de Detección de Objetos\nSube una imagen para detectar objetos usando varios modelos DETR.",
        "input_label": "Imagen de entrada",
        "output_label": "Objetos detectados",
        "dropdown_label": "Idioma de las etiquetas",
        "dropdown_detection_model_label": "Modelo de detección",
        "threshold_label": "Umbral de detección",
        "button": "Detectar objetos",
        "info_label": "Información de detección",
        "model_fast": "Objetos generales (rápido)",
        "model_precision": "Objetos generales (precisión alta)",
        "model_small": "Objetos pequeños/detalles (lento)",
        "model_faces": "Detección de caras (solo personas)"
    },
    "French": {
        "title": "## Application Améliorée de Détection d'Objets\nTéléchargez une image pour détecter des objets avec divers modèles DETR.",
        "input_label": "Image d'entrée",
        "output_label": "Objets détectés",
        "dropdown_label": "Langue des étiquettes",
        "dropdown_detection_model_label": "Modèle de détection",
        "threshold_label": "Seuil de détection",
        "button": "Détecter les objets",
        "info_label": "Information de détection",
        "model_fast": "Objets généraux (rapide)",
        "model_precision": "Objets généraux (haute précision)",
        "model_small": "Petits objets/détails (lent)",
        "model_faces": "Détection de visages (personnes uniquement)"
    }
}

def t(language, key):
    return translations.get(language, translations["English"]).get(key, key)

def get_translated_model_choices(language):
    model_mapping = {
        "DETR ResNet-50": "model_fast",
        "DETR ResNet-101": "model_precision",
        "DETR DC5": "model_small",
        "DETR ResNet-50 Face Only": "model_faces"
    }
    translated_choices = []
    for model_key in available_models.keys():
        if model_key in model_mapping:
            translation_key = model_mapping[model_key]
            translated_name = t(language, translation_key)
        else:
            translated_name = model_key
        translated_choices.append(translated_name)
    return translated_choices

def get_model_key_from_translation(translated_name, language):
    model_mapping = {
        "DETR ResNet-50": "model_fast",
        "DETR ResNet-101": "model_precision",
        "DETR DC5": "model_small",
        "DETR ResNet-50 Face Only": "model_faces"
    }
    for model_key, translation_key in model_mapping.items():
        if t(language, translation_key) == translated_name:
            return model_key
    if translated_name in available_models:
        return translated_name
    return "DETR ResNet-50"

# Translation logic (only if ENABLE_TRANSLATION and model is local)
translation_cache = {}

def translate_label(language_label, label):
    if language_label == "English" or not ENABLE_TRANSLATION:
        return label
    cache_key = f"{language_label}_{label}"
    if cache_key in translation_cache:
        return translation_cache[cache_key]
    # Dummy fallback in Spaces, or if not preloaded, just warn
    translation_cache[cache_key] = f"{label} (no translation)"
    return translation_cache[cache_key]

def detect_objects(image, language_selector, translated_model_selector, threshold):
    try:
        if image is None:
            return None, "Please upload an image before detecting objects."
        model_selector = get_model_key_from_translation(translated_model_selector, language_selector)
        model, processor = load_model(model_selector)
        inputs = processor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        target_sizes = torch.tensor([image.size[::-1]])
        results = processor.post_process_object_detection(
            outputs, threshold=threshold, target_sizes=target_sizes
        )[0]
        image_with_boxes = image.copy()
        draw = ImageDraw.Draw(image_with_boxes)
        detection_info = f"Detected {len(results['scores'])} objects with threshold {threshold}\n"
        detection_info += f"Model: {translated_model_selector} ({model_selector})\n\n"
        colors = {
            'high': 'red',
            'medium': 'orange',
            'low': 'yellow'
        }
        detected_objects = []
        for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
            confidence = score.item()
            box = [round(x, 2) for x in box.tolist()]
            if confidence > 0.8:
                color = colors['high']
            elif confidence > 0.5:
                color = colors['medium']
            else:
                color = colors['low']
            draw.rectangle(box, outline=color, width=3)
            label_text = model.config.id2label[label.item()]
            translated_label = translate_label(language_selector, label_text)
            display_text = f"{translated_label}: {round(confidence, 3)}"
            detected_objects.append({
                'label': label_text,
                'translated': translated_label,
                'confidence': confidence,
                'box': box
            })
            try:
                image_width = image.size[0]
                font_size = max(image_width // 40, 12)
                font = get_font(font_size)
                text_bbox = draw.textbbox((0, 0), display_text, font=font)
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]
            except:
                font = get_font(12)
                text_width = 50
                text_height = 20
            text_bg = [
                box[0], box[1] - text_height - 4,
                        box[0] + text_width + 4, box[1]
            ]
            draw.rectangle(text_bg, fill="black")
            draw.text((box[0] + 2, box[1] - text_height - 2), display_text, fill="white", font=font)
        if detected_objects:
            detection_info += "Objects found:\n"
            for obj in sorted(detected_objects, key=lambda x: x['confidence'], reverse=True):
                detection_info += f"- {obj['translated']} ({obj['label']}): {obj['confidence']:.3f}\n"
        else:
            detection_info += "No objects detected. Try lowering the threshold."
        return image_with_boxes, detection_info
    except Exception as e:
        import traceback
        print("ERROR EN DETECT_OBJECTS:", e)
        traceback.print_exc()
        return None, f"Error detecting objects: {e}"

def build_app():
    with gr.Blocks(theme=gr.themes.Soft()) as app:
        with gr.Row():
            title = gr.Markdown(t("English", "title"))
        with gr.Row():
            with gr.Column(scale=1):
                language_selector = gr.Dropdown(
                    choices=["English", "Spanish", "French"],
                    value="English",
                    label=t("English", "dropdown_label")
                )
            with gr.Column(scale=1):
                model_selector = gr.Dropdown(
                    choices=get_translated_model_choices("English"),
                    value=t("English", "model_fast"),
                    label=t("English", "dropdown_detection_model_label")
                )
            with gr.Column(scale=1):
                threshold_slider = gr.Slider(
                    minimum=0.1,
                    maximum=0.95,
                    value=0.5,
                    step=0.05,
                    label=t("English", "threshold_label")
                )
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(type="pil", label=t("English", "input_label"))
                button = gr.Button(t("English", "button"), variant="primary")
            with gr.Column(scale=1):
                output_image = gr.Image(label=t("English", "output_label"))
                detection_info = gr.Textbox(
                    label=t("English", "info_label"),
                    lines=10,
                    max_lines=15
                )
        def update_interface(selected_language):
            translated_choices = get_translated_model_choices(selected_language)
            default_model = t(selected_language, "model_fast")
            return [
                gr.update(value=t(selected_language, "title")),
                gr.update(label=t(selected_language, "dropdown_label")),
                gr.update(
                    choices=translated_choices,
                    value=default_model,
                    label=t(selected_language, "dropdown_detection_model_label")
                ),
                gr.update(label=t(selected_language, "threshold_label")),
                gr.update(label=t(selected_language, "input_label")),
                gr.update(value=t(selected_language, "button")),
                gr.update(label=t(selected_language, "output_label")),
                gr.update(label=t(selected_language, "info_label"))
            ]
        language_selector.change(
            fn=update_interface,
            inputs=language_selector,
            outputs=[title, language_selector, model_selector, threshold_slider,
                     input_image, button, output_image, detection_info],
            queue=False
        )
        button.click(
            fn=detect_objects,
            inputs=[input_image, language_selector, model_selector, threshold_slider],
            outputs=[output_image, detection_info]
        )
    return app

load_model("DETR ResNet-50")

if __name__ == "__main__":
    app = build_app()
    app.launch()