File size: 6,311 Bytes
d467bd4
 
9979917
d467bd4
 
9979917
c41cdf5
3d488b6
d467bd4
 
c41cdf5
 
 
d467bd4
 
 
9979917
 
3d488b6
c41cdf5
 
 
9979917
c41cdf5
 
 
 
 
 
 
 
 
3d488b6
d467bd4
c41cdf5
d467bd4
 
c41cdf5
d467bd4
 
3d488b6
c41cdf5
d467bd4
 
 
 
 
3d488b6
d467bd4
c41cdf5
d467bd4
 
c41cdf5
9979917
d467bd4
c41cdf5
d467bd4
 
 
 
 
 
 
 
 
9979917
d467bd4
 
3d488b6
c41cdf5
d467bd4
 
 
3d488b6
d467bd4
 
3d488b6
 
 
 
d467bd4
3d488b6
 
 
 
 
 
 
c41cdf5
d467bd4
 
 
3d488b6
 
 
b5b786d
d467bd4
3d488b6
 
d467bd4
3d488b6
d467bd4
 
 
3d488b6
d467bd4
 
 
 
9979917
d467bd4
 
 
 
 
9979917
d467bd4
 
 
9979917
d467bd4
 
 
 
 
 
 
 
 
 
 
 
 
9979917
d467bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d488b6
d467bd4
b5b786d
3d488b6
d467bd4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
import numpy as np
import os
import spaces
from huggingface_hub import hf_hub_download

# --- Configuration ---
MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your Space's repo
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" # Your new repo ID
HF_DATASET_FILENAME = "navier_stokes_2d.pt"

# --- Global Variables for Model and Data (loaded once) ---
MODEL = None
FULL_DATASET_X = None

# --- Function to Download Dataset from HF Hub ---
def download_file_from_hf_hub(repo_id, filename):
    """Downloads a file from Hugging Face Hub."""
    print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
    try:
        # hf_hub_download returns the local path to the downloaded file
        local_path = hf_hub_download(repo_id=repo_id, filename=filename)
        print(f"Downloaded {filename} to {local_path} successfully.")
        return local_path
    except Exception as e:
        print(f"Error downloading file from HF Hub: {e}")
        raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")


# --- 1. Model Loading Function (Loads to CPU, device transfer handled in run_inference) ---
def load_model():
    """Loads the pre-trained FNO model to CPU."""
    global MODEL
    if MODEL is None:
        print("Loading FNO model to CPU...")
        try:
            MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
            MODEL.eval() # Set to evaluation mode
            print("Model loaded successfully to CPU.")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise gr.Error(f"Failed to load model: {e}")
    return MODEL

# --- 2. Dataset Loading Function ---
def load_dataset():
    """Downloads and loads the initial conditions dataset from HF Hub."""
    global FULL_DATASET_X
    if FULL_DATASET_X is None:
        local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
        print("Loading dataset from local file...")
        try:
            data = torch.load(local_dataset_path, map_location='cpu')
            if isinstance(data, dict) and 'x' in data:
                FULL_DATASET_X = data['x']
            elif isinstance(data, torch.Tensor):
                FULL_DATASET_X = data
            else:
                raise ValueError("Unknown dataset format or 'x' key missing.")
            print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
        except Exception as e:
            print(f"Error loading dataset: {e}")
            raise gr.Error(f"Failed to load dataset from local file: {e}")
    return FULL_DATASET_X

# --- 3. Inference Function for Gradio (MODIFIED: Explicit device handling) ---
@spaces.GPU()
def run_inference(sample_index: int):
    """
    Performs inference for a selected sample index from the dataset.
    Ensures model and input are on the correct device (GPU).
    Returns two Matplotlib figures: one for input, one for output.
    """
    # Determine the target device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = load_model() # Model is initially loaded to CPU

    # Move model to the correct device ONLY when inside the @spaces.GPU() decorated function
    # and only if it's not already on the target device.
    if next(model.parameters()).device != device:
        model.to(device)
        print(f"Model moved to {device} within run_inference.")

    dataset = load_dataset()

    if not (0 <= sample_index < dataset.shape[0]):
        raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")

    # Move input tensor to the correct device directly
    single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
    print(f"Input moved to {device}.")

    print(f"Running inference for sample index {sample_index}...")
    with torch.no_grad(): # Disable gradient calculations for inference
        predicted_solution = model(single_initial_condition) # This is where the error occurred before

    # Move results back to CPU for plotting with Matplotlib
    input_numpy = single_initial_condition.squeeze().cpu().numpy()
    output_numpy = predicted_solution.squeeze().cpu().numpy()

    # Create Matplotlib figures
    fig_input, ax_input = plt.subplots()
    im_input = ax_input.imshow(input_numpy, cmap='viridis')
    ax_input.set_title(f"Initial Condition (Sample {sample_index})")
    fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
    plt.close(fig_input)

    fig_output, ax_output = plt.subplots()
    im_output = ax_output.imshow(output_numpy, cmap='viridis')
    ax_output.set_title(f"Predicted Solution")
    fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
    plt.close(fig_output)

    return fig_input, fig_output

# --- Gradio Interface Setup (No change) ---
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Fourier Neural Operator (FNO) for Navier-Stokes Equations
        Select a sample index from the pre-loaded dataset to see the FNO's prediction
        of the vorticity field evolution.
        """
    )

    with gr.Row():
        with gr.Column():
            sample_input_slider = gr.Slider(
                minimum=0,
                maximum=9999,
                value=0,
                step=1,
                label="Select Sample Index"
            )
            run_button = gr.Button("Generate Solution")
        with gr.Column():
            input_image_plot = gr.Plot(label="Selected Initial Condition")
            output_image_plot = gr.Plot(label="Predicted Solution")

    run_button.click(
        fn=run_inference,
        inputs=[sample_input_slider],
        outputs=[input_image_plot, output_image_plot]
    )

    def load_initial_data_and_predict():
        # These functions are called during main process startup (CPU)
        load_model()
        load_dataset()
        # The actual inference call here will ensure GPU utilization via @spaces.GPU()
        return run_inference(0)

    demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])

if __name__ == "__main__":
    demo.launch()