import os import requests import importlib.util import sys import torch from PIL import Image import numpy as np import matplotlib.pyplot as plt import gradio as gr # ================================ # 1. Função para baixar arquivos # ================================ def download_file(url, filename): if not os.path.exists(filename): print(f"Baixando {filename}...") r = requests.get(url) with open(filename, "wb") as f: f.write(r.content) print(f"{filename} baixado!") # ================================ # 2. Baixar arquivos do Surya # ================================ files = { "helio_spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/helio_spectformer.py", "spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/spectformer.py", "embedding.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/embedding.py", "flow.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/flow.py" } for fname, url in files.items(): download_file(url, fname) # ================================ # 3. Baixar pesos do Surya # ================================ MODEL_FILE = "surya.366m.v1.pt" MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt" download_file(MODEL_URL, MODEL_FILE) # ================================ # 4. Importar dinamicamente as classes # ================================ spec = importlib.util.spec_from_file_location("helio_spectformer", "helio_spectformer.py") helio_module = importlib.util.module_from_spec(spec) sys.modules["helio_spectformer"] = helio_module spec.loader.exec_module(helio_module) HelioSpectFormer = helio_module.HelioSpectFormer # ================================ # 5. Instanciar o modelo # ================================ model = HelioSpectFormer( img_size=224, patch_size=16, in_chans=1, embed_dim=366, time_embedding={"type": "linear", "time_dim": 1}, depth=8, n_spectral_blocks=4, num_heads=8, mlp_ratio=4.0, drop_rate=0.0, window_size=7, dp_rank=1, learned_flow=False, finetune=True ) state_dict = torch.load(MODEL_FILE, map_location=torch.device("cpu")) model.load_state_dict(state_dict) model.eval() # ================================ # 6. Função de inferência # ================================ def infer_solar_image_heatmap(img): # Pré-processamento da imagem img_gray = img.convert("L").resize((224,224)) ts_tensor = torch.tensor(np.array(img_gray), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(2) / 255.0 batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1,1))} with torch.no_grad(): outputs = model(batch) # Criar heatmap da saída emb = outputs.squeeze().numpy() heatmap = emb - emb.min() heatmap /= heatmap.max() + 1e-8 plt.imshow(heatmap, cmap="hot") plt.axis("off") plt.tight_layout() return plt.gcf() # ================================ # 7. Interface Gradio # ================================ interface = gr.Interface( fn=infer_solar_image_heatmap, inputs=gr.Image(type="pil"), outputs=gr.Plot(label="Heatmap do embedding Surya"), title="Playground Surya-1.0", description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0" ) interface.launch()