LoloSemper commited on
Commit
b207263
·
verified ·
1 Parent(s): 3c9cef1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -5
app.py CHANGED
@@ -8,6 +8,8 @@ import numpy as np
8
  import gradio as gr
9
  import io
10
  import base64
 
 
11
 
12
  # --- Cargar modelo ViT preentrenado fine‑tuned HAM10000 ---
13
  TF_MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification"
@@ -25,6 +27,17 @@ model_vit.eval()
25
  model_malignancy = load_learner("ada_learn_malben.pkl")
26
  model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Clases estándar de HAM10000
29
  CLASSES = [
30
  "Queratosis actínica / Bowen", "Carcinoma células basales",
@@ -59,7 +72,7 @@ def analizar_lesion_combined(img):
59
  prob_malign = float(probs_mal[1])
60
  pred_fast_type, _, _ = model_norm2000.predict(img_fastai)
61
 
62
- # ViT pre-trained fine-tuned (último modelo recomendado)
63
  inputs_tf = feature_extractor_tf(img, return_tensors="pt")
64
  with torch.no_grad():
65
  outputs_tf = model_tf_vit(**inputs_tf)
@@ -69,6 +82,13 @@ def analizar_lesion_combined(img):
69
  conf_tf = probs_tf[idx_tf]
70
  mal_tf = "Maligno" if idx_tf in MALIGNANT_INDICES else "Benigno"
71
 
 
 
 
 
 
 
 
72
  # Gráfico ViT base
73
  colors = [RISK_LEVELS[i]['color'] for i in range(7)]
74
  fig, ax = plt.subplots(figsize=(8, 3))
@@ -84,6 +104,7 @@ def analizar_lesion_combined(img):
84
  plt.close(fig)
85
  html_chart = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" style="max-width:100%"/>'
86
 
 
87
  informe = f"""
88
  <div style="font-family:sans-serif; max-width:800px; margin:auto">
89
  <h2>🧪 Diagnóstico por múltiples modelos de IA</h2>
@@ -93,19 +114,22 @@ def analizar_lesion_combined(img):
93
  <tr><td>🧬 Fast.ai (tipo)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
94
  <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{'Maligno' if prob_malign > 0.5 else 'Benigno'}</b></td><td>{prob_malign:.1%}</td></tr>
95
  <tr><td>🌟 ViT fined‑tuned (HAM10000)</td><td><b>{mal_tf} ({class_tf_model})</b></td><td>{conf_tf:.1%}</td></tr>
 
96
  </table><br>
97
  <b>🩺 Recomendación automática:</b><br>
98
  """
 
 
99
  risk = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
100
- if prob_malign > 0.7 or risk > 0.6:
101
  informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
102
- elif prob_malign > 0.4 or risk > 0.4:
103
  informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
104
  elif risk > 0.2:
105
  informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada en 2-4 semanas"
106
  else:
107
  informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
108
- informe += "</div>"""
109
 
110
  return informe, html_chart
111
 
@@ -113,8 +137,10 @@ demo = gr.Interface(
113
  fn=analizar_lesion_combined,
114
  inputs=gr.Image(type="pil"),
115
  outputs=[gr.HTML(label="Informe"), gr.HTML(label="Gráfico ViT base")],
116
- title="Detector de Lesiones Cutáneas (ViT + Fast.ai)",
117
  )
 
118
  if __name__ == "__main__":
119
  demo.launch()
120
 
 
 
8
  import gradio as gr
9
  import io
10
  import base64
11
+ from torchvision import transforms
12
+ from efficientnet_pytorch import EfficientNet
13
 
14
  # --- Cargar modelo ViT preentrenado fine‑tuned HAM10000 ---
15
  TF_MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification"
 
27
  model_malignancy = load_learner("ada_learn_malben.pkl")
28
  model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
29
 
30
+ # 🔹 EfficientNet B7 para binario (benigno vs maligno)
31
+ model_eff = EfficientNet.from_pretrained("efficientnet-b7", num_classes=2)
32
+ model_eff.eval()
33
+
34
+ eff_transform = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize([0.485, 0.456, 0.406],
38
+ [0.229, 0.224, 0.225])
39
+ ])
40
+
41
  # Clases estándar de HAM10000
42
  CLASSES = [
43
  "Queratosis actínica / Bowen", "Carcinoma células basales",
 
72
  prob_malign = float(probs_mal[1])
73
  pred_fast_type, _, _ = model_norm2000.predict(img_fastai)
74
 
75
+ # ViT finetuned (último modelo recomendado)
76
  inputs_tf = feature_extractor_tf(img, return_tensors="pt")
77
  with torch.no_grad():
78
  outputs_tf = model_tf_vit(**inputs_tf)
 
82
  conf_tf = probs_tf[idx_tf]
83
  mal_tf = "Maligno" if idx_tf in MALIGNANT_INDICES else "Benigno"
84
 
85
+ # EfficientNet B7
86
+ img_eff = eff_transform(img).unsqueeze(0)
87
+ with torch.no_grad():
88
+ out_eff = model_eff(img_eff)
89
+ prob_eff = torch.softmax(out_eff, dim=1)[0, 1].item()
90
+ eff_result = "Maligno" if prob_eff > 0.5 else "Benigno"
91
+
92
  # Gráfico ViT base
93
  colors = [RISK_LEVELS[i]['color'] for i in range(7)]
94
  fig, ax = plt.subplots(figsize=(8, 3))
 
104
  plt.close(fig)
105
  html_chart = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" style="max-width:100%"/>'
106
 
107
+ # Generar informe
108
  informe = f"""
109
  <div style="font-family:sans-serif; max-width:800px; margin:auto">
110
  <h2>🧪 Diagnóstico por múltiples modelos de IA</h2>
 
114
  <tr><td>🧬 Fast.ai (tipo)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
115
  <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{'Maligno' if prob_malign > 0.5 else 'Benigno'}</b></td><td>{prob_malign:.1%}</td></tr>
116
  <tr><td>🌟 ViT fined‑tuned (HAM10000)</td><td><b>{mal_tf} ({class_tf_model})</b></td><td>{conf_tf:.1%}</td></tr>
117
+ <tr><td>🏥 EfficientNet B7 (binario)</td><td><b>{eff_result}</b></td><td>{prob_eff:.1%}</td></tr>
118
  </table><br>
119
  <b>🩺 Recomendación automática:</b><br>
120
  """
121
+
122
+ # Nivel de riesgo automático
123
  risk = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
124
+ if prob_malign > 0.7 or risk > 0.6 or prob_eff > 0.7:
125
  informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
126
+ elif prob_malign > 0.4 or risk > 0.4 or prob_eff > 0.5:
127
  informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
128
  elif risk > 0.2:
129
  informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada en 2-4 semanas"
130
  else:
131
  informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
132
+ informe += "</div>"
133
 
134
  return informe, html_chart
135
 
 
137
  fn=analizar_lesion_combined,
138
  inputs=gr.Image(type="pil"),
139
  outputs=[gr.HTML(label="Informe"), gr.HTML(label="Gráfico ViT base")],
140
+ title="Detector de Lesiones Cutáneas (ViT + Fast.ai + EfficientNet)",
141
  )
142
+
143
  if __name__ == "__main__":
144
  demo.launch()
145
 
146
+