File size: 2,858 Bytes
d3fc090
dbb647b
d3fc090
 
dbb647b
 
 
d3fc090
 
c984c16
 
d3fc090
c984c16
d3fc090
 
c984c16
d3fc090
 
 
 
 
 
 
39f30cb
d3fc090
39f30cb
c4ccf03
d3fc090
c4ccf03
d3fc090
 
 
c4ccf03
 
d3fc090
c4ccf03
c984c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3fc090
 
c4ccf03
dbb647b
 
c4ccf03
d3fc090
c4ccf03
dbb647b
c984c16
d3fc090
c984c16
 
dbb647b
d3fc090
dbb647b
c984c16
 
d3fc090
39f30cb
dbb647b
c4ccf03
dbb647b
d3fc090
 
 
dbb647b
39f30cb
dbb647b
c4ccf03
d3fc090
c4ccf03
dbb647b
 
 
 
d3fc090
dbb647b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from torch import nn
from einops import rearrange
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import requests
import os

# ================================
# 1. Baixar pesos do Surya-1.0
# ================================
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
MODEL_FILE = "surya.366m.v1.pt"

def download_model():
    if not os.path.exists(MODEL_FILE):
        print("Baixando pesos do Surya-1.0...")
        r = requests.get(MODEL_URL)
        with open(MODEL_FILE, "wb") as f:
            f.write(r.content)
        print("Download concluído!")

download_model()

# ================================
# 2. Colar aqui a classe HelioSpectFormer
# ================================
# Copie todo o conteúdo que você me enviou da HelioSpectFormer aqui
# ⚠️ Substitua a seção abaixo pelo código real do repo
from surya.helio_spectformer import HelioSpectFormer  # se você tiver a pasta surya local

# ================================
# 3. Instanciar o modelo com parâmetros padrão
# ================================
model = HelioSpectFormer(
    img_size=224,
    patch_size=16,
    in_chans=1,
    embed_dim=366,
    time_embedding={"type": "linear", "time_dim": 1},
    depth=8,
    n_spectral_blocks=4,
    num_heads=8,
    mlp_ratio=4.0,
    drop_rate=0.0,
    window_size=7,
    dp_rank=1,
    learned_flow=False,
    finetune=True
)

# Carregar pesos
state_dict = torch.load(MODEL_FILE, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()

# ================================
# 4. Função de inferência para heatmap
# ================================
def infer_solar_image_heatmap(img):
    # Pré-processamento da imagem
    img_gray = img.convert("L").resize((224, 224))
    ts_tensor = torch.tensor(np.array(img_gray), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(2) / 255.0
    batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1,1))}

    # Inferência
    with torch.no_grad():
        outputs = model(batch)

    # Pegar o embedding da saída
    emb = outputs.squeeze().numpy()
    heatmap = emb - emb.min()
    heatmap /= heatmap.max() + 1e-8

    # Criar figura do heatmap
    plt.imshow(heatmap, cmap='hot')
    plt.axis('off')
    plt.tight_layout()
    return plt.gcf()

# ================================
# 5. Interface Gradio
# ================================
interface = gr.Interface(
    fn=infer_solar_image_heatmap,
    inputs=gr.Image(type="pil"),
    outputs=gr.Plot(label="Heatmap do embedding Surya"),
    title="Playground Surya-1.0 com Heatmap",
    description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)

interface.launch()