ajsbsd commited on
Commit
d467bd4
·
verified ·
1 Parent(s): 05fe443

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from neuralop.models import FNO # Make sure this import path matches your environment
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os # For path handling
7
+
8
+ # --- Configuration ---
9
+ # Set paths relative to the root of your Hugging Face Space repository
10
+ MODEL_PATH = "fno_ckpt_single_res"
11
+ DATASET_PATH = "navier_stokes_2d.pt" # Ensure this file is in your repo root
12
+
13
+ # --- Global Variables for Model and Data (loaded once) ---
14
+ MODEL = None
15
+ FULL_DATASET_X = None # To store all initial conditions
16
+
17
+ # --- 1. Model Loading Function ---
18
+ def load_model():
19
+ """Loads the pre-trained FNO model."""
20
+ global MODEL
21
+ if MODEL is None:
22
+ print("Loading FNO model...")
23
+ try:
24
+ # Ensure model is loaded to CPU for general compatibility in Spaces
25
+ MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
26
+ MODEL.eval() # Set to evaluation mode
27
+ print("Model loaded successfully.")
28
+ except Exception as e:
29
+ print(f"Error loading model: {e}")
30
+ raise gr.Error(f"Failed to load model: {e}")
31
+ return MODEL
32
+
33
+ # --- 2. Dataset Loading Function ---
34
+ def load_dataset():
35
+ """Loads the initial conditions dataset."""
36
+ global FULL_DATASET_X
37
+ if FULL_DATASET_X is None:
38
+ print("Loading dataset...")
39
+ try:
40
+ data = torch.load(DATASET_PATH, map_location='cpu')
41
+ if isinstance(data, dict) and 'x' in data:
42
+ FULL_DATASET_X = data['x']
43
+ elif isinstance(data, torch.Tensor):
44
+ FULL_DATASET_X = data
45
+ else:
46
+ raise ValueError("Unknown dataset format or 'x' key missing.")
47
+ print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
48
+ except FileNotFoundError:
49
+ print(f"Dataset file not found at {DATASET_PATH}")
50
+ raise gr.Error(f"Dataset file not found. Please ensure '{DATASET_PATH}' is in your Space.")
51
+ except Exception as e:
52
+ print(f"Error loading dataset: {e}")
53
+ raise gr.Error(f"Failed to load dataset: {e}")
54
+ return FULL_DATASET_X
55
+
56
+ # --- 3. Inference Function for Gradio ---
57
+ def run_inference(sample_index: int):
58
+ """
59
+ Performs inference for a selected sample index from the dataset.
60
+ Returns two Matplotlib figures: one for input, one for output.
61
+ """
62
+ model = load_model()
63
+ dataset = load_dataset()
64
+
65
+ if not (0 <= sample_index < dataset.shape[0]):
66
+ raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
67
+
68
+ # Extract single initial condition and add channel dimension
69
+ # (shape: [1, H, W] -> [1, 1, H, W])
70
+ single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
71
+
72
+ print(f"Running inference for sample index {sample_index}...")
73
+ with torch.no_grad():
74
+ predicted_solution = model(single_initial_condition)
75
+
76
+ # Convert tensors to numpy for plotting
77
+ input_numpy = single_initial_condition.squeeze().cpu().numpy()
78
+ output_numpy = predicted_solution.squeeze().cpu().numpy()
79
+
80
+ # Create Matplotlib figures
81
+ fig_input, ax_input = plt.subplots()
82
+ im_input = ax_input.imshow(input_numpy, cmap='viridis')
83
+ ax_input.set_title(f"Initial Condition (Sample {sample_index})")
84
+ fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
85
+ plt.close(fig_input) # Close to prevent display issues in Gradio
86
+
87
+ fig_output, ax_output = plt.subplots()
88
+ im_output = ax_output.imshow(output_numpy, cmap='viridis')
89
+ ax_output.set_title(f"Predicted Solution")
90
+ fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
91
+ plt.close(fig_output) # Close to prevent display issues in Gradio
92
+
93
+ return fig_input, fig_output
94
+
95
+ # --- Gradio Interface Setup ---
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown(
98
+ """
99
+ # Fourier Neural Operator (FNO) for Navier-Stokes Equations
100
+ Select a sample index from the pre-loaded dataset to see the FNO's prediction
101
+ of the vorticity field evolution.
102
+ """
103
+ )
104
+
105
+ with gr.Row():
106
+ with gr.Column():
107
+ sample_input_slider = gr.Slider(
108
+ minimum=0,
109
+ maximum=9999, # Assuming 10,000 samples based on your dataset shape
110
+ value=0,
111
+ step=1,
112
+ label="Select Sample Index"
113
+ )
114
+ run_button = gr.Button("Generate Solution")
115
+ with gr.Column():
116
+ input_image_plot = gr.Plot(label="Selected Initial Condition")
117
+ output_image_plot = gr.Plot(label="Predicted Solution")
118
+
119
+ # Bind the button click to the inference function
120
+ run_button.click(
121
+ fn=run_inference,
122
+ inputs=[sample_input_slider],
123
+ outputs=[input_image_plot, output_image_plot]
124
+ )
125
+
126
+ # Optional: Load initial data on startup for the first display
127
+ def load_initial_data_and_predict():
128
+ # Ensure model and dataset are loaded when the space starts
129
+ load_model()
130
+ load_dataset()
131
+ # Run inference for the default value (index 0)
132
+ return run_inference(0)
133
+
134
+ demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
135
+
136
+ # Launch the Gradio app (only runs when you test locally)
137
+ if __name__ == "__main__":
138
+ demo.launch()