import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
import os
import zipfile
import tensorflow as tf
# --- Extraer y cargar modelo TensorFlow desde zip ---
zip_path = "saved_model.zip"
extract_dir = "saved_model"
if not os.path.exists(extract_dir):
os.makedirs(extract_dir)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
model_tf = tf.saved_model.load(extract_dir)
# Función helper para inferencia TensorFlow
def predict_tf(img: Image.Image):
try:
img_resized = img.resize((224,224))
img_np = np.array(img_resized) / 255.0
if img_np.shape[-1] == 4:
img_np = img_np[..., :3]
img_tf = tf.convert_to_tensor(img_np, dtype=tf.float32)
img_tf = tf.expand_dims(img_tf, axis=0)
infer = model_tf.signatures["serving_default"]
output = infer(img_tf)
pred = list(output.values())[0].numpy()[0]
probs = tf.nn.softmax(pred).numpy()
return probs
except Exception as e:
print(f"Error en predict_tf: {e}")
return np.zeros(2)
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
CLASSES = [
"Queratosis actínica / Bowen", "Carcinoma células basales",
"Lesión queratósica benigna", "Dermatofibroma",
"Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}
def analizar_lesion_combined(img):
try:
img_fastai = PILImage.create(img)
inputs = feature_extractor(img, return_tensors="pt")
with torch.no_grad():
outputs = model_vit(**inputs)
probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
pred_idx_vit = int(np.argmax(probs_vit))
pred_class_vit = CLASSES[pred_idx_vit]
confidence_vit = probs_vit[pred_idx_vit]
except:
pred_class_vit = "Error"
confidence_vit = 0.0
probs_vit = np.zeros(len(CLASSES))
try:
pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
prob_malignant = float(probs_fast_mal[1])
except:
prob_malignant = 0.0
try:
pred_fast_type, _, _ = model_norm2000.predict(img_fastai)
except:
pred_fast_type = "Error"
try:
probs_tf = predict_tf(img)
if len(probs_tf) == 2:
benign_prob, malignant_prob = probs_tf
pred_class_tf = "Maligno" if malignant_prob > benign_prob else "Benigno"
confidence_tf = max(probs_tf)
else:
pred_class_tf = "Modelo no binario"
confidence_tf = 0.0
except:
pred_class_tf = "Error"
confidence_tf = 0.0
colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
fig, ax = plt.subplots(figsize=(8, 3))
ax.bar(CLASSES, probs_vit*100, color=colors_bars)
ax.set_title("Probabilidad ViT por tipo de lesión")
ax.set_ylabel("Probabilidad (%)")
ax.set_xticks(np.arange(len(CLASSES)))
ax.set_xticklabels(CLASSES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
img_bytes = buf.getvalue()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
html_chart = f''
informe = f"""
🔍 Modelo | Resultado | Confianza |
---|---|---|
🧠 ViT (transformer) | {pred_class_vit} | {confidence_vit:.1%} |
🧬 Fast.ai (clasificación) | {pred_fast_type} | N/A |
⚠️ Fast.ai (malignidad) | {"Maligno" if prob_malignant > 0.5 else "Benigno"} | {prob_malignant:.1%} |
🔬 TensorFlow (saved_model) | {pred_class_tf} | {confidence_tf:.1%} |