import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
# --- Cargar modelo ViT preentrenado fine‑tuned HAM10000 ---
TF_MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification"
feature_extractor_tf = ViTImageProcessor.from_pretrained(TF_MODEL_NAME)
model_tf_vit = ViTForImageClassification.from_pretrained(TF_MODEL_NAME)
model_tf_vit.eval()
# 🔹 Cargar modelo ViT base
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()
# 🔹 Cargar modelos Fast.ai locales
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
# 🔹 EfficientNet B7 para binario (benigno vs maligno)
model_eff = EfficientNet.from_pretrained("efficientnet-b7", num_classes=2)
model_eff.eval()
eff_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Clases estándar de HAM10000
CLASSES = [
"Queratosis actínica / Bowen", "Carcinoma células basales",
"Lesión queratósica benigna", "Dermatofibroma",
"Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}
MALIGNANT_INDICES = [0, 1, 4] # akiec, bcc, melanoma
def analizar_lesion_combined(img):
img_fastai = PILImage.create(img)
# ViT base
inputs = feature_extractor(img, return_tensors="pt")
with torch.no_grad():
outputs = model_vit(**inputs)
probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
idx_vit = int(np.argmax(probs_vit))
class_vit = CLASSES[idx_vit]
conf_vit = probs_vit[idx_vit]
# Fast.ai modelos
_, _, probs_mal = model_malignancy.predict(img_fastai)
prob_malign = float(probs_mal[1])
pred_fast_type, _, _ = model_norm2000.predict(img_fastai)
# ViT fine‑tuned (último modelo recomendado)
inputs_tf = feature_extractor_tf(img, return_tensors="pt")
with torch.no_grad():
outputs_tf = model_tf_vit(**inputs_tf)
probs_tf = outputs_tf.logits.softmax(dim=-1).cpu().numpy()[0]
idx_tf = int(np.argmax(probs_tf))
class_tf_model = CLASSES[idx_tf]
conf_tf = probs_tf[idx_tf]
mal_tf = "Maligno" if idx_tf in MALIGNANT_INDICES else "Benigno"
# EfficientNet B7
img_eff = eff_transform(img).unsqueeze(0)
with torch.no_grad():
out_eff = model_eff(img_eff)
prob_eff = torch.softmax(out_eff, dim=1)[0, 1].item()
eff_result = "Maligno" if prob_eff > 0.5 else "Benigno"
# Gráfico ViT base
colors = [RISK_LEVELS[i]['color'] for i in range(7)]
fig, ax = plt.subplots(figsize=(8, 3))
ax.bar(CLASSES, probs_vit*100, color=colors)
ax.set_title("Probabilidad ViT base por tipo de lesión")
ax.set_ylabel("Probabilidad (%)")
ax.set_xticks(np.arange(len(CLASSES)))
ax.set_xticklabels(CLASSES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
html_chart = f''
# Generar informe
informe = f"""
Modelo | Resultado | Confianza |
---|---|---|
🧠 ViT base | {class_vit} | {conf_vit:.1%} |
🧬 Fast.ai (tipo) | {pred_fast_type} | N/A |
⚠️ Fast.ai (malignidad) | {'Maligno' if prob_malign > 0.5 else 'Benigno'} | {prob_malign:.1%} |
🌟 ViT fined‑tuned (HAM10000) | {mal_tf} ({class_tf_model}) | {conf_tf:.1%} |
🏥 EfficientNet B7 (binario) | {eff_result} | {prob_eff:.1%} |