File size: 13,183 Bytes
d467bd4
 
9979917
d467bd4
 
9979917
3d488b6
d467bd4
71f864d
 
c41cdf5
d467bd4
 
9979917
 
c41cdf5
 
 
9979917
c41cdf5
 
 
 
 
 
 
d467bd4
c41cdf5
d467bd4
 
c41cdf5
d467bd4
 
acadd4b
c41cdf5
d467bd4
 
 
 
 
 
c41cdf5
d467bd4
 
c41cdf5
9979917
d467bd4
c41cdf5
d467bd4
 
71f864d
d467bd4
 
 
 
 
 
9979917
d467bd4
 
acadd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d467bd4
8dd0d9d
acadd4b
d467bd4
acadd4b
71f864d
 
d467bd4
3d488b6
 
71f864d
3d488b6
acadd4b
3d488b6
c41cdf5
d467bd4
 
 
acadd4b
3d488b6
71f864d
b5b786d
acadd4b
d467bd4
71f864d
8dd0d9d
d467bd4
acadd4b
d467bd4
 
 
acadd4b
 
d467bd4
acadd4b
d467bd4
 
acadd4b
 
 
 
 
0dde314
acadd4b
0dde314
 
 
 
 
 
 
 
 
 
acadd4b
2a00688
acadd4b
 
 
 
 
 
 
 
 
0dde314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acadd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d467bd4
acadd4b
 
 
 
 
 
d467bd4
acadd4b
 
 
 
 
 
0dde314
 
 
 
 
acadd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dde314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acadd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d467bd4
 
 
 
 
acadd4b
 
 
 
 
 
 
 
 
d467bd4
 
 
b5b786d
d467bd4
 
acadd4b
 
 
 
 
d467bd4
 
6e6f8fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import gradio as gr
import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
import numpy as np
import os
from huggingface_hub import hf_hub_download

MODEL_PATH = "fno_ckpt_single_res"
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset"
HF_DATASET_FILENAME = "navier_stokes_2d.pt"

MODEL = None
FULL_DATASET_X = None

def download_file_from_hf_hub(repo_id, filename):
    """Downloads a file from Hugging Face Hub."""
    print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
    try:
        local_path = hf_hub_download(repo_id=repo_id, filename=filename)
        print(f"Downloaded {filename} to {local_path} successfully.")
        return local_path
    except Exception as e:
        print(f"Error downloading file from HF Hub: {e}")
        raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")

def load_model():
    """Loads the pre-trained FNO model to CPU."""
    global MODEL
    if MODEL is None:
        print("Loading FNO model to CPU...")
        try:
            MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
            MODEL.eval()
            print("Model loaded successfully to CPU.")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise gr.Error(f"Failed to load model: {e}")
    return MODEL

def load_dataset():
    """Downloads and loads the initial conditions dataset from HF Hub."""
    global FULL_DATASET_X
    if FULL_DATASET_X is None:
        local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
        print("Loading dataset from local file...")
        try:
            data = torch.load(local_dataset_path, map_location='cpu')
            if isinstance(data, dict) and 'x' in data:
                FULL_DATASET_X = data['x']
            elif isinstance(data, torch.Tensor):
                FULL_DATASET_X = data
            else:
                raise ValueError("Unknown dataset format or 'x' key missing.")
            print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
        except Exception as e:
            print(f"Error loading dataset: {e}")
            raise gr.Error(f"Failed to load dataset from local file: {e}")
    return FULL_DATASET_X

def create_enhanced_plot(data, title, is_input=True):
    """Creates enhanced matplotlib plots with professional styling."""
    plt.style.use('dark_background')
    fig, ax = plt.subplots(figsize=(8, 6), facecolor='#0f0f23')
    ax.set_facecolor('#0f0f23')
    
    # Enhanced colormap and visualization
    cmap = 'plasma' if is_input else 'viridis'
    im = ax.imshow(data, cmap=cmap, interpolation='bilinear')
    
    # Styling
    ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=20)
    ax.set_xlabel('X Coordinate', color='#8892b0', fontsize=10)
    ax.set_ylabel('Y Coordinate', color='#8892b0', fontsize=10)
    
    # Enhanced colorbar
    cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02)
    cbar.set_label('Vorticity', color='#8892b0', fontsize=10)
    cbar.ax.tick_params(colors='#8892b0', labelsize=8)
    
    # Grid and ticks
    ax.tick_params(colors='#8892b0', labelsize=8)
    ax.grid(True, alpha=0.1, color='white')
    
    plt.tight_layout()
    return fig

def run_inference(sample_index: int, progress=gr.Progress()):
    """
    Performs inference for a selected sample index from the dataset on CPU.
    Returns two enhanced Matplotlib figures with progress tracking.
    """
    progress(0.1, desc="Loading model...")
    device = torch.device("cpu")
    model = load_model()

    if next(model.parameters()).device != device:
        model.to(device)
        print(f"Model moved to {device} within run_inference.")

    progress(0.3, desc="Loading dataset...")
    dataset = load_dataset()

    if not (0 <= sample_index < dataset.shape[0]):
        raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")

    progress(0.5, desc="Preparing input...")
    single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
    print(f"Input moved to {device}.")

    progress(0.7, desc="Running inference...")
    print(f"Running inference for sample index {sample_index}...")
    with torch.no_grad():
        predicted_solution = model(single_initial_condition)

    progress(0.9, desc="Generating visualizations...")
    input_numpy = single_initial_condition.squeeze().cpu().numpy()
    output_numpy = predicted_solution.squeeze().cpu().numpy()

    fig_input = create_enhanced_plot(input_numpy, f"Initial Condition β€’ Sample {sample_index}", is_input=True)
    fig_output = create_enhanced_plot(output_numpy, "Predicted Solution", is_input=False)

    progress(1.0, desc="Complete!")
    return fig_input, fig_output

def get_random_sample():
    """Returns a random sample index for exploration."""
    dataset = load_dataset()
    return np.random.randint(0, dataset.shape[0])

# Custom CSS for professional styling with forced dark mode
custom_css = """
/* Force dark mode for all elements */
* {
    color-scheme: dark !important;
}

body, html {
    background-color: #0f0f23 !important;
    color: #e2e8f0 !important;
}

#main-container {
    background: linear-gradient(135deg, #1a1d1f 0%, #764ba2 100%);
    min-height: 100vh;
}

.gradio-container {
    background: rgba(15, 15, 35, 0.95) !important;
    backdrop-filter: blur(10px);
    border-radius: 20px;
    border: 1px solid rgba(255, 255, 255, 0.1);
    box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3);
    color: #e2e8f0 !important;
}

/* Force all text to be light colored */
p, span, div, label, h1, h2, h3, h4, h5, h6 {
    color: #e2e8f0 !important;
}

/* Fix link colors for both visited and unvisited */
a, a:link, a:visited, a:hover, a:active {
    color: #60a5fa !important;
    text-decoration: none !important;
    transition: color 0.3s ease !important;
}

a:hover {
    color: #93c5fd !important;
    text-decoration: underline !important;
}

/* Force markdown content colors */
.markdown-content, .markdown-content * {
    color: #e2e8f0 !important;
}

.markdown-content a, .markdown-content a:link, .markdown-content a:visited {
    color: #60a5fa !important;
}

.markdown-content a:hover {
    color: #93c5fd !important;
}

.gr-button {
    background: linear-gradient(45deg, #667eea, #764ba2) !important;
    border: none !important;
    border-radius: 12px !important;
    color: white !important;
    font-weight: 600 !important;
    padding: 12px 24px !important;
    transition: all 0.3s ease !important;
    box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
}

.gr-button:hover {
    transform: translateY(-2px) !important;
    box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}

.gr-slider input[type="range"] {
    background: linear-gradient(to right, #667eea, #764ba2) !important;
}

.markdown-content h1 {
    background: linear-gradient(45deg, #667eea, #764ba2);
    -webkit-background-clip: text;
    -webkit-text-fill-color: transparent;
    text-align: center;
    font-size: 2.5rem !important;
    margin-bottom: 1rem !important;
}

.markdown-content h3 {
    color: #8892b0 !important;
    border-left: 4px solid #667eea;
    padding-left: 1rem;
    margin-top: 2rem !important;
}

.feature-card {
    background: rgba(255, 255, 255, 0.05);
    border-radius: 15px;
    padding: 1.5rem;
    border: 1px solid rgba(255, 255, 255, 0.1);
    backdrop-filter: blur(10px);
    color: #e2e8f0 !important;
}

.feature-card strong {
    color: #93c5fd !important;
}

.status-indicator {
    display: inline-block;
    width: 8px;
    height: 8px;
    background: #00ff88;
    border-radius: 50%;
    margin-right: 8px;
    animation: pulse 2s infinite;
}

@keyframes pulse {
    0% { opacity: 1; }
    50% { opacity: 0.5; }
    100% { opacity: 1; }
}

.info-panel {
    background: rgba(102, 126, 234, 0.1);
    border-left: 4px solid #667eea;
    border-radius: 8px;
    padding: 1rem;
    margin: 1rem 0;
    color: #e2e8f0 !important;
}

.info-panel strong {
    color: #93c5fd !important;
}

/* Force input and form elements to dark mode */
input, textarea, select {
    background-color: rgba(255, 255, 255, 0.1) !important;
    color: #e2e8f0 !important;
    border: 1px solid rgba(255, 255, 255, 0.2) !important;
}

/* Force plot containers to dark mode */
.gr-plot {
    background-color: #0f0f23 !important;
}
"""

with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
    with gr.Column(elem_id="main-container"):
        # Hero Section
        gr.Markdown(
            """
            # 🌊 Neural Fluid Dynamics
            ## Advanced Fourier Neural Operator for Navier-Stokes Equations
            
            <div class="info-panel">
            <span class="status-indicator"></span><strong>AI-Powered Fluid Simulation</strong> β€’ Experience cutting-edge neural operators that solve complex partial differential equations in milliseconds instead of hours.
            </div>
            """,
            elem_classes=["markdown-content"]
        )
        
        with gr.Row():
            # Left Panel - Controls
            with gr.Column(scale=1):
                gr.Markdown(
                    """
                    ### πŸŽ›οΈ Simulation Controls
                    """,
                    elem_classes=["markdown-content"]
                )
                
                with gr.Group():
                    sample_input_slider = gr.Slider(
                        minimum=0,
                        maximum=9999,
                        value=0,
                        step=1,
                        label="πŸ“Š Sample Index",
                        info="Choose from 10,000 unique fluid scenarios"
                    )
                    
                    with gr.Row():
                        run_button = gr.Button("πŸš€ Generate Solution", variant="primary", scale=2)
                        random_button = gr.Button("🎲 Random", variant="secondary", scale=1)
                
                # Information Panel
                gr.Markdown(
                    """
                    ### πŸ“‹ Model Information
                    
                    <div class="feature-card">
                    <strong>Architecture:</strong> Fourier Neural Operator (FNO)<br>
                    <strong>Domain:</strong> 2D Fluid Dynamics<br>
                    <strong>Resolution:</strong> 64Γ—64 Grid<br>
                    <strong>Inference Time:</strong> ~50ms<br>
                    <strong>Training Samples:</strong> 10,000+
                    </div>
                    
                    ### πŸ”¬ Research Context
                    This demo showcases the practical applications of **'Principled approaches for extending neural architectures to function spaces for operator learning'** research.
                    
                    **Key Resources:**
                    - πŸ“„ [Research Paper](https://arxiv.org/abs/2506.10973)
                    - πŸ’» [Source Code](https://github.com/neuraloperator/NNs-to-NOs)
                    - πŸ“Š [Dataset](https://zenodo.org/records/12825163)
                    """,
                    elem_classes=["markdown-content"]
                )
            
            # Right Panel - Visualizations
            with gr.Column(scale=2):
                gr.Markdown(
                    """
                    ### πŸ“ˆ Simulation Results
                    """,
                    elem_classes=["markdown-content"]
                )
                
                with gr.Row():
                    input_image_plot = gr.Plot(
                        label="πŸŒ€ Initial Vorticity Field",
                        container=True
                    )
                    output_image_plot = gr.Plot(
                        label="⚑ Neural Operator Prediction", 
                        container=True
                    )
                
                gr.Markdown(
                    """
                    <div class="info-panel">
                    πŸ’‘ <strong>Pro Tip:</strong> The left plot shows the initial fluid state (vorticity), while the right shows how our neural operator predicts the fluid will evolve. Traditional methods would take hoursβ€”our AI does it instantly!
                    </div>
                    """,
                    elem_classes=["markdown-content"]
                )

    # Event handlers
    run_button.click(
        fn=run_inference,
        inputs=[sample_input_slider],
        outputs=[input_image_plot, output_image_plot]
    )
    
    random_button.click(
        fn=get_random_sample,
        outputs=[sample_input_slider]
    ).then(
        fn=run_inference,
        inputs=[sample_input_slider],
        outputs=[input_image_plot, output_image_plot]
    )

    def load_initial_data_and_predict():
        load_model()
        load_dataset()
        return run_inference(0)

    demo.load(
        load_initial_data_and_predict, 
        inputs=None, 
        outputs=[input_image_plot, output_image_plot]
    )

if __name__ == "__main__":
    demo.launch()