Test / app.py
AndersonConforto's picture
first commit
c984c16
raw
history blame
3.49 kB
import os
import requests
import importlib.util
import sys
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
# ================================
# 1. Função para baixar arquivos
# ================================
def download_file(url, filename):
if not os.path.exists(filename):
print(f"Baixando {filename}...")
r = requests.get(url)
with open(filename, "wb") as f:
f.write(r.content)
print(f"{filename} baixado!")
# ================================
# 2. Baixar arquivos do Surya
# ================================
files = {
"helio_spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/helio_spectformer.py",
"spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/spectformer.py",
"embedding.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/embedding.py",
"flow.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/flow.py"
}
for fname, url in files.items():
download_file(url, fname)
# ================================
# 3. Baixar pesos do Surya
# ================================
MODEL_FILE = "surya.366m.v1.pt"
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
download_file(MODEL_URL, MODEL_FILE)
# ================================
# 4. Importar dinamicamente as classes
# ================================
spec = importlib.util.spec_from_file_location("helio_spectformer", "helio_spectformer.py")
helio_module = importlib.util.module_from_spec(spec)
sys.modules["helio_spectformer"] = helio_module
spec.loader.exec_module(helio_module)
HelioSpectFormer = helio_module.HelioSpectFormer
# ================================
# 5. Instanciar o modelo
# ================================
model = HelioSpectFormer(
img_size=224,
patch_size=16,
in_chans=1,
embed_dim=366,
time_embedding={"type": "linear", "time_dim": 1},
depth=8,
n_spectral_blocks=4,
num_heads=8,
mlp_ratio=4.0,
drop_rate=0.0,
window_size=7,
dp_rank=1,
learned_flow=False,
finetune=True
)
state_dict = torch.load(MODEL_FILE, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()
# ================================
# 6. Função de inferência
# ================================
def infer_solar_image_heatmap(img):
# Pré-processamento da imagem
img_gray = img.convert("L").resize((224,224))
ts_tensor = torch.tensor(np.array(img_gray), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(2) / 255.0
batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1,1))}
with torch.no_grad():
outputs = model(batch)
# Criar heatmap da saída
emb = outputs.squeeze().numpy()
heatmap = emb - emb.min()
heatmap /= heatmap.max() + 1e-8
plt.imshow(heatmap, cmap="hot")
plt.axis("off")
plt.tight_layout()
return plt.gcf()
# ================================
# 7. Interface Gradio
# ================================
interface = gr.Interface(
fn=infer_solar_image_heatmap,
inputs=gr.Image(type="pil"),
outputs=gr.Plot(label="Heatmap do embedding Surya"),
title="Playground Surya-1.0",
description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)
interface.launch()