Navier_Stokes / app.py
ajsbsd's picture
Update app.py
3d488b6 verified
raw
history blame
6.31 kB
import gradio as gr
import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
import numpy as np
import os
import spaces
from huggingface_hub import hf_hub_download
# --- Configuration ---
MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your Space's repo
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" # Your new repo ID
HF_DATASET_FILENAME = "navier_stokes_2d.pt"
# --- Global Variables for Model and Data (loaded once) ---
MODEL = None
FULL_DATASET_X = None
# --- Function to Download Dataset from HF Hub ---
def download_file_from_hf_hub(repo_id, filename):
"""Downloads a file from Hugging Face Hub."""
print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
try:
# hf_hub_download returns the local path to the downloaded file
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded {filename} to {local_path} successfully.")
return local_path
except Exception as e:
print(f"Error downloading file from HF Hub: {e}")
raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")
# --- 1. Model Loading Function (Loads to CPU, device transfer handled in run_inference) ---
def load_model():
"""Loads the pre-trained FNO model to CPU."""
global MODEL
if MODEL is None:
print("Loading FNO model to CPU...")
try:
MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
MODEL.eval() # Set to evaluation mode
print("Model loaded successfully to CPU.")
except Exception as e:
print(f"Error loading model: {e}")
raise gr.Error(f"Failed to load model: {e}")
return MODEL
# --- 2. Dataset Loading Function ---
def load_dataset():
"""Downloads and loads the initial conditions dataset from HF Hub."""
global FULL_DATASET_X
if FULL_DATASET_X is None:
local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
print("Loading dataset from local file...")
try:
data = torch.load(local_dataset_path, map_location='cpu')
if isinstance(data, dict) and 'x' in data:
FULL_DATASET_X = data['x']
elif isinstance(data, torch.Tensor):
FULL_DATASET_X = data
else:
raise ValueError("Unknown dataset format or 'x' key missing.")
print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
except Exception as e:
print(f"Error loading dataset: {e}")
raise gr.Error(f"Failed to load dataset from local file: {e}")
return FULL_DATASET_X
# --- 3. Inference Function for Gradio (MODIFIED: Explicit device handling) ---
@spaces.GPU()
def run_inference(sample_index: int):
"""
Performs inference for a selected sample index from the dataset.
Ensures model and input are on the correct device (GPU).
Returns two Matplotlib figures: one for input, one for output.
"""
# Determine the target device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model() # Model is initially loaded to CPU
# Move model to the correct device ONLY when inside the @spaces.GPU() decorated function
# and only if it's not already on the target device.
if next(model.parameters()).device != device:
model.to(device)
print(f"Model moved to {device} within run_inference.")
dataset = load_dataset()
if not (0 <= sample_index < dataset.shape[0]):
raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
# Move input tensor to the correct device directly
single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
print(f"Input moved to {device}.")
print(f"Running inference for sample index {sample_index}...")
with torch.no_grad(): # Disable gradient calculations for inference
predicted_solution = model(single_initial_condition) # This is where the error occurred before
# Move results back to CPU for plotting with Matplotlib
input_numpy = single_initial_condition.squeeze().cpu().numpy()
output_numpy = predicted_solution.squeeze().cpu().numpy()
# Create Matplotlib figures
fig_input, ax_input = plt.subplots()
im_input = ax_input.imshow(input_numpy, cmap='viridis')
ax_input.set_title(f"Initial Condition (Sample {sample_index})")
fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
plt.close(fig_input)
fig_output, ax_output = plt.subplots()
im_output = ax_output.imshow(output_numpy, cmap='viridis')
ax_output.set_title(f"Predicted Solution")
fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
plt.close(fig_output)
return fig_input, fig_output
# --- Gradio Interface Setup (No change) ---
with gr.Blocks() as demo:
gr.Markdown(
"""
# Fourier Neural Operator (FNO) for Navier-Stokes Equations
Select a sample index from the pre-loaded dataset to see the FNO's prediction
of the vorticity field evolution.
"""
)
with gr.Row():
with gr.Column():
sample_input_slider = gr.Slider(
minimum=0,
maximum=9999,
value=0,
step=1,
label="Select Sample Index"
)
run_button = gr.Button("Generate Solution")
with gr.Column():
input_image_plot = gr.Plot(label="Selected Initial Condition")
output_image_plot = gr.Plot(label="Predicted Solution")
run_button.click(
fn=run_inference,
inputs=[sample_input_slider],
outputs=[input_image_plot, output_image_plot]
)
def load_initial_data_and_predict():
# These functions are called during main process startup (CPU)
load_model()
load_dataset()
# The actual inference call here will ensure GPU utilization via @spaces.GPU()
return run_inference(0)
demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
if __name__ == "__main__":
demo.launch()