Spaces:
Running
Running
import gradio as gr | |
import torch | |
from neuralop.models import FNO | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
from huggingface_hub import hf_hub_download | |
MODEL_PATH = "fno_ckpt_single_res" | |
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" | |
HF_DATASET_FILENAME = "navier_stokes_2d.pt" | |
MODEL = None | |
FULL_DATASET_X = None | |
def download_file_from_hf_hub(repo_id, filename): | |
"""Downloads a file from Hugging Face Hub.""" | |
print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...") | |
try: | |
local_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
print(f"Downloaded {filename} to {local_path} successfully.") | |
return local_path | |
except Exception as e: | |
print(f"Error downloading file from HF Hub: {e}") | |
raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}") | |
def load_model(): | |
"""Loads the pre-trained FNO model to CPU.""" | |
global MODEL | |
if MODEL is None: | |
print("Loading FNO model to CPU...") | |
try: | |
MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu') | |
MODEL.eval() | |
print("Model loaded successfully to CPU.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise gr.Error(f"Failed to load model: {e}") | |
return MODEL | |
def load_dataset(): | |
"""Downloads and loads the initial conditions dataset from HF Hub.""" | |
global FULL_DATASET_X | |
if FULL_DATASET_X is None: | |
local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME) | |
print("Loading dataset from local file...") | |
try: | |
data = torch.load(local_dataset_path, map_location='cpu') | |
if isinstance(data, dict) and 'x' in data: | |
FULL_DATASET_X = data['x'] | |
elif isinstance(data, torch.Tensor): | |
FULL_DATASET_X = data | |
else: | |
raise ValueError("Unknown dataset format or 'x' key missing.") | |
print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}") | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
raise gr.Error(f"Failed to load dataset from local file: {e}") | |
return FULL_DATASET_X | |
def create_enhanced_plot(data, title, is_input=True): | |
"""Creates enhanced matplotlib plots with professional styling.""" | |
plt.style.use('dark_background') | |
fig, ax = plt.subplots(figsize=(8, 6), facecolor='#0f0f23') | |
ax.set_facecolor('#0f0f23') | |
# Enhanced colormap and visualization | |
cmap = 'plasma' if is_input else 'viridis' | |
im = ax.imshow(data, cmap=cmap, interpolation='bilinear') | |
# Styling | |
ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=20) | |
ax.set_xlabel('X Coordinate', color='#8892b0', fontsize=10) | |
ax.set_ylabel('Y Coordinate', color='#8892b0', fontsize=10) | |
# Enhanced colorbar | |
cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) | |
cbar.set_label('Vorticity', color='#8892b0', fontsize=10) | |
cbar.ax.tick_params(colors='#8892b0', labelsize=8) | |
# Grid and ticks | |
ax.tick_params(colors='#8892b0', labelsize=8) | |
ax.grid(True, alpha=0.1, color='white') | |
plt.tight_layout() | |
return fig | |
def run_inference(sample_index: int, progress=gr.Progress()): | |
""" | |
Performs inference for a selected sample index from the dataset on CPU. | |
Returns two enhanced Matplotlib figures with progress tracking. | |
""" | |
progress(0.1, desc="Loading model...") | |
device = torch.device("cpu") | |
model = load_model() | |
if next(model.parameters()).device != device: | |
model.to(device) | |
print(f"Model moved to {device} within run_inference.") | |
progress(0.3, desc="Loading dataset...") | |
dataset = load_dataset() | |
if not (0 <= sample_index < dataset.shape[0]): | |
raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.") | |
progress(0.5, desc="Preparing input...") | |
single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device) | |
print(f"Input moved to {device}.") | |
progress(0.7, desc="Running inference...") | |
print(f"Running inference for sample index {sample_index}...") | |
with torch.no_grad(): | |
predicted_solution = model(single_initial_condition) | |
progress(0.9, desc="Generating visualizations...") | |
input_numpy = single_initial_condition.squeeze().cpu().numpy() | |
output_numpy = predicted_solution.squeeze().cpu().numpy() | |
fig_input = create_enhanced_plot(input_numpy, f"Initial Condition β’ Sample {sample_index}", is_input=True) | |
fig_output = create_enhanced_plot(output_numpy, "Predicted Solution", is_input=False) | |
progress(1.0, desc="Complete!") | |
return fig_input, fig_output | |
def get_random_sample(): | |
"""Returns a random sample index for exploration.""" | |
dataset = load_dataset() | |
return np.random.randint(0, dataset.shape[0]) | |
# Custom CSS for professional styling with forced dark mode | |
custom_css = """ | |
/* Force dark mode for all elements */ | |
* { | |
color-scheme: dark !important; | |
} | |
body, html { | |
background-color: #0f0f23 !important; | |
color: #e2e8f0 !important; | |
} | |
#main-container { | |
background: linear-gradient(135deg, #1a1d1f 0%, #764ba2 100%); | |
min-height: 100vh; | |
} | |
.gradio-container { | |
background: rgba(15, 15, 35, 0.95) !important; | |
backdrop-filter: blur(10px); | |
border-radius: 20px; | |
border: 1px solid rgba(255, 255, 255, 0.1); | |
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3); | |
color: #e2e8f0 !important; | |
} | |
/* Force all text to be light colored */ | |
p, span, div, label, h1, h2, h3, h4, h5, h6 { | |
color: #e2e8f0 !important; | |
} | |
/* Fix link colors for both visited and unvisited */ | |
a, a:link, a:visited, a:hover, a:active { | |
color: #60a5fa !important; | |
text-decoration: none !important; | |
transition: color 0.3s ease !important; | |
} | |
a:hover { | |
color: #93c5fd !important; | |
text-decoration: underline !important; | |
} | |
/* Force markdown content colors */ | |
.markdown-content, .markdown-content * { | |
color: #e2e8f0 !important; | |
} | |
.markdown-content a, .markdown-content a:link, .markdown-content a:visited { | |
color: #60a5fa !important; | |
} | |
.markdown-content a:hover { | |
color: #93c5fd !important; | |
} | |
.gr-button { | |
background: linear-gradient(45deg, #667eea, #764ba2) !important; | |
border: none !important; | |
border-radius: 12px !important; | |
color: white !important; | |
font-weight: 600 !important; | |
padding: 12px 24px !important; | |
transition: all 0.3s ease !important; | |
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; | |
} | |
.gr-button:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; | |
} | |
.gr-slider input[type="range"] { | |
background: linear-gradient(to right, #667eea, #764ba2) !important; | |
} | |
.markdown-content h1 { | |
background: linear-gradient(45deg, #667eea, #764ba2); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
text-align: center; | |
font-size: 2.5rem !important; | |
margin-bottom: 1rem !important; | |
} | |
.markdown-content h3 { | |
color: #8892b0 !important; | |
border-left: 4px solid #667eea; | |
padding-left: 1rem; | |
margin-top: 2rem !important; | |
} | |
.feature-card { | |
background: rgba(255, 255, 255, 0.05); | |
border-radius: 15px; | |
padding: 1.5rem; | |
border: 1px solid rgba(255, 255, 255, 0.1); | |
backdrop-filter: blur(10px); | |
color: #e2e8f0 !important; | |
} | |
.feature-card strong { | |
color: #93c5fd !important; | |
} | |
.status-indicator { | |
display: inline-block; | |
width: 8px; | |
height: 8px; | |
background: #00ff88; | |
border-radius: 50%; | |
margin-right: 8px; | |
animation: pulse 2s infinite; | |
} | |
@keyframes pulse { | |
0% { opacity: 1; } | |
50% { opacity: 0.5; } | |
100% { opacity: 1; } | |
} | |
.info-panel { | |
background: rgba(102, 126, 234, 0.1); | |
border-left: 4px solid #667eea; | |
border-radius: 8px; | |
padding: 1rem; | |
margin: 1rem 0; | |
color: #e2e8f0 !important; | |
} | |
.info-panel strong { | |
color: #93c5fd !important; | |
} | |
/* Force input and form elements to dark mode */ | |
input, textarea, select { | |
background-color: rgba(255, 255, 255, 0.1) !important; | |
color: #e2e8f0 !important; | |
border: 1px solid rgba(255, 255, 255, 0.2) !important; | |
} | |
/* Force plot containers to dark mode */ | |
.gr-plot { | |
background-color: #0f0f23 !important; | |
} | |
""" | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo: | |
with gr.Column(elem_id="main-container"): | |
# Hero Section | |
gr.Markdown( | |
""" | |
# π Neural Fluid Dynamics | |
## Advanced Fourier Neural Operator for Navier-Stokes Equations | |
<div class="info-panel"> | |
<span class="status-indicator"></span><strong>AI-Powered Fluid Simulation</strong> β’ Experience cutting-edge neural operators that solve complex partial differential equations in milliseconds instead of hours. | |
</div> | |
""", | |
elem_classes=["markdown-content"] | |
) | |
with gr.Row(): | |
# Left Panel - Controls | |
with gr.Column(scale=1): | |
gr.Markdown( | |
""" | |
### ποΈ Simulation Controls | |
""", | |
elem_classes=["markdown-content"] | |
) | |
with gr.Group(): | |
sample_input_slider = gr.Slider( | |
minimum=0, | |
maximum=9999, | |
value=0, | |
step=1, | |
label="π Sample Index", | |
info="Choose from 10,000 unique fluid scenarios" | |
) | |
with gr.Row(): | |
run_button = gr.Button("π Generate Solution", variant="primary", scale=2) | |
random_button = gr.Button("π² Random", variant="secondary", scale=1) | |
# Information Panel | |
gr.Markdown( | |
""" | |
### π Model Information | |
<div class="feature-card"> | |
<strong>Architecture:</strong> Fourier Neural Operator (FNO)<br> | |
<strong>Domain:</strong> 2D Fluid Dynamics<br> | |
<strong>Resolution:</strong> 64Γ64 Grid<br> | |
<strong>Inference Time:</strong> ~50ms<br> | |
<strong>Training Samples:</strong> 10,000+ | |
</div> | |
### π¬ Research Context | |
This demo showcases the practical applications of **'Principled approaches for extending neural architectures to function spaces for operator learning'** research. | |
**Key Resources:** | |
- π [Research Paper](https://arxiv.org/abs/2506.10973) | |
- π» [Source Code](https://github.com/neuraloperator/NNs-to-NOs) | |
- π [Dataset](https://zenodo.org/records/12825163) | |
""", | |
elem_classes=["markdown-content"] | |
) | |
# Right Panel - Visualizations | |
with gr.Column(scale=2): | |
gr.Markdown( | |
""" | |
### π Simulation Results | |
""", | |
elem_classes=["markdown-content"] | |
) | |
with gr.Row(): | |
input_image_plot = gr.Plot( | |
label="π Initial Vorticity Field", | |
container=True | |
) | |
output_image_plot = gr.Plot( | |
label="β‘ Neural Operator Prediction", | |
container=True | |
) | |
gr.Markdown( | |
""" | |
<div class="info-panel"> | |
π‘ <strong>Pro Tip:</strong> The left plot shows the initial fluid state (vorticity), while the right shows how our neural operator predicts the fluid will evolve. Traditional methods would take hoursβour AI does it instantly! | |
</div> | |
""", | |
elem_classes=["markdown-content"] | |
) | |
# Event handlers | |
run_button.click( | |
fn=run_inference, | |
inputs=[sample_input_slider], | |
outputs=[input_image_plot, output_image_plot] | |
) | |
random_button.click( | |
fn=get_random_sample, | |
outputs=[sample_input_slider] | |
).then( | |
fn=run_inference, | |
inputs=[sample_input_slider], | |
outputs=[input_image_plot, output_image_plot] | |
) | |
def load_initial_data_and_predict(): | |
load_model() | |
load_dataset() | |
return run_inference(0) | |
demo.load( | |
load_initial_data_and_predict, | |
inputs=None, | |
outputs=[input_image_plot, output_image_plot] | |
) | |
if __name__ == "__main__": | |
demo.launch() |