Navier_Stokes / app.py
ajsbsd's picture
Update app.py
0dde314 verified
raw
history blame
13.2 kB
import gradio as gr
import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
import numpy as np
import os
from huggingface_hub import hf_hub_download
MODEL_PATH = "fno_ckpt_single_res"
HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset"
HF_DATASET_FILENAME = "navier_stokes_2d.pt"
MODEL = None
FULL_DATASET_X = None
def download_file_from_hf_hub(repo_id, filename):
"""Downloads a file from Hugging Face Hub."""
print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
try:
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded {filename} to {local_path} successfully.")
return local_path
except Exception as e:
print(f"Error downloading file from HF Hub: {e}")
raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")
def load_model():
"""Loads the pre-trained FNO model to CPU."""
global MODEL
if MODEL is None:
print("Loading FNO model to CPU...")
try:
MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
MODEL.eval()
print("Model loaded successfully to CPU.")
except Exception as e:
print(f"Error loading model: {e}")
raise gr.Error(f"Failed to load model: {e}")
return MODEL
def load_dataset():
"""Downloads and loads the initial conditions dataset from HF Hub."""
global FULL_DATASET_X
if FULL_DATASET_X is None:
local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
print("Loading dataset from local file...")
try:
data = torch.load(local_dataset_path, map_location='cpu')
if isinstance(data, dict) and 'x' in data:
FULL_DATASET_X = data['x']
elif isinstance(data, torch.Tensor):
FULL_DATASET_X = data
else:
raise ValueError("Unknown dataset format or 'x' key missing.")
print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
except Exception as e:
print(f"Error loading dataset: {e}")
raise gr.Error(f"Failed to load dataset from local file: {e}")
return FULL_DATASET_X
def create_enhanced_plot(data, title, is_input=True):
"""Creates enhanced matplotlib plots with professional styling."""
plt.style.use('dark_background')
fig, ax = plt.subplots(figsize=(8, 6), facecolor='#0f0f23')
ax.set_facecolor('#0f0f23')
# Enhanced colormap and visualization
cmap = 'plasma' if is_input else 'viridis'
im = ax.imshow(data, cmap=cmap, interpolation='bilinear')
# Styling
ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('X Coordinate', color='#8892b0', fontsize=10)
ax.set_ylabel('Y Coordinate', color='#8892b0', fontsize=10)
# Enhanced colorbar
cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label('Vorticity', color='#8892b0', fontsize=10)
cbar.ax.tick_params(colors='#8892b0', labelsize=8)
# Grid and ticks
ax.tick_params(colors='#8892b0', labelsize=8)
ax.grid(True, alpha=0.1, color='white')
plt.tight_layout()
return fig
def run_inference(sample_index: int, progress=gr.Progress()):
"""
Performs inference for a selected sample index from the dataset on CPU.
Returns two enhanced Matplotlib figures with progress tracking.
"""
progress(0.1, desc="Loading model...")
device = torch.device("cpu")
model = load_model()
if next(model.parameters()).device != device:
model.to(device)
print(f"Model moved to {device} within run_inference.")
progress(0.3, desc="Loading dataset...")
dataset = load_dataset()
if not (0 <= sample_index < dataset.shape[0]):
raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
progress(0.5, desc="Preparing input...")
single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
print(f"Input moved to {device}.")
progress(0.7, desc="Running inference...")
print(f"Running inference for sample index {sample_index}...")
with torch.no_grad():
predicted_solution = model(single_initial_condition)
progress(0.9, desc="Generating visualizations...")
input_numpy = single_initial_condition.squeeze().cpu().numpy()
output_numpy = predicted_solution.squeeze().cpu().numpy()
fig_input = create_enhanced_plot(input_numpy, f"Initial Condition β€’ Sample {sample_index}", is_input=True)
fig_output = create_enhanced_plot(output_numpy, "Predicted Solution", is_input=False)
progress(1.0, desc="Complete!")
return fig_input, fig_output
def get_random_sample():
"""Returns a random sample index for exploration."""
dataset = load_dataset()
return np.random.randint(0, dataset.shape[0])
# Custom CSS for professional styling with forced dark mode
custom_css = """
/* Force dark mode for all elements */
* {
color-scheme: dark !important;
}
body, html {
background-color: #0f0f23 !important;
color: #e2e8f0 !important;
}
#main-container {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
}
.gradio-container {
background: rgba(15, 15, 35, 0.95) !important;
backdrop-filter: blur(10px);
border-radius: 20px;
border: 1px solid rgba(255, 255, 255, 0.1);
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3);
color: #e2e8f0 !important;
}
/* Force all text to be light colored */
p, span, div, label, h1, h2, h3, h4, h5, h6 {
color: #e2e8f0 !important;
}
/* Fix link colors for both visited and unvisited */
a, a:link, a:visited, a:hover, a:active {
color: #60a5fa !important;
text-decoration: none !important;
transition: color 0.3s ease !important;
}
a:hover {
color: #93c5fd !important;
text-decoration: underline !important;
}
/* Force markdown content colors */
.markdown-content, .markdown-content * {
color: #e2e8f0 !important;
}
.markdown-content a, .markdown-content a:link, .markdown-content a:visited {
color: #60a5fa !important;
}
.markdown-content a:hover {
color: #93c5fd !important;
}
.gr-button {
background: linear-gradient(45deg, #667eea, #764ba2) !important;
border: none !important;
border-radius: 12px !important;
color: white !important;
font-weight: 600 !important;
padding: 12px 24px !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
}
.gr-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
.gr-slider input[type="range"] {
background: linear-gradient(to right, #667eea, #764ba2) !important;
}
.markdown-content h1 {
background: linear-gradient(45deg, #667eea, #764ba2);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
text-align: center;
font-size: 2.5rem !important;
margin-bottom: 1rem !important;
}
.markdown-content h3 {
color: #8892b0 !important;
border-left: 4px solid #667eea;
padding-left: 1rem;
margin-top: 2rem !important;
}
.feature-card {
background: rgba(255, 255, 255, 0.05);
border-radius: 15px;
padding: 1.5rem;
border: 1px solid rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
color: #e2e8f0 !important;
}
.feature-card strong {
color: #93c5fd !important;
}
.status-indicator {
display: inline-block;
width: 8px;
height: 8px;
background: #00ff88;
border-radius: 50%;
margin-right: 8px;
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.5; }
100% { opacity: 1; }
}
.info-panel {
background: rgba(102, 126, 234, 0.1);
border-left: 4px solid #667eea;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
color: #e2e8f0 !important;
}
.info-panel strong {
color: #93c5fd !important;
}
/* Force input and form elements to dark mode */
input, textarea, select {
background-color: rgba(255, 255, 255, 0.1) !important;
color: #e2e8f0 !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
}
/* Force plot containers to dark mode */
.gr-plot {
background-color: #0f0f23 !important;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
with gr.Column(elem_id="main-container"):
# Hero Section
gr.Markdown(
"""
# 🌊 Neural Fluid Dynamics
## Advanced Fourier Neural Operator for Navier-Stokes Equations
<div class="info-panel">
<span class="status-indicator"></span><strong>AI-Powered Fluid Simulation</strong> β€’ Experience cutting-edge neural operators that solve complex partial differential equations in milliseconds instead of hours.
</div>
""",
elem_classes=["markdown-content"]
)
with gr.Row():
# Left Panel - Controls
with gr.Column(scale=1):
gr.Markdown(
"""
### πŸŽ›οΈ Simulation Controls
""",
elem_classes=["markdown-content"]
)
with gr.Group():
sample_input_slider = gr.Slider(
minimum=0,
maximum=9999,
value=0,
step=1,
label="πŸ“Š Sample Index",
info="Choose from 10,000 unique fluid scenarios"
)
with gr.Row():
run_button = gr.Button("πŸš€ Generate Solution", variant="primary", scale=2)
random_button = gr.Button("🎲 Random", variant="secondary", scale=1)
# Information Panel
gr.Markdown(
"""
### πŸ“‹ Model Information
<div class="feature-card">
<strong>Architecture:</strong> Fourier Neural Operator (FNO)<br>
<strong>Domain:</strong> 2D Fluid Dynamics<br>
<strong>Resolution:</strong> 64Γ—64 Grid<br>
<strong>Inference Time:</strong> ~50ms<br>
<strong>Training Samples:</strong> 10,000+
</div>
### πŸ”¬ Research Context
This demo showcases the practical applications of **'Principled approaches for extending neural architectures to function spaces for operator learning'** research.
**Key Resources:**
- πŸ“„ [Research Paper](https://arxiv.org/abs/2506.10973)
- πŸ’» [Source Code](https://github.com/neuraloperator/NNs-to-NOs)
- πŸ“Š [Dataset](https://zenodo.org/records/12825163)
""",
elem_classes=["markdown-content"]
)
# Right Panel - Visualizations
with gr.Column(scale=2):
gr.Markdown(
"""
### πŸ“ˆ Simulation Results
""",
elem_classes=["markdown-content"]
)
with gr.Row():
input_image_plot = gr.Plot(
label="πŸŒ€ Initial Vorticity Field",
container=True
)
output_image_plot = gr.Plot(
label="⚑ Neural Operator Prediction",
container=True
)
gr.Markdown(
"""
<div class="info-panel">
πŸ’‘ <strong>Pro Tip:</strong> The left plot shows the initial fluid state (vorticity), while the right shows how our neural operator predicts the fluid will evolve. Traditional methods would take hoursβ€”our AI does it instantly!
</div>
""",
elem_classes=["markdown-content"]
)
# Event handlers
run_button.click(
fn=run_inference,
inputs=[sample_input_slider],
outputs=[input_image_plot, output_image_plot]
)
random_button.click(
fn=get_random_sample,
outputs=[sample_input_slider]
).then(
fn=run_inference,
inputs=[sample_input_slider],
outputs=[input_image_plot, output_image_plot]
)
def load_initial_data_and_predict():
load_model()
load_dataset()
return run_inference(0)
demo.load(
load_initial_data_and_predict,
inputs=None,
outputs=[input_image_plot, output_image_plot]
)
if __name__ == "__main__":
demo.launch()