File size: 7,126 Bytes
fa2b4a8
ba932fd
 
 
 
fa2b4a8
ba932fd
 
fa2b4a8
ba932fd
fa2b4a8
ba932fd
 
fa2b4a8
ba932fd
 
 
 
 
 
 
 
 
 
 
 
8cfacf4
 
 
 
 
 
 
 
ba932fd
8cfacf4
 
 
 
 
 
 
 
 
ba932fd
 
 
 
 
 
 
 
16c2fe3
 
eb8c75c
ba932fd
 
 
 
 
 
fa2b4a8
ba932fd
 
 
 
 
 
 
fa2b4a8
 
cdc152a
8cfacf4
 
 
ba932fd
8cfacf4
 
 
 
 
 
 
 
 
 
 
 
 
ba932fd
8cfacf4
 
 
 
 
 
 
ba932fd
8cfacf4
 
 
 
 
3f789e4
8cfacf4
 
 
 
 
 
 
 
 
 
 
 
 
ba932fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa2b4a8
 
 
 
ba932fd
 
fa2b4a8
 
 
 
 
8658ccb
4309199
ba932fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
import os
import zipfile
import tensorflow as tf

# --- Extraer y cargar modelo TensorFlow desde zip ---
zip_path = "saved_model.zip"
extract_dir = "saved_model"
if not os.path.exists(extract_dir):
    os.makedirs(extract_dir)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)

model_tf = tf.saved_model.load(extract_dir)

# Función helper para inferencia TensorFlow
def predict_tf(img: Image.Image):
    try:
        # Preprocesar imagen para TF: convertir a tensor float32, normalizar, añadir batch
        img_resized = img.resize((224,224))  # ajusta según modelo
        img_np = np.array(img_resized) / 255.0
        if img_np.shape[-1] == 4:  # eliminar canal alfa si existe
            img_np = img_np[..., :3]
        img_tf = tf.convert_to_tensor(img_np, dtype=tf.float32)
        img_tf = tf.expand_dims(img_tf, axis=0)  # batch dimension

        # Ejecutar modelo (suponiendo firma default)
        infer = model_tf.signatures["serving_default"]
        output = infer(img_tf)
        pred = list(output.values())[0].numpy()[0]
        probs = tf.nn.softmax(pred).numpy()
        return probs
    except Exception as e:
        print(f"Error en predict_tf: {e}")
        return np.zeros(len(CLASSES))

# 🔹 Cargar modelo ViT desde Hugging Face
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()

# 🔹 Cargar modelos Fast.ai desde archivos locales
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")

# 🔹 Clases y niveles de riesgo
CLASSES = [
    "Queratosis actínica / Bowen", "Carcinoma células basales",
    "Lesión queratósica benigna", "Dermatofibroma",
    "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
    0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
    1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
    2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
    5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}

def analizar_lesion_combined(img):
    try:
        # Convertir imagen para Fastai
        img_fastai = PILImage.create(img)

        # ViT prediction
        inputs = feature_extractor(img, return_tensors="pt")
        with torch.no_grad():
            outputs = model_vit(**inputs)
            probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
        pred_idx_vit = int(np.argmax(probs_vit))
        pred_class_vit = CLASSES[pred_idx_vit]
        confidence_vit = probs_vit[pred_idx_vit]
    except Exception as e:
        print(f"Error en ViT prediction: {e}")
        pred_class_vit = "Error"
        confidence_vit = 0.0
        probs_vit = np.zeros(len(CLASSES))

    try:
        # Fast.ai models
        pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
        prob_malignant = float(probs_fast_mal[1])  # 1 = maligno
    except Exception as e:
        print(f"Error en Fast.ai malignancy: {e}")
        prob_malignant = 0.0

    try:
        pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)
    except Exception as e:
        print(f"Error en Fast.ai tipo: {e}")
        pred_fast_type = "Error"

    try:
        # TensorFlow model prediction
        probs_tf = predict_tf(img)
        pred_idx_tf = int(np.argmax(probs_tf))
        confidence_tf = probs_tf[pred_idx_tf]
        if pred_idx_tf < len(CLASSES):
            pred_class_tf = CLASSES[pred_idx_tf]
        else:
            pred_class_tf = f"Clase desconocida (índice {pred_idx_tf})"
    except Exception as e:
        print(f"Error en TensorFlow prediction: {e}")
        pred_class_tf = "Error"
        confidence_tf = 0.0

    # Gráfico ViT
    colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.bar(CLASSES, probs_vit*100, color=colors_bars)
    ax.set_title("Probabilidad ViT por tipo de lesión")
    ax.set_ylabel("Probabilidad (%)")
    ax.set_xticks(np.arange(len(CLASSES)))  # evita warning
    ax.set_xticklabels(CLASSES, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.2)
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    img_bytes = buf.getvalue()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'

    # Informe HTML
    informe = f"""
    <div style="font-family:sans-serif; max-width:800px; margin:auto">
    <h2>🧪 Diagnóstico por 4 modelos de IA</h2>
    <table style="border-collapse: collapse; width:100%; font-size:16px">
        <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
        <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
        <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
        <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
        <tr><td>🔬 TensorFlow (saved_model)</td><td><b>{pred_class_tf}</b></td><td>{confidence_tf:.1%}</td></tr>
    </table>
    <br>
    <b>🩺 Recomendación automática:</b><br>
    """

    # Recomendación basada en ViT + malignidad (podrías adaptar aquí según TF)
    cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
    if prob_malignant > 0.7 or cancer_risk_score > 0.6:
        informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
    elif prob_malignant > 0.4 or cancer_risk_score > 0.4:
        informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
    elif cancer_risk_score > 0.2:
        informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
    else:
        informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"

    informe += "</div>"

    return informe, html_chart

# Interfaz Gradio
demo = gr.Interface(
    fn=analizar_lesion_combined,
    inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
    outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
    title="Detector de Lesiones Cutáneas (ViT + Fast.ai + TensorFlow)",
    description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y un modelo TensorFlow.",
    flagging_mode="never"
)

if __name__ == "__main__":
    demo.launch()