Spaces:
Running
Running
Commit
·
9014040
1
Parent(s):
d0f5598
feat
Browse files
app.py
CHANGED
@@ -112,21 +112,43 @@ model.eval()
|
|
112 |
def infer_solar_image_heatmap(img):
|
113 |
# Pré-processamento da imagem
|
114 |
img_gray = img.convert("L").resize((224, 224))
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
with torch.no_grad():
|
120 |
-
|
121 |
-
|
122 |
-
# Pegar o embedding da saída
|
123 |
-
emb = outputs.squeeze().numpy()
|
124 |
-
heatmap = emb - emb.min()
|
125 |
-
heatmap /= heatmap.max() + 1e-8
|
126 |
|
127 |
-
#
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
plt.tight_layout()
|
131 |
return plt.gcf()
|
132 |
|
|
|
112 |
def infer_solar_image_heatmap(img):
|
113 |
# Pré-processamento da imagem
|
114 |
img_gray = img.convert("L").resize((224, 224))
|
115 |
+
img_np = np.array(img_gray)
|
116 |
+
ts_tensor = (
|
117 |
+
torch.tensor(img_np, dtype=torch.float32)
|
118 |
+
.unsqueeze(0)
|
119 |
+
.unsqueeze(0)
|
120 |
+
.unsqueeze(2)
|
121 |
+
/ 255.0
|
122 |
+
) # [B=1,C=1,T=1,H=224,W=224]
|
123 |
+
batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1, 1))}
|
124 |
+
|
125 |
+
# Inferência (retorna tokens [1, L, D] com finetune=True)
|
126 |
with torch.no_grad():
|
127 |
+
tokens = model(batch).squeeze(0).cpu() # [L, D]
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
# Remover o componente estático de posição para evitar mapa "igual" entre imagens
|
130 |
+
try:
|
131 |
+
pos = model.embedding.pos_embed.squeeze(0).to(tokens.dtype).cpu() # [L, D]
|
132 |
+
if pos.shape == tokens.shape:
|
133 |
+
tokens = tokens - pos
|
134 |
+
except Exception:
|
135 |
+
pass
|
136 |
+
|
137 |
+
# Agregar energia por patch (L2) e remontar 14x14
|
138 |
+
L, D = tokens.shape
|
139 |
+
side = int(L ** 0.5) # 14 para 224/16
|
140 |
+
heat_vec = torch.sqrt((tokens**2).mean(dim=1)) # [L]
|
141 |
+
heat = heat_vec.reshape(side, side).numpy()
|
142 |
+
|
143 |
+
# Normalizar e upsample p/ 224x224 (nearest para simplicidade)
|
144 |
+
heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
|
145 |
+
heat224 = np.kron(heat, np.ones((224 // side, 224 // side)))
|
146 |
+
|
147 |
+
# Overlay sobre a imagem original
|
148 |
+
plt.figure(figsize=(5, 5))
|
149 |
+
plt.imshow(img_np, cmap="gray")
|
150 |
+
plt.imshow(heat224, cmap="inferno", alpha=0.5, vmin=0.0, vmax=1.0)
|
151 |
+
plt.axis("off")
|
152 |
plt.tight_layout()
|
153 |
return plt.gcf()
|
154 |
|