AndersonConforto commited on
Commit
c4ccf03
·
1 Parent(s): 53dc186

first commit

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -6,13 +6,12 @@ import matplotlib.pyplot as plt
6
  import requests
7
  import os
8
 
9
- # URLs dos arquivos do modelo
 
 
10
  MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
11
-
12
- # Nome local do arquivo
13
  MODEL_FILE = "surya.366m.v1.pt"
14
 
15
- # Função para baixar o modelo se não existir
16
  def download_model():
17
  if not os.path.exists(MODEL_FILE):
18
  print("Baixando pesos do Surya-1.0...")
@@ -21,33 +20,52 @@ def download_model():
21
  f.write(r.content)
22
  print("Download concluído!")
23
 
24
- # Baixar modelo
25
  download_model()
26
 
27
- # Carregar modelo PyTorch
28
- model = torch.load(MODEL_FILE, map_location=torch.device('cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  model.eval()
30
 
31
- # Função para gerar heatmap
 
 
32
  def infer_solar_image_heatmap(img):
33
- # Pré-processamento: grayscale, resize 224x224
34
  img = img.convert("L").resize((224, 224))
35
  img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
36
 
37
  with torch.no_grad():
38
  outputs = model(img_tensor)
39
 
40
- # Criar heatmap
41
  emb = outputs.squeeze().numpy()
42
  heatmap = emb - emb.min()
43
- heatmap /= heatmap.max() + 1e-8 # normalização 0-1
44
 
45
  plt.imshow(heatmap, cmap='hot')
46
  plt.axis('off')
47
  plt.tight_layout()
48
  return plt.gcf()
49
 
50
- # Interface Gradio
 
 
51
  interface = gr.Interface(
52
  fn=infer_solar_image_heatmap,
53
  inputs=gr.Image(type="pil"),
 
6
  import requests
7
  import os
8
 
9
+ # ================================
10
+ # 1. Baixar pesos do Surya-1.0
11
+ # ================================
12
  MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
 
 
13
  MODEL_FILE = "surya.366m.v1.pt"
14
 
 
15
  def download_model():
16
  if not os.path.exists(MODEL_FILE):
17
  print("Baixando pesos do Surya-1.0...")
 
20
  f.write(r.content)
21
  print("Download concluído!")
22
 
 
23
  download_model()
24
 
25
+ # ================================
26
+ # 2. Definir a arquitetura do Surya
27
+ # ================================
28
+ # Aqui você deve colar ou importar a classe SuryaModel do repo oficial
29
+ # Exemplo genérico:
30
+ import torch.nn as nn
31
+
32
+ class SuryaModel(nn.Module):
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.conv = nn.Conv2d(1, 1, kernel_size=3, padding=1)
36
+ def forward(self, x):
37
+ return self.conv(x)
38
+
39
+ # ================================
40
+ # 3. Criar instância e carregar pesos
41
+ # ================================
42
+ model = SuryaModel()
43
+ state_dict = torch.load(MODEL_FILE, map_location=torch.device('cpu'))
44
+ model.load_state_dict(state_dict)
45
  model.eval()
46
 
47
+ # ================================
48
+ # 4. Função de inferência para heatmap
49
+ # ================================
50
  def infer_solar_image_heatmap(img):
 
51
  img = img.convert("L").resize((224, 224))
52
  img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
53
 
54
  with torch.no_grad():
55
  outputs = model(img_tensor)
56
 
 
57
  emb = outputs.squeeze().numpy()
58
  heatmap = emb - emb.min()
59
+ heatmap /= heatmap.max() + 1e-8
60
 
61
  plt.imshow(heatmap, cmap='hot')
62
  plt.axis('off')
63
  plt.tight_layout()
64
  return plt.gcf()
65
 
66
+ # ================================
67
+ # 5. Interface Gradio
68
+ # ================================
69
  interface = gr.Interface(
70
  fn=infer_solar_image_heatmap,
71
  inputs=gr.Image(type="pil"),