Spaces:
Running
Running
File size: 7,677 Bytes
d467bd4 9979917 d467bd4 9979917 8dd0d9d 3d488b6 d467bd4 c41cdf5 d467bd4 9979917 3d488b6 c41cdf5 9979917 c41cdf5 3d488b6 d467bd4 c41cdf5 d467bd4 c41cdf5 d467bd4 3d488b6 c41cdf5 d467bd4 3d488b6 d467bd4 c41cdf5 d467bd4 c41cdf5 9979917 d467bd4 c41cdf5 d467bd4 6e6f8fe d467bd4 9979917 d467bd4 8dd0d9d d467bd4 8dd0d9d d467bd4 8dd0d9d 3d488b6 d467bd4 8dd0d9d 3d488b6 8dd0d9d 3d488b6 c41cdf5 d467bd4 8dd0d9d 3d488b6 8dd0d9d b5b786d d467bd4 3d488b6 8dd0d9d d467bd4 8dd0d9d d467bd4 3d488b6 d467bd4 9979917 d467bd4 9979917 d467bd4 8dd0d9d d467bd4 9979917 d467bd4 8dd0d9d d467bd4 3d488b6 d467bd4 b5b786d 8dd0d9d d467bd4 6e6f8fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import gradio as gr
import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
import numpy as np
import os
# import spaces # No longer needed if running purely on CPU and not using @spaces.GPU()
from huggingface_hub import hf_hub_download
# --- Configuration ---
MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your Space's repo
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" # Your new repo ID
HF_DATASET_FILENAME = "navier_stokes_2d.pt"
# --- Global Variables for Model and Data (loaded once) ---
MODEL = None
FULL_DATASET_X = None
# --- Function to Download Dataset from HF Hub ---
def download_file_from_hf_hub(repo_id, filename):
"""Downloads a file from Hugging Face Hub."""
print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
try:
# hf_hub_download returns the local path to the downloaded file
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded {filename} to {local_path} successfully.")
return local_path
except Exception as e:
print(f"Error downloading file from HF Hub: {e}")
raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")
# --- 1. Model Loading Function (Loads to CPU, device transfer handled in run_inference) ---
def load_model():
"""Loads the pre-trained FNO model to CPU."""
global MODEL
if MODEL is None:
print("Loading FNO model to CPU...")
try:
MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
MODEL.eval() # Set to evaluation mode
print("Model loaded successfully to CPU.")
except Exception as e:
print(f"Error loading model: {e}")
raise gr.Error(f"Failed to load model: {e}")
return MODEL
# --- 2. Dataset Loading Function ---
def load_dataset():
"""Downloads and loads the initial conditions dataset from HF Hub."""
global FULL_DATASET_X
if FULL_DATASET_X is None:
local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
print("Loading dataset from local file...")
try:
data = torch.load(local_dataset_path, map_location='cpu')
if isinstance(data, dict) and 'x' in data:
FULL_DATASET_X = data['x']
elif isinstance(dYou can easily add that blurb by inserting a `gr.Markdown()` component within the same `gr.Column()` as your `sample_input_slider` and `run_button`. This effectively places it within Gradio's "flexbox" layout, ensuring it's always visible below the slider and button.
Here's your `app.py` code with the blurb added in the correct place. I've also updated the `run_inference` function to explicitly target `torch.device("cpu")` and removed the `@spaces.GPU()` decorator, which aligns with your successful run on ZeroCPU.
```pythonata, torch.Tensor):
FULL_DATASET_X = data
else:
raise ValueError("Unknown dataset format or 'x' key missing.")
print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
except Exception as e:
print(f"Error loading dataset: {e}")
raise gr.Error(f"Failed to load dataset from local file: {e}")
return FULL_DATASET_X
# --- 3. Inference Function for Gradio ---
# Removed @spaces.GPU() decorator as you're running on ZeroCPU
def run_inference(sample_index: int):
"""
Performs inference for a selected sample index from the dataset on CPU.
Returns two Matplotlib figures: one for input, one for output.
"""
# Determine the target device (always CPU for ZeroCPU space)
device = torch.device("cpu") # Explicitly set to CPU as you're on ZeroCPU
model = load_model() # Model is initially loaded to CPU
# Model device check is still good practice, even if always CPU here
if next(model.parameters()).device != device:
model.to(device)
print(f"Model moved to {device} within run_inference.") # Will now print 'Model moved to cpu...'
dataset = load_dataset()
if not (0 <= sample_index < dataset.shape[0]):
raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
# Move input tensor to the correct device
single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
print(f"Input moved to {device}.") # Will now print 'Input moved to cpu.'
print(f"Running inference for sample index {sample_index}...")
with torch.no_grad(): # Disable gradient calculations for inference
predicted_solution = model(single_initial_condition)
# Move results back to CPU for plotting with Matplotlib (already on CPU now)
input_numpy = single_initial_condition.squeeze().cpu().numpy()
output_numpy = predicted_solution.squeeze().cpu().numpy()
# Create Matplotlib figures
fig_input, ax_input = plt.subplots()
im_input = ax_input.imshow(input_numpy, cmap='viridis')
ax_input.set_title(f"Initial Condition (Sample {sample_index})")
fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
plt.close(fig_input)
fig_output, ax_output = plt.subplots()
im_output = ax_output.imshow(output_numpy, cmap='viridis')
ax_output.set_title(f"Predicted Solution")
fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
plt.close(fig_output)
return fig_input, fig_output
# --- Gradio Interface Setup (MODIFIED to add blurb) ---
with gr.Blocks() as demo:
gr.Markdown(
"""
# Fourier Neural Operator (FNO) for Navier-Stokes Equations
Select a sample index from the pre-loaded dataset to see the FNO's prediction
of the vorticity field evolution.
"""
)
with gr.Row():
with gr.Column():
sample_input_slider = gr.Slider(
minimum=0,
maximum=9999,
value=0,
step=1,
label="Select Sample Index"
)
run_button = gr.Button("Generate Solution")
# --- ADDED BLURB HERE ---
gr.Markdown(
"""
### Project Inspiration
This Hugging Face Space demonstrates the concepts and models from the research paper **'Principled approaches for extending neural architectures to function spaces for operator learning'** (available as a preprint on [arXiv](https://arxiv.org/abs/2506.10973)). The underlying code for the neural operators and the experiments can be explored further in the associated [GitHub repository](https://github.com/neuraloperator/NNs-to-NOs). The Navier-Stokes dataset used for training and inference, crucial for these fluid dynamics simulations, is openly accessible and citable via [Zenodo](https://zenodo.org/records/12825163).
"""
)
# --- END ADDED BLURB ---
with gr.Column():
input_image_plot = gr.Plot(label="Selected Initial Condition")
output_image_plot = gr.Plot(label="Predicted Solution")
run_button.click(
fn=run_inference,
inputs=[sample_input_slider],
outputs=[input_image_plot, output_image_plot]
)
def load_initial_data_and_predict():
# These functions are called during main process startup (CPU)
load_model()
load_dataset()
# The actual inference call here will now run on CPU
return run_inference(0)
demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
if __name__ == "__main__":
demo.launch() |