AndersonConforto commited on
Commit
c984c16
·
1 Parent(s): c4ccf03

first commit

Browse files
Files changed (2) hide show
  1. app.py +68 -36
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,76 +1,108 @@
1
- import gradio as gr
 
 
 
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
- import requests
7
- import os
8
 
9
  # ================================
10
- # 1. Baixar pesos do Surya-1.0
11
  # ================================
12
- MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
13
- MODEL_FILE = "surya.366m.v1.pt"
14
-
15
- def download_model():
16
- if not os.path.exists(MODEL_FILE):
17
- print("Baixando pesos do Surya-1.0...")
18
- r = requests.get(MODEL_URL)
19
- with open(MODEL_FILE, "wb") as f:
20
  f.write(r.content)
21
- print("Download concluído!")
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- download_model()
 
 
 
 
 
24
 
25
  # ================================
26
- # 2. Definir a arquitetura do Surya
27
  # ================================
28
- # Aqui você deve colar ou importar a classe SuryaModel do repo oficial
29
- # Exemplo genérico:
30
- import torch.nn as nn
 
31
 
32
- class SuryaModel(nn.Module):
33
- def __init__(self):
34
- super().__init__()
35
- self.conv = nn.Conv2d(1, 1, kernel_size=3, padding=1)
36
- def forward(self, x):
37
- return self.conv(x)
38
 
39
  # ================================
40
- # 3. Criar instância e carregar pesos
41
  # ================================
42
- model = SuryaModel()
43
- state_dict = torch.load(MODEL_FILE, map_location=torch.device('cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  model.load_state_dict(state_dict)
45
  model.eval()
46
 
47
  # ================================
48
- # 4. Função de inferência para heatmap
49
  # ================================
50
  def infer_solar_image_heatmap(img):
51
- img = img.convert("L").resize((224, 224))
52
- img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
 
 
53
 
54
  with torch.no_grad():
55
- outputs = model(img_tensor)
56
-
 
57
  emb = outputs.squeeze().numpy()
58
  heatmap = emb - emb.min()
59
  heatmap /= heatmap.max() + 1e-8
60
 
61
- plt.imshow(heatmap, cmap='hot')
62
- plt.axis('off')
63
  plt.tight_layout()
64
  return plt.gcf()
65
 
66
  # ================================
67
- # 5. Interface Gradio
68
  # ================================
69
  interface = gr.Interface(
70
  fn=infer_solar_image_heatmap,
71
  inputs=gr.Image(type="pil"),
72
  outputs=gr.Plot(label="Heatmap do embedding Surya"),
73
- title="Playground Surya-1.0 com Heatmap",
74
  description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
75
  )
76
 
 
1
+ import os
2
+ import requests
3
+ import importlib.util
4
+ import sys
5
  import torch
6
  from PIL import Image
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
+ import gradio as gr
 
10
 
11
  # ================================
12
+ # 1. Função para baixar arquivos
13
  # ================================
14
+ def download_file(url, filename):
15
+ if not os.path.exists(filename):
16
+ print(f"Baixando {filename}...")
17
+ r = requests.get(url)
18
+ with open(filename, "wb") as f:
 
 
 
19
  f.write(r.content)
20
+ print(f"{filename} baixado!")
21
+
22
+ # ================================
23
+ # 2. Baixar arquivos do Surya
24
+ # ================================
25
+ files = {
26
+ "helio_spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/helio_spectformer.py",
27
+ "spectformer.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/spectformer.py",
28
+ "embedding.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/embedding.py",
29
+ "flow.py": "https://raw.githubusercontent.com/NASA-IMPACT/Surya/main/surya/flow.py"
30
+ }
31
+
32
+ for fname, url in files.items():
33
+ download_file(url, fname)
34
 
35
+ # ================================
36
+ # 3. Baixar pesos do Surya
37
+ # ================================
38
+ MODEL_FILE = "surya.366m.v1.pt"
39
+ MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
40
+ download_file(MODEL_URL, MODEL_FILE)
41
 
42
  # ================================
43
+ # 4. Importar dinamicamente as classes
44
  # ================================
45
+ spec = importlib.util.spec_from_file_location("helio_spectformer", "helio_spectformer.py")
46
+ helio_module = importlib.util.module_from_spec(spec)
47
+ sys.modules["helio_spectformer"] = helio_module
48
+ spec.loader.exec_module(helio_module)
49
 
50
+ HelioSpectFormer = helio_module.HelioSpectFormer
 
 
 
 
 
51
 
52
  # ================================
53
+ # 5. Instanciar o modelo
54
  # ================================
55
+ model = HelioSpectFormer(
56
+ img_size=224,
57
+ patch_size=16,
58
+ in_chans=1,
59
+ embed_dim=366,
60
+ time_embedding={"type": "linear", "time_dim": 1},
61
+ depth=8,
62
+ n_spectral_blocks=4,
63
+ num_heads=8,
64
+ mlp_ratio=4.0,
65
+ drop_rate=0.0,
66
+ window_size=7,
67
+ dp_rank=1,
68
+ learned_flow=False,
69
+ finetune=True
70
+ )
71
+
72
+ state_dict = torch.load(MODEL_FILE, map_location=torch.device("cpu"))
73
  model.load_state_dict(state_dict)
74
  model.eval()
75
 
76
  # ================================
77
+ # 6. Função de inferência
78
  # ================================
79
  def infer_solar_image_heatmap(img):
80
+ # Pré-processamento da imagem
81
+ img_gray = img.convert("L").resize((224,224))
82
+ ts_tensor = torch.tensor(np.array(img_gray), dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(2) / 255.0
83
+ batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1,1))}
84
 
85
  with torch.no_grad():
86
+ outputs = model(batch)
87
+
88
+ # Criar heatmap da saída
89
  emb = outputs.squeeze().numpy()
90
  heatmap = emb - emb.min()
91
  heatmap /= heatmap.max() + 1e-8
92
 
93
+ plt.imshow(heatmap, cmap="hot")
94
+ plt.axis("off")
95
  plt.tight_layout()
96
  return plt.gcf()
97
 
98
  # ================================
99
+ # 7. Interface Gradio
100
  # ================================
101
  interface = gr.Interface(
102
  fn=infer_solar_image_heatmap,
103
  inputs=gr.Image(type="pil"),
104
  outputs=gr.Plot(label="Heatmap do embedding Surya"),
105
+ title="Playground Surya-1.0",
106
  description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
107
  )
108
 
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  torch
 
2
  pillow
3
  numpy
4
  matplotlib
5
  gradio
6
- requests
 
1
  torch
2
+ einops
3
  pillow
4
  numpy
5
  matplotlib
6
  gradio
7
+ requests