Navier_Stokes / app.py
ajsbsd's picture
Update app.py
2a00688 verified
import gradio as gr
import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
import numpy as np
import os
from huggingface_hub import hf_hub_download
MODEL_PATH = "fno_ckpt_single_res"
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset"
HF_DATASET_FILENAME = "navier_stokes_2d.pt"
MODEL = None
FULL_DATASET_X = None
def download_file_from_hf_hub(repo_id, filename):
"""Downloads a file from Hugging Face Hub."""
print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
try:
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded {filename} to {local_path} successfully.")
return local_path
except Exception as e:
print(f"Error downloading file from HF Hub: {e}")
raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")
def load_model():
"""Loads the pre-trained FNO model to CPU."""
global MODEL
if MODEL is None:
print("Loading FNO model to CPU...")
try:
MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
MODEL.eval()
print("Model loaded successfully to CPU.")
except Exception as e:
print(f"Error loading model: {e}")
raise gr.Error(f"Failed to load model: {e}")
return MODEL
def load_dataset():
"""Downloads and loads the initial conditions dataset from HF Hub."""
global FULL_DATASET_X
if FULL_DATASET_X is None:
local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
print("Loading dataset from local file...")
try:
data = torch.load(local_dataset_path, map_location='cpu')
if isinstance(data, dict) and 'x' in data:
FULL_DATASET_X = data['x']
elif isinstance(data, torch.Tensor):
FULL_DATASET_X = data
else:
raise ValueError("Unknown dataset format or 'x' key missing.")
print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
except Exception as e:
print(f"Error loading dataset: {e}")
raise gr.Error(f"Failed to load dataset from local file: {e}")
return FULL_DATASET_X
def create_enhanced_plot(data, title, is_input=True):
"""Creates enhanced matplotlib plots with professional styling."""
plt.style.use('dark_background')
fig, ax = plt.subplots(figsize=(8, 6), facecolor='#0f0f23')
ax.set_facecolor('#0f0f23')
# Enhanced colormap and visualization
cmap = 'plasma' if is_input else 'viridis'
im = ax.imshow(data, cmap=cmap, interpolation='bilinear')
# Styling
ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('X Coordinate', color='#8892b0', fontsize=10)
ax.set_ylabel('Y Coordinate', color='#8892b0', fontsize=10)
# Enhanced colorbar
cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label('Vorticity', color='#8892b0', fontsize=10)
cbar.ax.tick_params(colors='#8892b0', labelsize=8)
# Grid and ticks
ax.tick_params(colors='#8892b0', labelsize=8)
ax.grid(True, alpha=0.1, color='white')
plt.tight_layout()
return fig
def run_inference(sample_index: int, progress=gr.Progress()):
"""
Performs inference for a selected sample index from the dataset on CPU.
Returns two enhanced Matplotlib figures with progress tracking.
"""
progress(0.1, desc="Loading model...")
device = torch.device("cpu")
model = load_model()
if next(model.parameters()).device != device:
model.to(device)
print(f"Model moved to {device} within run_inference.")
progress(0.3, desc="Loading dataset...")
dataset = load_dataset()
if not (0 <= sample_index < dataset.shape[0]):
raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
progress(0.5, desc="Preparing input...")
single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
print(f"Input moved to {device}.")
progress(0.7, desc="Running inference...")
print(f"Running inference for sample index {sample_index}...")
with torch.no_grad():
predicted_solution = model(single_initial_condition)
progress(0.9, desc="Generating visualizations...")
input_numpy = single_initial_condition.squeeze().cpu().numpy()
output_numpy = predicted_solution.squeeze().cpu().numpy()
fig_input = create_enhanced_plot(input_numpy, f"Initial Condition β€’ Sample {sample_index}", is_input=True)
fig_output = create_enhanced_plot(output_numpy, "Predicted Solution", is_input=False)
progress(1.0, desc="Complete!")
return fig_input, fig_output
def get_random_sample():
"""Returns a random sample index for exploration."""
dataset = load_dataset()
return np.random.randint(0, dataset.shape[0])
# Custom CSS for professional styling with forced dark mode
custom_css = """
/* Force dark mode for all elements */
* {
color-scheme: dark !important;
}
body, html {
background-color: #0f0f23 !important;
color: #e2e8f0 !important;
}
#main-container {
background: linear-gradient(135deg, #1a1d1f 0%, #764ba2 100%);
min-height: 100vh;
}
.gradio-container {
background: rgba(15, 15, 35, 0.95) !important;
backdrop-filter: blur(10px);
border-radius: 20px;
border: 1px solid rgba(255, 255, 255, 0.1);
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3);
color: #e2e8f0 !important;
}
/* Force all text to be light colored */
p, span, div, label, h1, h2, h3, h4, h5, h6 {
color: #e2e8f0 !important;
}
/* Fix link colors for both visited and unvisited */
a, a:link, a:visited, a:hover, a:active {
color: #60a5fa !important;
text-decoration: none !important;
transition: color 0.3s ease !important;
}
a:hover {
color: #93c5fd !important;
text-decoration: underline !important;
}
/* Force markdown content colors */
.markdown-content, .markdown-content * {
color: #e2e8f0 !important;
}
.markdown-content a, .markdown-content a:link, .markdown-content a:visited {
color: #60a5fa !important;
}
.markdown-content a:hover {
color: #93c5fd !important;
}
.gr-button {
background: linear-gradient(45deg, #667eea, #764ba2) !important;
border: none !important;
border-radius: 12px !important;
color: white !important;
font-weight: 600 !important;
padding: 12px 24px !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
}
.gr-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
.gr-slider input[type="range"] {
background: linear-gradient(to right, #667eea, #764ba2) !important;
}
.markdown-content h1 {
background: linear-gradient(45deg, #667eea, #764ba2);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
text-align: center;
font-size: 2.5rem !important;
margin-bottom: 1rem !important;
}
.markdown-content h3 {
color: #8892b0 !important;
border-left: 4px solid #667eea;
padding-left: 1rem;
margin-top: 2rem !important;
}
.feature-card {
background: rgba(255, 255, 255, 0.05);
border-radius: 15px;
padding: 1.5rem;
border: 1px solid rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
color: #e2e8f0 !important;
}
.feature-card strong {
color: #93c5fd !important;
}
.status-indicator {
display: inline-block;
width: 8px;
height: 8px;
background: #00ff88;
border-radius: 50%;
margin-right: 8px;
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.5; }
100% { opacity: 1; }
}
.info-panel {
background: rgba(102, 126, 234, 0.1);
border-left: 4px solid #667eea;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
color: #e2e8f0 !important;
}
.info-panel strong {
color: #93c5fd !important;
}
/* Force input and form elements to dark mode */
input, textarea, select {
background-color: rgba(255, 255, 255, 0.1) !important;
color: #e2e8f0 !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
}
/* Force plot containers to dark mode */
.gr-plot {
background-color: #0f0f23 !important;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
with gr.Column(elem_id="main-container"):
# Hero Section
gr.Markdown(
"""
# 🌊 Neural Fluid Dynamics
## Advanced Fourier Neural Operator for Navier-Stokes Equations
<div class="info-panel">
<span class="status-indicator"></span><strong>AI-Powered Fluid Simulation</strong> β€’ Experience cutting-edge neural operators that solve complex partial differential equations in milliseconds instead of hours.
</div>
""",
elem_classes=["markdown-content"]
)
with gr.Row():
# Left Panel - Controls
with gr.Column(scale=1):
gr.Markdown(
"""
### πŸŽ›οΈ Simulation Controls
""",
elem_classes=["markdown-content"]
)
with gr.Group():
sample_input_slider = gr.Slider(
minimum=0,
maximum=9999,
value=0,
step=1,
label="πŸ“Š Sample Index",
info="Choose from 10,000 unique fluid scenarios"
)
with gr.Row():
run_button = gr.Button("πŸš€ Generate Solution", variant="primary", scale=2)
random_button = gr.Button("🎲 Random", variant="secondary", scale=1)
# Information Panel
gr.Markdown(
"""
### πŸ“‹ Model Information
<div class="feature-card">
<strong>Architecture:</strong> Fourier Neural Operator (FNO)<br>
<strong>Domain:</strong> 2D Fluid Dynamics<br>
<strong>Resolution:</strong> 64Γ—64 Grid<br>
<strong>Inference Time:</strong> ~50ms<br>
<strong>Training Samples:</strong> 10,000+
</div>
### πŸ”¬ Research Context
This demo showcases the practical applications of **'Principled approaches for extending neural architectures to function spaces for operator learning'** research.
**Key Resources:**
- πŸ“„ [Research Paper](https://arxiv.org/abs/2506.10973)
- πŸ’» [Source Code](https://github.com/neuraloperator/NNs-to-NOs)
- πŸ“Š [Dataset](https://zenodo.org/records/12825163)
""",
elem_classes=["markdown-content"]
)
# Right Panel - Visualizations
with gr.Column(scale=2):
gr.Markdown(
"""
### πŸ“ˆ Simulation Results
""",
elem_classes=["markdown-content"]
)
with gr.Row():
input_image_plot = gr.Plot(
label="πŸŒ€ Initial Vorticity Field",
container=True
)
output_image_plot = gr.Plot(
label="⚑ Neural Operator Prediction",
container=True
)
gr.Markdown(
"""
<div class="info-panel">
πŸ’‘ <strong>Pro Tip:</strong> The left plot shows the initial fluid state (vorticity), while the right shows how our neural operator predicts the fluid will evolve. Traditional methods would take hoursβ€”our AI does it instantly!
</div>
""",
elem_classes=["markdown-content"]
)
# Event handlers
run_button.click(
fn=run_inference,
inputs=[sample_input_slider],
outputs=[input_image_plot, output_image_plot]
)
random_button.click(
fn=get_random_sample,
outputs=[sample_input_slider]
).then(
fn=run_inference,
inputs=[sample_input_slider],
outputs=[input_image_plot, output_image_plot]
)
def load_initial_data_and_predict():
load_model()
load_dataset()
return run_inference(0)
demo.load(
load_initial_data_and_predict,
inputs=None,
outputs=[input_image_plot, output_image_plot]
)
if __name__ == "__main__":
demo.launch()