File size: 5,042 Bytes
fa2b4a8
ba932fd
 
 
 
fa2b4a8
ba932fd
 
fa2b4a8
ba932fd
fa2b4a8
f3d9a1b
 
 
 
 
 
 
ba932fd
 
 
 
e44c49f
f3d9a1b
16c2fe3
 
eb8c75c
f3d9a1b
ba932fd
 
 
 
 
fa2b4a8
ba932fd
f3d9a1b
 
 
 
 
 
fa2b4a8
f3d9a1b
34a30bc
cdc152a
f3d9a1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba932fd
f3d9a1b
 
ba932fd
b224d9f
ba932fd
 
 
 
 
 
f3d9a1b
ba932fd
 
 
f3d9a1b
 
 
 
 
 
 
 
 
ba932fd
f3d9a1b
 
ba932fd
f3d9a1b
ba932fd
f3d9a1b
 
ba932fd
 
f3d9a1b
ba932fd
 
 
fa2b4a8
 
f3d9a1b
 
 
fa2b4a8
 
 
f3d9a1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64

# --- Cargar modelo ViT preentrenado fine‑tuned HAM10000 ---
TF_MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification"
feature_extractor_tf = ViTImageProcessor.from_pretrained(TF_MODEL_NAME)
model_tf_vit = ViTForImageClassification.from_pretrained(TF_MODEL_NAME)
model_tf_vit.eval()

# 🔹 Cargar modelo ViT base
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()

# 🔹 Cargar modelos Fast.ai locales
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")

# Clases estándar de HAM10000
CLASSES = [
    "Queratosis actínica / Bowen", "Carcinoma células basales",
    "Lesión queratósica benigna", "Dermatofibroma",
    "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
    0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
    1: {'level': 'Alto',     'color': '#ff4444', 'weight': 0.8},
    2: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1},
    3: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1},
    4: {'level': 'Crítico',  'color': '#cc0000', 'weight': 1.0},
    5: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1},
    6: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1}
}
MALIGNANT_INDICES = [0, 1, 4]  # akiec, bcc, melanoma

def analizar_lesion_combined(img):
    img_fastai = PILImage.create(img)

    # ViT base
    inputs = feature_extractor(img, return_tensors="pt")
    with torch.no_grad():
        outputs = model_vit(**inputs)
        probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
    idx_vit = int(np.argmax(probs_vit))
    class_vit = CLASSES[idx_vit]
    conf_vit = probs_vit[idx_vit]

    # Fast.ai modelos
    _, _, probs_mal = model_malignancy.predict(img_fastai)
    prob_malign = float(probs_mal[1])
    pred_fast_type, _, _ = model_norm2000.predict(img_fastai)

    # ViT pre-trained fine-tuned (último modelo recomendado)
    inputs_tf = feature_extractor_tf(img, return_tensors="pt")
    with torch.no_grad():
        outputs_tf = model_tf_vit(**inputs_tf)
        probs_tf = outputs_tf.logits.softmax(dim=-1).cpu().numpy()[0]
    idx_tf = int(np.argmax(probs_tf))
    class_tf_model = CLASSES[idx_tf]
    conf_tf = probs_tf[idx_tf]
    mal_tf = "Maligno" if idx_tf in MALIGNANT_INDICES else "Benigno"

    # Gráfico ViT base
    colors = [RISK_LEVELS[i]['color'] for i in range(7)]
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.bar(CLASSES, probs_vit*100, color=colors)
    ax.set_title("Probabilidad ViT base por tipo de lesión")
    ax.set_ylabel("Probabilidad (%)")
    ax.set_xticks(np.arange(len(CLASSES)))
    ax.set_xticklabels(CLASSES, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.2)
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    html_chart = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" style="max-width:100%"/>'

    informe = f"""
    <div style="font-family:sans-serif; max-width:800px; margin:auto">
    <h2>🧪 Diagnóstico por múltiples modelos de IA</h2>
    <table style="width:100%; font-size:16px; border-collapse:collapse">
      <tr><th>Modelo</th><th>Resultado</th><th>Confianza</th></tr>
      <tr><td>🧠 ViT base</td><td><b>{class_vit}</b></td><td>{conf_vit:.1%}</td></tr>
      <tr><td>🧬 Fast.ai (tipo)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
      <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{'Maligno' if prob_malign > 0.5 else 'Benigno'}</b></td><td>{prob_malign:.1%}</td></tr>
      <tr><td>🌟 ViT fined‑tuned (HAM10000)</td><td><b>{mal_tf} ({class_tf_model})</b></td><td>{conf_tf:.1%}</td></tr>
    </table><br>
    <b>🩺 Recomendación automática:</b><br>
    """
    risk = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
    if prob_malign > 0.7 or risk > 0.6:
        informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
    elif prob_malign > 0.4 or risk > 0.4:
        informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
    elif risk > 0.2:
        informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada en 2-4 semanas"
    else:
        informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
    informe += "</div>"""

    return informe, html_chart

demo = gr.Interface(
    fn=analizar_lesion_combined,
    inputs=gr.Image(type="pil"),
    outputs=[gr.HTML(label="Informe"), gr.HTML(label="Gráfico ViT base")],
    title="Detector de Lesiones Cutáneas (ViT + Fast.ai)",
)
if __name__ == "__main__":
    demo.launch()