Spaces:
Sleeping
Sleeping
File size: 3,488 Bytes
c984c16 dbb647b c984c16 dbb647b c4ccf03 c984c16 c4ccf03 c984c16 39f30cb c984c16 39f30cb c984c16 39f30cb c4ccf03 c984c16 c4ccf03 c984c16 c4ccf03 c984c16 c4ccf03 c984c16 c4ccf03 c984c16 c4ccf03 dbb647b c4ccf03 c984c16 c4ccf03 dbb647b c984c16 dbb647b c984c16 39f30cb dbb647b c4ccf03 dbb647b c984c16 dbb647b 39f30cb dbb647b c4ccf03 c984c16 c4ccf03 dbb647b c984c16 dbb647b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import requests
import importlib.util
import sys
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
# ================================
# 1. Função para baixar arquivos
# ================================
def download_file(url, filename):
if not os.path.exists(filename):
print(f"Baixando {filename}...")
r = requests.get(url)
with open(filename, "wb") as f:
f.write(r.content)
print(f"{filename} baixado!")
# ================================
# 2. Baixar arquivos do Surya
# ================================
files = {
"helio_spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/helio_spectformer.py",
"spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/spectformer.py",
"embedding.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/embedding.py",
"flow.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/flow.py"
}
for fname, url in files.items():
download_file(url, fname)
# ================================
# 3. Baixar pesos do Surya
# ================================
MODEL_FILE = "surya.366m.v1.pt"
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
download_file(MODEL_URL, MODEL_FILE)
# ================================
# 4. Importar dinamicamente as classes
# ================================
spec = importlib.util.spec_from_file_location("helio_spectformer", "helio_spectformer.py")
helio_module = importlib.util.module_from_spec(spec)
sys.modules["helio_spectformer"] = helio_module
spec.loader.exec_module(helio_module)
HelioSpectFormer = helio_module.HelioSpectFormer
# ================================
# 5. Instanciar o modelo
# ================================
model = HelioSpectFormer(
img_size=224,
patch_size=16,
in_chans=1,
embed_dim=366,
time_embedding={"type": "linear", "time_dim": 1},
depth=8,
n_spectral_blocks=4,
num_heads=8,
mlp_ratio=4.0,
drop_rate=0.0,
window_size=7,
dp_rank=1,
learned_flow=False,
finetune=True
)
state_dict = torch.load(MODEL_FILE, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()
# ================================
# 6. Função de inferência
# ================================
def infer_solar_image_heatmap(img):
# Pré-processamento da imagem
img_gray = img.convert("L").resize((224,224))
ts_tensor = torch.tensor(np.array(img_gray), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(2) / 255.0
batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1,1))}
with torch.no_grad():
outputs = model(batch)
# Criar heatmap da saída
emb = outputs.squeeze().numpy()
heatmap = emb - emb.min()
heatmap /= heatmap.max() + 1e-8
plt.imshow(heatmap, cmap="hot")
plt.axis("off")
plt.tight_layout()
return plt.gcf()
# ================================
# 7. Interface Gradio
# ================================
interface = gr.Interface(
fn=infer_solar_image_heatmap,
inputs=gr.Image(type="pil"),
outputs=gr.Plot(label="Heatmap do embedding Surya"),
title="Playground Surya-1.0",
description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)
interface.launch()
|