LoloSemper commited on
Commit
e041ea9
·
verified ·
1 Parent(s): ca9df00

Upload 4 files

Browse files
ada_learn_malben.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:441760abd9a3a6b143917c68cf1709a1d96323fa334b4952009f50a34ddb57c2
3
+ size 87548073
ada_learn_skin_norm2000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:820252747a54e3e7494fdc86c998d1d67996cb31bde244ddd4260307e80dd819
3
+ size 87739881
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
3
+ from fastai.learner import load_learner
4
+ from fastai.vision.core import PILImage
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import gradio as gr
9
+ import io
10
+ import base64
11
+
12
+ # 🔹 Modelo ViT desde Hugging Face (HAM10000)
13
+ MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
14
+ feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
15
+ model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
16
+ model_vit.eval()
17
+
18
+ # 🔹 Modelos Fast.ai desde archivo local
19
+ model_malignancy = load_learner("ada_learn_malben.pkl")
20
+ model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
21
+
22
+ # 🔹 Modelo binario ISIC preentrenado (alta fiabilidad)
23
+ classifier_isic = pipeline("image-classification", model="VRJBro/skin-cancer-detection")
24
+
25
+ # 🔹 Clases y niveles de riesgo
26
+ CLASSES = [
27
+ "Queratosis actínica / Bowen", "Carcinoma células basales",
28
+ "Lesión queratósica benigna", "Dermatofibroma",
29
+ "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
30
+ ]
31
+ RISK_LEVELS = {
32
+ 0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
33
+ 1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
34
+ 2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
35
+ 3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
36
+ 4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
37
+ 5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
38
+ 6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
39
+ }
40
+
41
+ def analizar_lesion_combined(img):
42
+ img_fastai = PILImage.create(img)
43
+
44
+ # 🔹 ViT prediction
45
+ inputs = feature_extractor(img, return_tensors="pt")
46
+ with torch.no_grad():
47
+ outputs = model_vit(**inputs)
48
+ probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
49
+ pred_idx_vit = int(np.argmax(probs_vit))
50
+ pred_class_vit = CLASSES[pred_idx_vit]
51
+ confidence_vit = probs_vit[pred_idx_vit]
52
+
53
+ # 🔹 Fast.ai predictions
54
+ pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
55
+ prob_malignant = float(probs_fast_mal[1])
56
+ pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)
57
+
58
+ # 🔹 ISIC binary classification (modelo 4)
59
+ result_isic = classifier_isic(img)
60
+ pred_isic = result_isic[0]['label']
61
+ confidence_isic = result_isic[0]['score']
62
+
63
+ # 🔹 Gráfico ViT
64
+ colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
65
+ fig, ax = plt.subplots(figsize=(8, 3))
66
+ ax.bar(CLASSES, probs_vit*100, color=colors_bars)
67
+ ax.set_title("Probabilidad ViT por tipo de lesión")
68
+ ax.set_ylabel("Probabilidad (%)")
69
+ ax.set_xticks(np.arange(len(CLASSES)))
70
+ ax.set_xticklabels(CLASSES, rotation=45, ha='right')
71
+ ax.grid(axis='y', alpha=0.2)
72
+ plt.tight_layout()
73
+ buf = io.BytesIO()
74
+ plt.savefig(buf, format="png")
75
+ plt.close(fig)
76
+ img_bytes = buf.getvalue()
77
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
78
+ html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
79
+
80
+ # 🔹 Informe HTML
81
+ informe = f"""
82
+ <div style="font-family:sans-serif; max-width:800px; margin:auto">
83
+ <h2>🧪 Diagnóstico por 4 modelos de IA</h2>
84
+ <table style="border-collapse: collapse; width:100%; font-size:16px">
85
+ <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
86
+ <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
87
+ <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
88
+ <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
89
+ <tr><td>🔬 ISIC binario</td><td><b>{pred_isic.capitalize()}</b></td><td>{confidence_isic:.1%}</td></tr>
90
+ </table>
91
+ <br>
92
+ <b>🩺 Recomendación automática:</b><br>
93
+ """
94
+
95
+ cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
96
+ if prob_malignant > 0.7 or cancer_risk_score > 0.6 or (pred_isic == "cancerous" and confidence_isic > 0.9):
97
+ informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
98
+ elif prob_malignant > 0.4 or cancer_risk_score > 0.4 or (pred_isic == "cancerous"):
99
+ informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
100
+ elif cancer_risk_score > 0.2:
101
+ informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
102
+ else:
103
+ informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
104
+
105
+ informe += "</div>"
106
+
107
+ return informe, html_chart
108
+
109
+ # 🔹 Interfaz Gradio actualizada
110
+ demo = gr.Interface(
111
+ fn=analizar_lesion_combined,
112
+ inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
113
+ outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
114
+ title="Detector de Lesiones Cutáneas (ViT + Fast.ai + ISIC)",
115
+ description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y un clasificador binario ISIC con alta precisión.",
116
+ flagging_mode="never"
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ fastai==2.7.12
4
+ transformers==4.41.2
5
+ numpy<2
6
+ gradio==5.38.1
7
+ matplotlib