File size: 5,886 Bytes
cdc152a
dcd58f1
cdc152a
 
 
 
 
 
 
 
dcd58f1
 
 
cdc152a
dcd58f1
cdc152a
 
 
 
 
dcd58f1
cdc152a
 
 
dcd58f1
 
 
 
 
 
 
cdc152a
dcd58f1
cdc152a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd58f1
 
 
 
 
 
 
 
 
cdc152a
dcd58f1
cdc152a
 
dcd58f1
cdc152a
 
 
 
 
 
 
 
dcd58f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
import tensorflow as tf
import zipfile
import os

# 🔹 Cargar modelo ViT desde Hugging Face
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()

# 🔹 Cargar modelos Fast.ai desde archivos locales
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")

# 🔹 Preparar y cargar modelo TensorFlow ISIC
zip_path = "saved_model.zip"
extract_dir = "saved_model"
if not os.path.exists(extract_dir):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
model_isic = tf.keras.models.load_model(extract_dir)

# 🔹 Clases y niveles de riesgo
CLASSES = [
    "Queratosis actínica / Bowen", "Carcinoma células basales",
    "Lesión queratósica benigna", "Dermatofibroma",
    "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
    0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
    1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
    2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
    5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}

def preprocess_image_isic(image: Image.Image):
    # Ajustar tamaño y normalización que espera el modelo ISIC
    image = image.resize((224, 224))
    img_array = np.array(image) / 255.0
    if img_array.shape[-1] == 4:  # eliminar canal alpha si existe
        img_array = img_array[..., :3]
    img_array = np.expand_dims(img_array, axis=0)  # batch dimension
    return img_array

def analizar_lesion_combined(img):
    # Convertir imagen para Fastai
    img_fastai = PILImage.create(img)

    # ViT prediction
    inputs = feature_extractor(img, return_tensors="pt")
    with torch.no_grad():
        outputs = model_vit(**inputs)
        probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
    pred_idx_vit = int(np.argmax(probs_vit))
    pred_class_vit = CLASSES[pred_idx_vit]
    confidence_vit = probs_vit[pred_idx_vit]

    # Fast.ai models
    pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
    prob_malignant = float(probs_fast_mal[1])  # índice 1 = maligno
    pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)

    # Modelo TensorFlow ISIC
    x_isic = preprocess_image_isic(img)
    preds_isic = model_isic.predict(x_isic)[0]  # vector probabilidades
    pred_idx_isic = int(np.argmax(preds_isic))
    pred_class_isic = CLASSES[pred_idx_isic]
    confidence_isic = preds_isic[pred_idx_isic]

    # Gráfico ViT
    colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.bar(CLASSES, probs_vit*100, color=colors_bars)
    ax.set_title("Probabilidad ViT por tipo de lesión")
    ax.set_ylabel("Probabilidad (%)")
    ax.set_xticks(np.arange(len(CLASSES)))
    ax.set_xticklabels(CLASSES, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.2)
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    img_bytes = buf.getvalue()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'

    # Informe HTML con los 4 modelos
    informe = f"""
    <div style="font-family:sans-serif; max-width:800px; margin:auto">
    <h2>🧪 Diagnóstico por 4 modelos de IA</h2>
    <table style="border-collapse: collapse; width:100%; font-size:16px">
        <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
        <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
        <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
        <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
        <tr><td>🔬 ISIC TensorFlow</td><td><b>{pred_class_isic}</b></td><td>{confidence_isic:.1%}</td></tr>
    </table>
    <br>
    <b>🩺 Recomendación automática:</b><br>
    """

    # Recomendación basada en ViT + malignidad
    cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
    if prob_malignant > 0.7 or cancer_risk_score > 0.6:
        informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
    elif prob_malignant > 0.4 or cancer_risk_score > 0.4:
        informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
    elif cancer_risk_score > 0.2:
        informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
    else:
        informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"

    informe += "</div>"

    return informe, html_chart

# 🔹 Interfaz Gradio
demo = gr.Interface(
    fn=analizar_lesion_combined,
    inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
    outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
    title="Detector de Lesiones Cutáneas (ViT + Fast.ai + ISIC TensorFlow)",
    description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y el modelo ISIC TensorFlow.",
    flagging_mode="never"
)

if __name__ == "__main__":
    demo.launch()