File size: 7,677 Bytes
d467bd4
 
9979917
d467bd4
 
9979917
8dd0d9d
3d488b6
d467bd4
 
c41cdf5
 
 
d467bd4
 
 
9979917
 
3d488b6
c41cdf5
 
 
9979917
c41cdf5
 
 
 
 
 
 
 
 
3d488b6
d467bd4
c41cdf5
d467bd4
 
c41cdf5
d467bd4
 
3d488b6
c41cdf5
d467bd4
 
 
 
 
3d488b6
d467bd4
c41cdf5
d467bd4
 
c41cdf5
9979917
d467bd4
c41cdf5
d467bd4
 
6e6f8fe
 
 
 
 
d467bd4
 
 
 
 
 
9979917
d467bd4
 
8dd0d9d
 
d467bd4
 
8dd0d9d
d467bd4
 
8dd0d9d
 
3d488b6
 
d467bd4
8dd0d9d
3d488b6
 
8dd0d9d
3d488b6
 
c41cdf5
d467bd4
 
 
8dd0d9d
3d488b6
8dd0d9d
b5b786d
d467bd4
3d488b6
8dd0d9d
d467bd4
8dd0d9d
d467bd4
 
 
3d488b6
d467bd4
 
 
 
9979917
d467bd4
 
 
 
 
9979917
d467bd4
 
 
8dd0d9d
d467bd4
 
 
 
 
 
 
 
 
 
 
 
 
9979917
d467bd4
 
 
 
 
8dd0d9d
 
 
 
 
 
 
 
 
 
d467bd4
 
 
 
 
 
 
 
 
 
 
3d488b6
d467bd4
b5b786d
8dd0d9d
d467bd4
 
 
 
 
6e6f8fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
import numpy as np
import os
# import spaces # No longer needed if running purely on CPU and not using @spaces.GPU()
from huggingface_hub import hf_hub_download

# --- Configuration ---
MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your Space's repo
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" # Your new repo ID
HF_DATASET_FILENAME = "navier_stokes_2d.pt"

# --- Global Variables for Model and Data (loaded once) ---
MODEL = None
FULL_DATASET_X = None

# --- Function to Download Dataset from HF Hub ---
def download_file_from_hf_hub(repo_id, filename):
    """Downloads a file from Hugging Face Hub."""
    print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
    try:
        # hf_hub_download returns the local path to the downloaded file
        local_path = hf_hub_download(repo_id=repo_id, filename=filename)
        print(f"Downloaded {filename} to {local_path} successfully.")
        return local_path
    except Exception as e:
        print(f"Error downloading file from HF Hub: {e}")
        raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")


# --- 1. Model Loading Function (Loads to CPU, device transfer handled in run_inference) ---
def load_model():
    """Loads the pre-trained FNO model to CPU."""
    global MODEL
    if MODEL is None:
        print("Loading FNO model to CPU...")
        try:
            MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
            MODEL.eval() # Set to evaluation mode
            print("Model loaded successfully to CPU.")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise gr.Error(f"Failed to load model: {e}")
    return MODEL

# --- 2. Dataset Loading Function ---
def load_dataset():
    """Downloads and loads the initial conditions dataset from HF Hub."""
    global FULL_DATASET_X
    if FULL_DATASET_X is None:
        local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
        print("Loading dataset from local file...")
        try:
            data = torch.load(local_dataset_path, map_location='cpu')
            if isinstance(data, dict) and 'x' in data:
                FULL_DATASET_X = data['x']
            elif isinstance(dYou can easily add that blurb by inserting a `gr.Markdown()` component within the same `gr.Column()` as your `sample_input_slider` and `run_button`. This effectively places it within Gradio's "flexbox" layout, ensuring it's always visible below the slider and button.

Here's your `app.py` code with the blurb added in the correct place. I've also updated the `run_inference` function to explicitly target `torch.device("cpu")` and removed the `@spaces.GPU()` decorator, which aligns with your successful run on ZeroCPU.

```pythonata, torch.Tensor):
                FULL_DATASET_X = data
            else:
                raise ValueError("Unknown dataset format or 'x' key missing.")
            print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
        except Exception as e:
            print(f"Error loading dataset: {e}")
            raise gr.Error(f"Failed to load dataset from local file: {e}")
    return FULL_DATASET_X

# --- 3. Inference Function for Gradio ---
# Removed @spaces.GPU() decorator as you're running on ZeroCPU
def run_inference(sample_index: int):
    """
    Performs inference for a selected sample index from the dataset on CPU.
    Returns two Matplotlib figures: one for input, one for output.
    """
    # Determine the target device (always CPU for ZeroCPU space)
    device = torch.device("cpu") # Explicitly set to CPU as you're on ZeroCPU

    model = load_model() # Model is initially loaded to CPU

    # Model device check is still good practice, even if always CPU here
    if next(model.parameters()).device != device:
        model.to(device)
        print(f"Model moved to {device} within run_inference.") # Will now print 'Model moved to cpu...'

    dataset = load_dataset()

    if not (0 <= sample_index < dataset.shape[0]):
        raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")

    # Move input tensor to the correct device
    single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
    print(f"Input moved to {device}.") # Will now print 'Input moved to cpu.'

    print(f"Running inference for sample index {sample_index}...")
    with torch.no_grad(): # Disable gradient calculations for inference
        predicted_solution = model(single_initial_condition)

    # Move results back to CPU for plotting with Matplotlib (already on CPU now)
    input_numpy = single_initial_condition.squeeze().cpu().numpy()
    output_numpy = predicted_solution.squeeze().cpu().numpy()

    # Create Matplotlib figures
    fig_input, ax_input = plt.subplots()
    im_input = ax_input.imshow(input_numpy, cmap='viridis')
    ax_input.set_title(f"Initial Condition (Sample {sample_index})")
    fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
    plt.close(fig_input)

    fig_output, ax_output = plt.subplots()
    im_output = ax_output.imshow(output_numpy, cmap='viridis')
    ax_output.set_title(f"Predicted Solution")
    fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
    plt.close(fig_output)

    return fig_input, fig_output

# --- Gradio Interface Setup (MODIFIED to add blurb) ---
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Fourier Neural Operator (FNO) for Navier-Stokes Equations
        Select a sample index from the pre-loaded dataset to see the FNO's prediction
        of the vorticity field evolution.
        """
    )

    with gr.Row():
        with gr.Column():
            sample_input_slider = gr.Slider(
                minimum=0,
                maximum=9999,
                value=0,
                step=1,
                label="Select Sample Index"
            )
            run_button = gr.Button("Generate Solution")

            # --- ADDED BLURB HERE ---
            gr.Markdown(
                """
                ### Project Inspiration
                This Hugging Face Space demonstrates the concepts and models from the research paper **'Principled approaches for extending neural architectures to function spaces for operator learning'** (available as a preprint on [arXiv](https://arxiv.org/abs/2506.10973)). The underlying code for the neural operators and the experiments can be explored further in the associated [GitHub repository](https://github.com/neuraloperator/NNs-to-NOs). The Navier-Stokes dataset used for training and inference, crucial for these fluid dynamics simulations, is openly accessible and citable via [Zenodo](https://zenodo.org/records/12825163).
                """
            )
            # --- END ADDED BLURB ---

        with gr.Column():
            input_image_plot = gr.Plot(label="Selected Initial Condition")
            output_image_plot = gr.Plot(label="Predicted Solution")

    run_button.click(
        fn=run_inference,
        inputs=[sample_input_slider],
        outputs=[input_image_plot, output_image_plot]
    )

    def load_initial_data_and_predict():
        # These functions are called during main process startup (CPU)
        load_model()
        load_dataset()
        # The actual inference call here will now run on CPU
        return run_inference(0)

    demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])

if __name__ == "__main__":
    demo.launch()