import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
import os
import zipfile
import tensorflow as tf
# --- Extraer y cargar modelo TensorFlow desde zip ---
zip_path = "saved_model.zip"
extract_dir = "saved_model"
if not os.path.exists(extract_dir):
os.makedirs(extract_dir)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
model_tf = tf.saved_model.load(extract_dir)
# Función helper para inferencia TensorFlow
def predict_tf(img: Image.Image):
# Preprocesar imagen para TF: convertir a tensor float32, normalizar, añadir batch
img_resized = img.resize((224,224)) # ajusta según modelo
img_np = np.array(img_resized) / 255.0
if img_np.shape[-1] == 4: # eliminar canal alfa si existe
img_np = img_np[..., :3]
img_tf = tf.convert_to_tensor(img_np, dtype=tf.float32)
img_tf = tf.expand_dims(img_tf, axis=0) # batch dimension
# Ejecutar modelo (suponiendo firma default)
infer = model_tf.signatures["serving_default"]
output = infer(img_tf)
# Extraemos el primer tensor de salida (puede cambiar según modelo)
pred = list(output.values())[0].numpy()[0]
probs = tf.nn.softmax(pred).numpy()
return probs
# 🔹 Cargar modelo ViT desde Hugging Face
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()
# 🔹 Cargar modelos Fast.ai desde archivos locales
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
# 🔹 Clases y niveles de riesgo
CLASSES = [
"Queratosis actínica / Bowen", "Carcinoma células basales",
"Lesión queratósica benigna", "Dermatofibroma",
"Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}
def analizar_lesion_combined(img):
# Convertir imagen para Fastai
img_fastai = PILImage.create(img)
# ViT prediction
inputs = feature_extractor(img, return_tensors="pt")
with torch.no_grad():
outputs = model_vit(**inputs)
probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
pred_idx_vit = int(np.argmax(probs_vit))
pred_class_vit = CLASSES[pred_idx_vit]
confidence_vit = probs_vit[pred_idx_vit]
# Fast.ai models
pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
prob_malignant = float(probs_fast_mal[1]) # 1 = maligno
pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)
# TensorFlow model prediction
probs_tf = predict_tf(img)
pred_idx_tf = int(np.argmax(probs_tf))
confidence_tf = probs_tf[pred_idx_tf]
if pred_idx_tf < len(CLASSES):
pred_class_tf = CLASSES[pred_idx_tf]
else:
pred_class_tf = f"Clase desconocida (índice {pred_idx_tf})"
# Gráfico ViT
colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
fig, ax = plt.subplots(figsize=(8, 3))
ax.bar(CLASSES, probs_vit*100, color=colors_bars)
ax.set_title("Probabilidad ViT por tipo de lesión")
ax.set_ylabel("Probabilidad (%)")
ax.set_xticks(np.arange(len(CLASSES))) # evita warning
ax.set_xticklabels(CLASSES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
img_bytes = buf.getvalue()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
html_chart = f''
# Informe HTML
informe = f"""
🔍 Modelo | Resultado | Confianza |
---|---|---|
🧠 ViT (transformer) | {pred_class_vit} | {confidence_vit:.1%} |
🧬 Fast.ai (clasificación) | {pred_fast_type} | N/A |
⚠️ Fast.ai (malignidad) | {"Maligno" if prob_malignant > 0.5 else "Benigno"} | {prob_malignant:.1%} |
🔬 TensorFlow (saved_model) | {pred_class_tf} | {confidence_tf:.1%} |