CancerSkinTest3 / app.py
LoloSemper's picture
Update app.py
53d4bcf verified
raw
history blame
5.75 kB
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
import base64
import io
from fastai.vision.all import *
import tensorflow as tf
import zipfile
import os
import traceback
from transformers import AutoImageProcessor, AutoModelForImageClassification
# Descomprimir el modelo si no se ha descomprimido aún
if not os.path.exists("saved_model"):
with zipfile.ZipFile("saved_model.zip", "r") as zip_ref:
zip_ref.extractall("saved_model")
# Cargar modelo ISIC con TensorFlow desde el directorio correcto
try:
model_isic = tf.saved_model.load("saved_model")
except Exception as e:
print("\U0001F534 Error al cargar el modelo ISIC:", e)
raise
# Cargar modelos fastai
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
# Cargar modelo ViT
feature_extractor = AutoImageProcessor.from_pretrained("nateraw/vit-skin-cancer")
model_vit = AutoModelForImageClassification.from_pretrained("nateraw/vit-skin-cancer")
# Clases y colores
CLASSES = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
RISK_LEVELS = {
0: {"label": "akiec", "color": "#FF6F61", "weight": 0.9},
1: {"label": "bcc", "color": "#FF8C42", "weight": 0.7},
2: {"label": "bkl", "color": "#FFD166", "weight": 0.3},
3: {"label": "df", "color": "#06D6A0", "weight": 0.1},
4: {"label": "mel", "color": "#EF476F", "weight": 1.0},
5: {"label": "nv", "color": "#118AB2", "weight": 0.2},
6: {"label": "vasc", "color": "#073B4C", "weight": 0.4},
}
def preprocess_image_isic(pil_image):
image = pil_image.resize((224, 224))
array = np.array(image) / 255.0
return np.expand_dims(array, axis=0)
def analizar_lesion_combined(img):
try:
img_fastai = PILImage.create(img)
inputs = feature_extractor(img, return_tensors="pt")
with torch.no_grad():
outputs = model_vit(**inputs)
probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
pred_idx_vit = int(np.argmax(probs_vit))
pred_class_vit = CLASSES[pred_idx_vit]
confidence_vit = probs_vit[pred_idx_vit]
pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
prob_malignant = float(probs_fast_mal[1])
pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)
x_isic = preprocess_image_isic(img)
isic_func = model_isic.signatures["serving_default"]
preds_isic_tensor = isic_func(tf.constant(x_isic))
key = list(preds_isic_tensor.keys())[0]
preds_isic = preds_isic_tensor[key].numpy()[0]
pred_idx_isic = int(np.argmax(preds_isic))
pred_class_isic = CLASSES[pred_idx_isic]
confidence_isic = preds_isic[pred_idx_isic]
colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
fig, ax = plt.subplots(figsize=(8, 3))
ax.bar(CLASSES, probs_vit * 100, color=colors_bars)
ax.set_title("Probabilidad ViT por tipo de lesión")
ax.set_ylabel("Probabilidad (%)")
ax.set_xticks(np.arange(len(CLASSES)))
ax.set_xticklabels(CLASSES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
img_bytes = buf.getvalue()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
informe = f"""<div style="font-family:sans-serif; max-width:800px; margin:auto">
<h2>🧪 Diagnóstico por 4 modelos de IA</h2>
<table style="border-collapse: collapse; width:100%; font-size:16px">
<tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
<tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
<tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
<tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
<tr><td>🔬 ISIC TensorFlow</td><td><b>{pred_class_isic}</b></td><td>{confidence_isic:.1%}</td></tr>
</table><br><b>🮺 Recomendación automática:</b><br>"
cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
if prob_malignant > 0.7 or cancer_risk_score > 0.6:
informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
elif prob_malignant > 0.4 or cancer_risk_score > 0.4:
informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
elif cancer_risk_score > 0.2:
informe += "📜 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
else:
informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
informe += "</div>"
return informe, html_chart
except Exception as e:
print("\U0001F534 ERROR en analizar_lesion_combined:")
print(str(e))
traceback.print_exc()
return f"<b>Error interno:</b> {str(e)}", ""
demo = gr.Interface(
fn=analizar_lesion_combined,
inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
title="Detector de Lesiones Cutáneas (ViT + Fast.ai + ISIC TensorFlow)",
description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y el modelo ISIC TensorFlow.",
flagging_mode="never"
)
if __name__ == "__main__":
demo.launch()