ajsbsd commited on
Commit
acadd4b
Β·
verified Β·
1 Parent(s): 71f864d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -45
app.py CHANGED
@@ -31,7 +31,7 @@ def load_model():
31
  print("Loading FNO model to CPU...")
32
  try:
33
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
34
- MODEL.eval() # Set to evaluation mode
35
  print("Model loaded successfully to CPU.")
36
  except Exception as e:
37
  print(f"Error loading model: {e}")
@@ -58,91 +58,277 @@ def load_dataset():
58
  raise gr.Error(f"Failed to load dataset from local file: {e}")
59
  return FULL_DATASET_X
60
 
61
- def run_inference(sample_index: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  """
63
  Performs inference for a selected sample index from the dataset on CPU.
64
- Returns two Matplotlib figures: one for input, one for output.
65
  """
 
66
  device = torch.device("cpu")
67
-
68
  model = load_model()
69
 
70
  if next(model.parameters()).device != device:
71
  model.to(device)
72
  print(f"Model moved to {device} within run_inference.")
73
 
 
74
  dataset = load_dataset()
75
 
76
  if not (0 <= sample_index < dataset.shape[0]):
77
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
78
 
 
79
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
80
  print(f"Input moved to {device}.")
81
 
 
82
  print(f"Running inference for sample index {sample_index}...")
83
  with torch.no_grad():
84
  predicted_solution = model(single_initial_condition)
85
 
 
86
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
87
  output_numpy = predicted_solution.squeeze().cpu().numpy()
88
 
89
- fig_input, ax_input = plt.subplots()
90
- im_input = ax_input.imshow(input_numpy, cmap='viridis')
91
- ax_input.set_title(f"Initial Condition (Sample {sample_index})")
92
- fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
93
- plt.close(fig_input)
94
-
95
- fig_output, ax_output = plt.subplots()
96
- im_output = ax_output.imshow(output_numpy, cmap='viridis')
97
- ax_output.set_title(f"Predicted Solution")
98
- fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
99
- plt.close(fig_output)
100
 
 
101
  return fig_input, fig_output
102
 
103
- with gr.Blocks() as demo:
104
- gr.Markdown(
105
- """
106
- # Fourier Neural Operator (FNO) for Navier-Stokes Equations
107
- Select a sample index from the pre-loaded dataset to see the FNO's prediction
108
- of the vorticity field evolution.
109
- """
110
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- with gr.Row():
113
- with gr.Column():
114
- sample_input_slider = gr.Slider(
115
- minimum=0,
116
- maximum=9999,
117
- value=0,
118
- step=1,
119
- label="Select Sample Index"
120
- )
121
- run_button = gr.Button("Generate Solution")
122
-
123
- gr.Markdown(
124
- """
125
- ### Project Inspiration
126
- This Hugging Face Space demonstrates the concepts and models from the research paper **'Principled approaches for extending neural architectures to function spaces for operator learning'** (available as a preprint on [arXiv](https://arxiv.org/abs/2506.10973)). The underlying code for the neural operators and the experiments can be explored further in the associated [GitHub repository](https://github.com/neuraloperator/NNs-to-NOs). The Navier-Stokes dataset used for training and inference, crucial for these fluid dynamics simulations, is openly accessible and citable via [Zenodo](https://zenodo.org/records/12825163).
127
- """
128
- )
129
-
130
- with gr.Column():
131
- input_image_plot = gr.Plot(label="Selected Initial Condition")
132
- output_image_plot = gr.Plot(label="Predicted Solution")
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  run_button.click(
135
  fn=run_inference,
136
  inputs=[sample_input_slider],
137
  outputs=[input_image_plot, output_image_plot]
138
  )
 
 
 
 
 
 
 
 
 
139
 
140
  def load_initial_data_and_predict():
141
  load_model()
142
  load_dataset()
143
  return run_inference(0)
144
 
145
- demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  demo.launch()
 
31
  print("Loading FNO model to CPU...")
32
  try:
33
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
34
+ MODEL.eval()
35
  print("Model loaded successfully to CPU.")
36
  except Exception as e:
37
  print(f"Error loading model: {e}")
 
58
  raise gr.Error(f"Failed to load dataset from local file: {e}")
59
  return FULL_DATASET_X
60
 
61
+ def create_enhanced_plot(data, title, is_input=True):
62
+ """Creates enhanced matplotlib plots with professional styling."""
63
+ plt.style.use('dark_background')
64
+ fig, ax = plt.subplots(figsize=(8, 6), facecolor='#0f0f23')
65
+ ax.set_facecolor('#0f0f23')
66
+
67
+ # Enhanced colormap and visualization
68
+ cmap = 'plasma' if is_input else 'viridis'
69
+ im = ax.imshow(data, cmap=cmap, interpolation='bilinear')
70
+
71
+ # Styling
72
+ ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=20)
73
+ ax.set_xlabel('X Coordinate', color='#8892b0', fontsize=10)
74
+ ax.set_ylabel('Y Coordinate', color='#8892b0', fontsize=10)
75
+
76
+ # Enhanced colorbar
77
+ cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02)
78
+ cbar.set_label('Vorticity', color='#8892b0', fontsize=10)
79
+ cbar.ax.tick_params(colors='#8892b0', labelsize=8)
80
+
81
+ # Grid and ticks
82
+ ax.tick_params(colors='#8892b0', labelsize=8)
83
+ ax.grid(True, alpha=0.1, color='white')
84
+
85
+ plt.tight_layout()
86
+ return fig
87
+
88
+ def run_inference(sample_index: int, progress=gr.Progress()):
89
  """
90
  Performs inference for a selected sample index from the dataset on CPU.
91
+ Returns two enhanced Matplotlib figures with progress tracking.
92
  """
93
+ progress(0.1, desc="Loading model...")
94
  device = torch.device("cpu")
 
95
  model = load_model()
96
 
97
  if next(model.parameters()).device != device:
98
  model.to(device)
99
  print(f"Model moved to {device} within run_inference.")
100
 
101
+ progress(0.3, desc="Loading dataset...")
102
  dataset = load_dataset()
103
 
104
  if not (0 <= sample_index < dataset.shape[0]):
105
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
106
 
107
+ progress(0.5, desc="Preparing input...")
108
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
109
  print(f"Input moved to {device}.")
110
 
111
+ progress(0.7, desc="Running inference...")
112
  print(f"Running inference for sample index {sample_index}...")
113
  with torch.no_grad():
114
  predicted_solution = model(single_initial_condition)
115
 
116
+ progress(0.9, desc="Generating visualizations...")
117
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
118
  output_numpy = predicted_solution.squeeze().cpu().numpy()
119
 
120
+ fig_input = create_enhanced_plot(input_numpy, f"Initial Condition β€’ Sample {sample_index}", is_input=True)
121
+ fig_output = create_enhanced_plot(output_numpy, "Predicted Solution", is_input=False)
 
 
 
 
 
 
 
 
 
122
 
123
+ progress(1.0, desc="Complete!")
124
  return fig_input, fig_output
125
 
126
+ def get_random_sample():
127
+ """Returns a random sample index for exploration."""
128
+ dataset = load_dataset()
129
+ return np.random.randint(0, dataset.shape[0])
130
+
131
+ # Custom CSS for professional styling
132
+ custom_css = """
133
+ #main-container {
134
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
135
+ min-height: 100vh;
136
+ }
137
+
138
+ .gradio-container {
139
+ background: rgba(15, 15, 35, 0.95) !important;
140
+ backdrop-filter: blur(10px);
141
+ border-radius: 20px;
142
+ border: 1px solid rgba(255, 255, 255, 0.1);
143
+ box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3);
144
+ }
145
+
146
+ .gr-button {
147
+ background: linear-gradient(45deg, #667eea, #764ba2) !important;
148
+ border: none !important;
149
+ border-radius: 12px !important;
150
+ color: white !important;
151
+ font-weight: 600 !important;
152
+ padding: 12px 24px !important;
153
+ transition: all 0.3s ease !important;
154
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
155
+ }
156
+
157
+ .gr-button:hover {
158
+ transform: translateY(-2px) !important;
159
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
160
+ }
161
+
162
+ .gr-slider input[type="range"] {
163
+ background: linear-gradient(to right, #667eea, #764ba2) !important;
164
+ }
165
+
166
+ .markdown-content h1 {
167
+ background: linear-gradient(45deg, #667eea, #764ba2);
168
+ -webkit-background-clip: text;
169
+ -webkit-text-fill-color: transparent;
170
+ text-align: center;
171
+ font-size: 2.5rem !important;
172
+ margin-bottom: 1rem !important;
173
+ }
174
 
175
+ .markdown-content h3 {
176
+ color: #8892b0 !important;
177
+ border-left: 4px solid #667eea;
178
+ padding-left: 1rem;
179
+ margin-top: 2rem !important;
180
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ .feature-card {
183
+ background: rgba(255, 255, 255, 0.05);
184
+ border-radius: 15px;
185
+ padding: 1.5rem;
186
+ border: 1px solid rgba(255, 255, 255, 0.1);
187
+ backdrop-filter: blur(10px);
188
+ }
189
+
190
+ .status-indicator {
191
+ display: inline-block;
192
+ width: 8px;
193
+ height: 8px;
194
+ background: #00ff88;
195
+ border-radius: 50%;
196
+ margin-right: 8px;
197
+ animation: pulse 2s infinite;
198
+ }
199
+
200
+ @keyframes pulse {
201
+ 0% { opacity: 1; }
202
+ 50% { opacity: 0.5; }
203
+ 100% { opacity: 1; }
204
+ }
205
+
206
+ .info-panel {
207
+ background: rgba(102, 126, 234, 0.1);
208
+ border-left: 4px solid #667eea;
209
+ border-radius: 8px;
210
+ padding: 1rem;
211
+ margin: 1rem 0;
212
+ }
213
+ """
214
+
215
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
216
+ with gr.Column(elem_id="main-container"):
217
+ # Hero Section
218
+ gr.Markdown(
219
+ """
220
+ # 🌊 Neural Fluid Dynamics
221
+ ## Advanced Fourier Neural Operator for Navier-Stokes Equations
222
+
223
+ <div class="info-panel">
224
+ <span class="status-indicator"></span><strong>AI-Powered Fluid Simulation</strong> β€’ Experience cutting-edge neural operators that solve complex partial differential equations in milliseconds instead of hours.
225
+ </div>
226
+ """,
227
+ elem_classes=["markdown-content"]
228
+ )
229
+
230
+ with gr.Row():
231
+ # Left Panel - Controls
232
+ with gr.Column(scale=1):
233
+ gr.Markdown(
234
+ """
235
+ ### πŸŽ›οΈ Simulation Controls
236
+ """,
237
+ elem_classes=["markdown-content"]
238
+ )
239
+
240
+ with gr.Group():
241
+ sample_input_slider = gr.Slider(
242
+ minimum=0,
243
+ maximum=9999,
244
+ value=0,
245
+ step=1,
246
+ label="πŸ“Š Sample Index",
247
+ info="Choose from 10,000 unique fluid scenarios"
248
+ )
249
+
250
+ with gr.Row():
251
+ run_button = gr.Button("πŸš€ Generate Solution", variant="primary", scale=2)
252
+ random_button = gr.Button("🎲 Random", variant="secondary", scale=1)
253
+
254
+ # Information Panel
255
+ gr.Markdown(
256
+ """
257
+ ### πŸ“‹ Model Information
258
+
259
+ <div class="feature-card">
260
+ <strong>Architecture:</strong> Fourier Neural Operator (FNO)<br>
261
+ <strong>Domain:</strong> 2D Fluid Dynamics<br>
262
+ <strong>Resolution:</strong> 64Γ—64 Grid<br>
263
+ <strong>Inference Time:</strong> ~50ms<br>
264
+ <strong>Training Samples:</strong> 10,000+
265
+ </div>
266
+
267
+ ### πŸ”¬ Research Context
268
+ This demo showcases the practical applications of **'Principled approaches for extending neural architectures to function spaces for operator learning'** research.
269
+
270
+ **Key Resources:**
271
+ - πŸ“„ [Research Paper](https://arxiv.org/abs/2506.10973)
272
+ - πŸ’» [Source Code](https://github.com/neuraloperator/NNs-to-NOs)
273
+ - πŸ“Š [Dataset](https://zenodo.org/records/12825163)
274
+ """,
275
+ elem_classes=["markdown-content"]
276
+ )
277
+
278
+ # Right Panel - Visualizations
279
+ with gr.Column(scale=2):
280
+ gr.Markdown(
281
+ """
282
+ ### πŸ“ˆ Simulation Results
283
+ """,
284
+ elem_classes=["markdown-content"]
285
+ )
286
+
287
+ with gr.Row():
288
+ input_image_plot = gr.Plot(
289
+ label="πŸŒ€ Initial Vorticity Field",
290
+ container=True
291
+ )
292
+ output_image_plot = gr.Plot(
293
+ label="⚑ Neural Operator Prediction",
294
+ container=True
295
+ )
296
+
297
+ gr.Markdown(
298
+ """
299
+ <div class="info-panel">
300
+ πŸ’‘ <strong>Pro Tip:</strong> The left plot shows the initial fluid state (vorticity), while the right shows how our neural operator predicts the fluid will evolve. Traditional methods would take hoursβ€”our AI does it instantly!
301
+ </div>
302
+ """,
303
+ elem_classes=["markdown-content"]
304
+ )
305
+
306
+ # Event handlers
307
  run_button.click(
308
  fn=run_inference,
309
  inputs=[sample_input_slider],
310
  outputs=[input_image_plot, output_image_plot]
311
  )
312
+
313
+ random_button.click(
314
+ fn=get_random_sample,
315
+ outputs=[sample_input_slider]
316
+ ).then(
317
+ fn=run_inference,
318
+ inputs=[sample_input_slider],
319
+ outputs=[input_image_plot, output_image_plot]
320
+ )
321
 
322
  def load_initial_data_and_predict():
323
  load_model()
324
  load_dataset()
325
  return run_inference(0)
326
 
327
+ demo.load(
328
+ load_initial_data_and_predict,
329
+ inputs=None,
330
+ outputs=[input_image_plot, output_image_plot]
331
+ )
332
 
333
  if __name__ == "__main__":
334
  demo.launch()