LoloSemper commited on
Commit
cdc152a
·
verified ·
1 Parent(s): e041ea9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -120
app.py CHANGED
@@ -1,120 +1,54 @@
1
- import torch
2
- from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
3
- from fastai.learner import load_learner
4
- from fastai.vision.core import PILImage
5
- from PIL import Image
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- import gradio as gr
9
- import io
10
- import base64
11
-
12
- # 🔹 Modelo ViT desde Hugging Face (HAM10000)
13
- MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
14
- feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
15
- model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
16
- model_vit.eval()
17
-
18
- # 🔹 Modelos Fast.ai desde archivo local
19
- model_malignancy = load_learner("ada_learn_malben.pkl")
20
- model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
21
-
22
- # 🔹 Modelo binario ISIC preentrenado (alta fiabilidad)
23
- classifier_isic = pipeline("image-classification", model="VRJBro/skin-cancer-detection")
24
-
25
- # 🔹 Clases y niveles de riesgo
26
- CLASSES = [
27
- "Queratosis actínica / Bowen", "Carcinoma células basales",
28
- "Lesión queratósica benigna", "Dermatofibroma",
29
- "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
30
- ]
31
- RISK_LEVELS = {
32
- 0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
33
- 1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
34
- 2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
35
- 3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
36
- 4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
37
- 5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
38
- 6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
39
- }
40
-
41
- def analizar_lesion_combined(img):
42
- img_fastai = PILImage.create(img)
43
-
44
- # 🔹 ViT prediction
45
- inputs = feature_extractor(img, return_tensors="pt")
46
- with torch.no_grad():
47
- outputs = model_vit(**inputs)
48
- probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
49
- pred_idx_vit = int(np.argmax(probs_vit))
50
- pred_class_vit = CLASSES[pred_idx_vit]
51
- confidence_vit = probs_vit[pred_idx_vit]
52
-
53
- # 🔹 Fast.ai predictions
54
- pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
55
- prob_malignant = float(probs_fast_mal[1])
56
- pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)
57
-
58
- # 🔹 ISIC binary classification (modelo 4)
59
- result_isic = classifier_isic(img)
60
- pred_isic = result_isic[0]['label']
61
- confidence_isic = result_isic[0]['score']
62
-
63
- # 🔹 Gráfico ViT
64
- colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
65
- fig, ax = plt.subplots(figsize=(8, 3))
66
- ax.bar(CLASSES, probs_vit*100, color=colors_bars)
67
- ax.set_title("Probabilidad ViT por tipo de lesión")
68
- ax.set_ylabel("Probabilidad (%)")
69
- ax.set_xticks(np.arange(len(CLASSES)))
70
- ax.set_xticklabels(CLASSES, rotation=45, ha='right')
71
- ax.grid(axis='y', alpha=0.2)
72
- plt.tight_layout()
73
- buf = io.BytesIO()
74
- plt.savefig(buf, format="png")
75
- plt.close(fig)
76
- img_bytes = buf.getvalue()
77
- img_b64 = base64.b64encode(img_bytes).decode("utf-8")
78
- html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
79
-
80
- # 🔹 Informe HTML
81
- informe = f"""
82
- <div style="font-family:sans-serif; max-width:800px; margin:auto">
83
- <h2>🧪 Diagnóstico por 4 modelos de IA</h2>
84
- <table style="border-collapse: collapse; width:100%; font-size:16px">
85
- <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
86
- <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
87
- <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
88
- <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
89
- <tr><td>🔬 ISIC binario</td><td><b>{pred_isic.capitalize()}</b></td><td>{confidence_isic:.1%}</td></tr>
90
- </table>
91
- <br>
92
- <b>🩺 Recomendación automática:</b><br>
93
- """
94
-
95
- cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
96
- if prob_malignant > 0.7 or cancer_risk_score > 0.6 or (pred_isic == "cancerous" and confidence_isic > 0.9):
97
- informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
98
- elif prob_malignant > 0.4 or cancer_risk_score > 0.4 or (pred_isic == "cancerous"):
99
- informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
100
- elif cancer_risk_score > 0.2:
101
- informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
102
- else:
103
- informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
104
-
105
- informe += "</div>"
106
-
107
- return informe, html_chart
108
-
109
- # 🔹 Interfaz Gradio actualizada
110
- demo = gr.Interface(
111
- fn=analizar_lesion_combined,
112
- inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
113
- outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
114
- title="Detector de Lesiones Cutáneas (ViT + Fast.ai + ISIC)",
115
- description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y un clasificador binario ISIC con alta precisión.",
116
- flagging_mode="never"
117
- )
118
-
119
- if __name__ == "__main__":
120
- demo.launch()
 
1
+ import torch
2
+ from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
3
+ from fastai.learner import load_learner
4
+ from fastai.vision.core import PILImage
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import gradio as gr
9
+ import io
10
+ import base64
11
+
12
+ # 🔹 Cargar modelo ViT desde Hugging Face (HAM10000)
13
+ MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
14
+ feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
15
+ model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
16
+ model_vit.eval()
17
+
18
+ # 🔹 Cargar modelos Fast.ai
19
+ model_malignancy = load_learner("ada_learn_malben.pkl")
20
+ model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
21
+
22
+ # 🔹 Cargar modelo ISIC 7 clases
23
+ classifier_isic7 = pipeline("image-classification", model="Anwarkh1/Skin_Cancer-Image_Classification")
24
+
25
+ # 🔹 Clases ViT y niveles de riesgo
26
+ CLASSES = [
27
+ "Queratosis actínica / Bowen", "Carcinoma células basales",
28
+ "Lesión queratósica benigna", "Dermatofibroma",
29
+ "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
30
+ ]
31
+ RISK_LEVELS = {
32
+ 0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
33
+ 1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
34
+ 2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
35
+ 3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
36
+ 4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
37
+ 5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
38
+ 6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
39
+ }
40
+
41
+ def analizar_lesion_combined(img):
42
+ img_fastai = PILImage.create(img)
43
+
44
+ # 🔹 ViT transformer (HAM10000)
45
+ inputs = feature_extractor(img, return_tensors="pt")
46
+ with torch.no_grad():
47
+ outputs = model_vit(**inputs)
48
+ probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
49
+ pred_idx_vit = int(np.argmax(probs_vit))
50
+ pred_class_vit = CLASSES[pred_idx_vit]
51
+ confidence_vit = probs_vit[pred_idx_vit]
52
+
53
+ # 🔹 Fast.ai modelos
54
+ pred_fast_malignant, _, pr_