import torch from transformers import ViTImageProcessor, ViTForImageClassification from fastai.learner import load_learner from fastai.vision.core import PILImage from PIL import Image import matplotlib.pyplot as plt import numpy as np import gradio as gr import io import base64 from torchvision import transforms from efficientnet_pytorch import EfficientNet # --- Cargar modelo ViT preentrenado fine‑tuned HAM10000 --- TF_MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification" feature_extractor_tf = ViTImageProcessor.from_pretrained(TF_MODEL_NAME) model_tf_vit = ViTForImageClassification.from_pretrained(TF_MODEL_NAME) model_tf_vit.eval() # 🔹 Cargar modelo ViT base MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32" feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME) model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME) model_vit.eval() # 🔹 Cargar modelos Fast.ai locales model_malignancy = load_learner("ada_learn_malben.pkl") model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl") # 🔹 EfficientNet B7 para binario (benigno vs maligno) model_eff = EfficientNet.from_pretrained("efficientnet-b7", num_classes=2) model_eff.eval() eff_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Clases estándar de HAM10000 CLASSES = [ "Queratosis actínica / Bowen", "Carcinoma células basales", "Lesión queratósica benigna", "Dermatofibroma", "Melanoma maligno", "Nevus melanocítico", "Lesión vascular" ] RISK_LEVELS = { 0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6}, 1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8}, 2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}, 3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}, 4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0}, 5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}, 6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1} } MALIGNANT_INDICES = [0, 1, 4] # akiec, bcc, melanoma def analizar_lesion_combined(img): img_fastai = PILImage.create(img) # ViT base inputs = feature_extractor(img, return_tensors="pt") with torch.no_grad(): outputs = model_vit(**inputs) probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0] idx_vit = int(np.argmax(probs_vit)) class_vit = CLASSES[idx_vit] conf_vit = probs_vit[idx_vit] # Fast.ai modelos _, _, probs_mal = model_malignancy.predict(img_fastai) prob_malign = float(probs_mal[1]) pred_fast_type, _, _ = model_norm2000.predict(img_fastai) # ViT fine‑tuned (último modelo recomendado) inputs_tf = feature_extractor_tf(img, return_tensors="pt") with torch.no_grad(): outputs_tf = model_tf_vit(**inputs_tf) probs_tf = outputs_tf.logits.softmax(dim=-1).cpu().numpy()[0] idx_tf = int(np.argmax(probs_tf)) class_tf_model = CLASSES[idx_tf] conf_tf = probs_tf[idx_tf] mal_tf = "Maligno" if idx_tf in MALIGNANT_INDICES else "Benigno" # EfficientNet B7 img_eff = eff_transform(img).unsqueeze(0) with torch.no_grad(): out_eff = model_eff(img_eff) prob_eff = torch.softmax(out_eff, dim=1)[0, 1].item() eff_result = "Maligno" if prob_eff > 0.5 else "Benigno" # Gráfico ViT base colors = [RISK_LEVELS[i]['color'] for i in range(7)] fig, ax = plt.subplots(figsize=(8, 3)) ax.bar(CLASSES, probs_vit*100, color=colors) ax.set_title("Probabilidad ViT base por tipo de lesión") ax.set_ylabel("Probabilidad (%)") ax.set_xticks(np.arange(len(CLASSES))) ax.set_xticklabels(CLASSES, rotation=45, ha='right') ax.grid(axis='y', alpha=0.2) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png") plt.close(fig) html_chart = f'' # Generar informe informe = f"""

🧪 Diagnóstico por múltiples modelos de IA

ModeloResultadoConfianza
🧠 ViT base{class_vit}{conf_vit:.1%}
🧬 Fast.ai (tipo){pred_fast_type}N/A
⚠️ Fast.ai (malignidad){'Maligno' if prob_malign > 0.5 else 'Benigno'}{prob_malign:.1%}
🌟 ViT fined‑tuned (HAM10000){mal_tf} ({class_tf_model}){conf_tf:.1%}
🏥 EfficientNet B7 (binario){eff_result}{prob_eff:.1%}

🩺 Recomendación automática:
""" # Nivel de riesgo automático risk = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7)) if prob_malign > 0.7 or risk > 0.6 or prob_eff > 0.7: informe += "🚨 CRÍTICO – Derivación urgente a oncología dermatológica" elif prob_malign > 0.4 or risk > 0.4 or prob_eff > 0.5: informe += "⚠️ ALTO RIESGO – Consulta con dermatólogo en 7 días" elif risk > 0.2: informe += "📋 RIESGO MODERADO – Evaluación programada en 2-4 semanas" else: informe += "✅ BAJO RIESGO – Seguimiento de rutina (3-6 meses)" informe += "
" return informe, html_chart demo = gr.Interface( fn=analizar_lesion_combined, inputs=gr.Image(type="pil"), outputs=[gr.HTML(label="Informe"), gr.HTML(label="Gráfico ViT base")], title="Detector de Lesiones Cutáneas (ViT + Fast.ai + EfficientNet)", ) if __name__ == "__main__": demo.launch()