Test / app.py
AndersonConforto's picture
first commit
53dc186
raw
history blame
1.72 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import requests
import os
# URLs dos arquivos do modelo
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
# Nome local do arquivo
MODEL_FILE = "surya.366m.v1.pt"
# Função para baixar o modelo se não existir
def download_model():
if not os.path.exists(MODEL_FILE):
print("Baixando pesos do Surya-1.0...")
r = requests.get(MODEL_URL)
with open(MODEL_FILE, "wb") as f:
f.write(r.content)
print("Download concluído!")
# Baixar modelo
download_model()
# Carregar modelo PyTorch
model = torch.load(MODEL_FILE, map_location=torch.device('cpu'))
model.eval()
# Função para gerar heatmap
def infer_solar_image_heatmap(img):
# Pré-processamento: grayscale, resize 224x224
img = img.convert("L").resize((224, 224))
img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
with torch.no_grad():
outputs = model(img_tensor)
# Criar heatmap
emb = outputs.squeeze().numpy()
heatmap = emb - emb.min()
heatmap /= heatmap.max() + 1e-8 # normalização 0-1
plt.imshow(heatmap, cmap='hot')
plt.axis('off')
plt.tight_layout()
return plt.gcf()
# Interface Gradio
interface = gr.Interface(
fn=infer_solar_image_heatmap,
inputs=gr.Image(type="pil"),
outputs=gr.Plot(label="Heatmap do embedding Surya"),
title="Playground Surya-1.0 com Heatmap",
description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)
interface.launch()