AndersonConforto commited on
Commit
9014040
·
1 Parent(s): d0f5598
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -112,21 +112,43 @@ model.eval()
112
  def infer_solar_image_heatmap(img):
113
  # Pré-processamento da imagem
114
  img_gray = img.convert("L").resize((224, 224))
115
- ts_tensor = torch.tensor(np.array(img_gray), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(2) / 255.0
116
- batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1,1))}
117
-
118
- # Inferência
 
 
 
 
 
 
 
119
  with torch.no_grad():
120
- outputs = model(batch)
121
-
122
- # Pegar o embedding da saída
123
- emb = outputs.squeeze().numpy()
124
- heatmap = emb - emb.min()
125
- heatmap /= heatmap.max() + 1e-8
126
 
127
- # Criar figura do heatmap
128
- plt.imshow(heatmap, cmap='hot')
129
- plt.axis('off')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  plt.tight_layout()
131
  return plt.gcf()
132
 
 
112
  def infer_solar_image_heatmap(img):
113
  # Pré-processamento da imagem
114
  img_gray = img.convert("L").resize((224, 224))
115
+ img_np = np.array(img_gray)
116
+ ts_tensor = (
117
+ torch.tensor(img_np, dtype=torch.float32)
118
+ .unsqueeze(0)
119
+ .unsqueeze(0)
120
+ .unsqueeze(2)
121
+ / 255.0
122
+ ) # [B=1,C=1,T=1,H=224,W=224]
123
+ batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1, 1))}
124
+
125
+ # Inferência (retorna tokens [1, L, D] com finetune=True)
126
  with torch.no_grad():
127
+ tokens = model(batch).squeeze(0).cpu() # [L, D]
 
 
 
 
 
128
 
129
+ # Remover o componente estático de posição para evitar mapa "igual" entre imagens
130
+ try:
131
+ pos = model.embedding.pos_embed.squeeze(0).to(tokens.dtype).cpu() # [L, D]
132
+ if pos.shape == tokens.shape:
133
+ tokens = tokens - pos
134
+ except Exception:
135
+ pass
136
+
137
+ # Agregar energia por patch (L2) e remontar 14x14
138
+ L, D = tokens.shape
139
+ side = int(L ** 0.5) # 14 para 224/16
140
+ heat_vec = torch.sqrt((tokens**2).mean(dim=1)) # [L]
141
+ heat = heat_vec.reshape(side, side).numpy()
142
+
143
+ # Normalizar e upsample p/ 224x224 (nearest para simplicidade)
144
+ heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
145
+ heat224 = np.kron(heat, np.ones((224 // side, 224 // side)))
146
+
147
+ # Overlay sobre a imagem original
148
+ plt.figure(figsize=(5, 5))
149
+ plt.imshow(img_np, cmap="gray")
150
+ plt.imshow(heat224, cmap="inferno", alpha=0.5, vmin=0.0, vmax=1.0)
151
+ plt.axis("off")
152
  plt.tight_layout()
153
  return plt.gcf()
154