Test / app.py
AndersonConforto's picture
first commit
dbb647b
raw
history blame
1.33 kB
import gradio as gr
from transformers import AutoModel
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
# Carregar modelo Surya
model_name = "nasa-ibm-ai4science/Surya-1.0"
model = AutoModel.from_pretrained(model_name)
model.eval()
# Função para gerar heatmap
def infer_solar_image_heatmap(img):
# Pré-processamento: grayscale, resize 224x224
img = img.convert("L").resize((224, 224))
img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
with torch.no_grad():
outputs = model(img_tensor)
# Pegar os primeiros canais e reshapar para visualização
emb = outputs[0].squeeze().numpy()
heatmap = emb - emb.min()
heatmap /= heatmap.max() + 1e-8 # normalização 0-1
# Criar figura
plt.imshow(heatmap, cmap='hot')
plt.axis('off')
plt.tight_layout()
# Salvar figura em buffer
fig = plt.gcf()
return fig
# Interface Gradio
interface = gr.Interface(
fn=infer_solar_image_heatmap,
inputs=gr.Image(type="pil"),
outputs=gr.Plot(label="Heatmap do embedding Surya"),
title="Playground Surya-1.0 com Heatmap",
description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)
interface.launch()