File size: 3,488 Bytes
c984c16
 
 
 
dbb647b
 
 
 
c984c16
dbb647b
c4ccf03
c984c16
c4ccf03
c984c16
 
 
 
 
39f30cb
c984c16
 
 
 
 
 
 
 
 
 
 
 
 
 
39f30cb
c984c16
 
 
 
 
 
39f30cb
c4ccf03
c984c16
c4ccf03
c984c16
 
 
 
c4ccf03
c984c16
c4ccf03
 
c984c16
c4ccf03
c984c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4ccf03
dbb647b
 
c4ccf03
c984c16
c4ccf03
dbb647b
c984c16
 
 
 
dbb647b
 
c984c16
 
 
39f30cb
dbb647b
c4ccf03
dbb647b
c984c16
 
dbb647b
39f30cb
dbb647b
c4ccf03
c984c16
c4ccf03
dbb647b
 
 
 
c984c16
dbb647b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import requests
import importlib.util
import sys
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr

# ================================
# 1. Função para baixar arquivos
# ================================
def download_file(url, filename):
    if not os.path.exists(filename):
        print(f"Baixando {filename}...")
        r = requests.get(url)
        with open(filename, "wb") as f:
            f.write(r.content)
        print(f"{filename} baixado!")

# ================================
# 2. Baixar arquivos do Surya
# ================================
files = {
    "helio_spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/helio_spectformer.py",
    "spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/spectformer.py",
    "embedding.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/embedding.py",
    "flow.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/flow.py"
}

for fname, url in files.items():
    download_file(url, fname)

# ================================
# 3. Baixar pesos do Surya
# ================================
MODEL_FILE = "surya.366m.v1.pt"
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
download_file(MODEL_URL, MODEL_FILE)

# ================================
# 4. Importar dinamicamente as classes
# ================================
spec = importlib.util.spec_from_file_location("helio_spectformer", "helio_spectformer.py")
helio_module = importlib.util.module_from_spec(spec)
sys.modules["helio_spectformer"] = helio_module
spec.loader.exec_module(helio_module)

HelioSpectFormer = helio_module.HelioSpectFormer

# ================================
# 5. Instanciar o modelo
# ================================
model = HelioSpectFormer(
    img_size=224,
    patch_size=16,
    in_chans=1,
    embed_dim=366,
    time_embedding={"type": "linear", "time_dim": 1},
    depth=8,
    n_spectral_blocks=4,
    num_heads=8,
    mlp_ratio=4.0,
    drop_rate=0.0,
    window_size=7,
    dp_rank=1,
    learned_flow=False,
    finetune=True
)

state_dict = torch.load(MODEL_FILE, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()

# ================================
# 6. Função de inferência
# ================================
def infer_solar_image_heatmap(img):
    # Pré-processamento da imagem
    img_gray = img.convert("L").resize((224,224))
    ts_tensor = torch.tensor(np.array(img_gray), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(2) / 255.0
    batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1,1))}

    with torch.no_grad():
        outputs = model(batch)

    # Criar heatmap da saída
    emb = outputs.squeeze().numpy()
    heatmap = emb - emb.min()
    heatmap /= heatmap.max() + 1e-8

    plt.imshow(heatmap, cmap="hot")
    plt.axis("off")
    plt.tight_layout()
    return plt.gcf()

# ================================
# 7. Interface Gradio
# ================================
interface = gr.Interface(
    fn=infer_solar_image_heatmap,
    inputs=gr.Image(type="pil"),
    outputs=gr.Plot(label="Heatmap do embedding Surya"),
    title="Playground Surya-1.0",
    description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)

interface.launch()