import gradio as gr import torch from neuralop.models import FNO import matplotlib.pyplot as plt import numpy as np import os from huggingface_hub import hf_hub_download MODEL_PATH = "fno_ckpt_single_res" HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" HF_DATASET_FILENAME = "navier_stokes_2d.pt" MODEL = None FULL_DATASET_X = None def download_file_from_hf_hub(repo_id, filename): """Downloads a file from Hugging Face Hub.""" print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...") try: local_path = hf_hub_download(repo_id=repo_id, filename=filename) print(f"Downloaded {filename} to {local_path} successfully.") return local_path except Exception as e: print(f"Error downloading file from HF Hub: {e}") raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}") def load_model(): """Loads the pre-trained FNO model to CPU.""" global MODEL if MODEL is None: print("Loading FNO model to CPU...") try: MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu') MODEL.eval() print("Model loaded successfully to CPU.") except Exception as e: print(f"Error loading model: {e}") raise gr.Error(f"Failed to load model: {e}") return MODEL def load_dataset(): """Downloads and loads the initial conditions dataset from HF Hub.""" global FULL_DATASET_X if FULL_DATASET_X is None: local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME) print("Loading dataset from local file...") try: data = torch.load(local_dataset_path, map_location='cpu') if isinstance(data, dict) and 'x' in data: FULL_DATASET_X = data['x'] elif isinstance(data, torch.Tensor): FULL_DATASET_X = data else: raise ValueError("Unknown dataset format or 'x' key missing.") print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}") except Exception as e: print(f"Error loading dataset: {e}") raise gr.Error(f"Failed to load dataset from local file: {e}") return FULL_DATASET_X def create_enhanced_plot(data, title, is_input=True): """Creates enhanced matplotlib plots with professional styling.""" plt.style.use('dark_background') fig, ax = plt.subplots(figsize=(8, 6), facecolor='#0f0f23') ax.set_facecolor('#0f0f23') # Enhanced colormap and visualization cmap = 'plasma' if is_input else 'viridis' im = ax.imshow(data, cmap=cmap, interpolation='bilinear') # Styling ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=20) ax.set_xlabel('X Coordinate', color='#8892b0', fontsize=10) ax.set_ylabel('Y Coordinate', color='#8892b0', fontsize=10) # Enhanced colorbar cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) cbar.set_label('Vorticity', color='#8892b0', fontsize=10) cbar.ax.tick_params(colors='#8892b0', labelsize=8) # Grid and ticks ax.tick_params(colors='#8892b0', labelsize=8) ax.grid(True, alpha=0.1, color='white') plt.tight_layout() return fig def run_inference(sample_index: int, progress=gr.Progress()): """ Performs inference for a selected sample index from the dataset on CPU. Returns two enhanced Matplotlib figures with progress tracking. """ progress(0.1, desc="Loading model...") device = torch.device("cpu") model = load_model() if next(model.parameters()).device != device: model.to(device) print(f"Model moved to {device} within run_inference.") progress(0.3, desc="Loading dataset...") dataset = load_dataset() if not (0 <= sample_index < dataset.shape[0]): raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.") progress(0.5, desc="Preparing input...") single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device) print(f"Input moved to {device}.") progress(0.7, desc="Running inference...") print(f"Running inference for sample index {sample_index}...") with torch.no_grad(): predicted_solution = model(single_initial_condition) progress(0.9, desc="Generating visualizations...") input_numpy = single_initial_condition.squeeze().cpu().numpy() output_numpy = predicted_solution.squeeze().cpu().numpy() fig_input = create_enhanced_plot(input_numpy, f"Initial Condition • Sample {sample_index}", is_input=True) fig_output = create_enhanced_plot(output_numpy, "Predicted Solution", is_input=False) progress(1.0, desc="Complete!") return fig_input, fig_output def get_random_sample(): """Returns a random sample index for exploration.""" dataset = load_dataset() return np.random.randint(0, dataset.shape[0]) # Custom CSS for professional styling with forced dark mode custom_css = """ /* Force dark mode for all elements */ * { color-scheme: dark !important; } body, html { background-color: #0f0f23 !important; color: #e2e8f0 !important; } #main-container { background: linear-gradient(135deg, #1a1d1f 0%, #764ba2 100%); min-height: 100vh; } .gradio-container { background: rgba(15, 15, 35, 0.95) !important; backdrop-filter: blur(10px); border-radius: 20px; border: 1px solid rgba(255, 255, 255, 0.1); box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3); color: #e2e8f0 !important; } /* Force all text to be light colored */ p, span, div, label, h1, h2, h3, h4, h5, h6 { color: #e2e8f0 !important; } /* Fix link colors for both visited and unvisited */ a, a:link, a:visited, a:hover, a:active { color: #60a5fa !important; text-decoration: none !important; transition: color 0.3s ease !important; } a:hover { color: #93c5fd !important; text-decoration: underline !important; } /* Force markdown content colors */ .markdown-content, .markdown-content * { color: #e2e8f0 !important; } .markdown-content a, .markdown-content a:link, .markdown-content a:visited { color: #60a5fa !important; } .markdown-content a:hover { color: #93c5fd !important; } .gr-button { background: linear-gradient(45deg, #667eea, #764ba2) !important; border: none !important; border-radius: 12px !important; color: white !important; font-weight: 600 !important; padding: 12px 24px !important; transition: all 0.3s ease !important; box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; } .gr-button:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; } .gr-slider input[type="range"] { background: linear-gradient(to right, #667eea, #764ba2) !important; } .markdown-content h1 { background: linear-gradient(45deg, #667eea, #764ba2); -webkit-background-clip: text; -webkit-text-fill-color: transparent; text-align: center; font-size: 2.5rem !important; margin-bottom: 1rem !important; } .markdown-content h3 { color: #8892b0 !important; border-left: 4px solid #667eea; padding-left: 1rem; margin-top: 2rem !important; } .feature-card { background: rgba(255, 255, 255, 0.05); border-radius: 15px; padding: 1.5rem; border: 1px solid rgba(255, 255, 255, 0.1); backdrop-filter: blur(10px); color: #e2e8f0 !important; } .feature-card strong { color: #93c5fd !important; } .status-indicator { display: inline-block; width: 8px; height: 8px; background: #00ff88; border-radius: 50%; margin-right: 8px; animation: pulse 2s infinite; } @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.5; } 100% { opacity: 1; } } .info-panel { background: rgba(102, 126, 234, 0.1); border-left: 4px solid #667eea; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #e2e8f0 !important; } .info-panel strong { color: #93c5fd !important; } /* Force input and form elements to dark mode */ input, textarea, select { background-color: rgba(255, 255, 255, 0.1) !important; color: #e2e8f0 !important; border: 1px solid rgba(255, 255, 255, 0.2) !important; } /* Force plot containers to dark mode */ .gr-plot { background-color: #0f0f23 !important; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo: with gr.Column(elem_id="main-container"): # Hero Section gr.Markdown( """ # 🌊 Neural Fluid Dynamics ## Advanced Fourier Neural Operator for Navier-Stokes Equations
AI-Powered Fluid Simulation • Experience cutting-edge neural operators that solve complex partial differential equations in milliseconds instead of hours.
""", elem_classes=["markdown-content"] ) with gr.Row(): # Left Panel - Controls with gr.Column(scale=1): gr.Markdown( """ ### 🎛️ Simulation Controls """, elem_classes=["markdown-content"] ) with gr.Group(): sample_input_slider = gr.Slider( minimum=0, maximum=9999, value=0, step=1, label="📊 Sample Index", info="Choose from 10,000 unique fluid scenarios" ) with gr.Row(): run_button = gr.Button("🚀 Generate Solution", variant="primary", scale=2) random_button = gr.Button("🎲 Random", variant="secondary", scale=1) # Information Panel gr.Markdown( """ ### 📋 Model Information
Architecture: Fourier Neural Operator (FNO)
Domain: 2D Fluid Dynamics
Resolution: 64×64 Grid
Inference Time: ~50ms
Training Samples: 10,000+
### 🔬 Research Context This demo showcases the practical applications of **'Principled approaches for extending neural architectures to function spaces for operator learning'** research. **Key Resources:** - 📄 [Research Paper](https://arxiv.org/abs/2506.10973) - 💻 [Source Code](https://github.com/neuraloperator/NNs-to-NOs) - 📊 [Dataset](https://zenodo.org/records/12825163) """, elem_classes=["markdown-content"] ) # Right Panel - Visualizations with gr.Column(scale=2): gr.Markdown( """ ### 📈 Simulation Results """, elem_classes=["markdown-content"] ) with gr.Row(): input_image_plot = gr.Plot( label="🌀 Initial Vorticity Field", container=True ) output_image_plot = gr.Plot( label="⚡ Neural Operator Prediction", container=True ) gr.Markdown( """
💡 Pro Tip: The left plot shows the initial fluid state (vorticity), while the right shows how our neural operator predicts the fluid will evolve. Traditional methods would take hours—our AI does it instantly!
""", elem_classes=["markdown-content"] ) # Event handlers run_button.click( fn=run_inference, inputs=[sample_input_slider], outputs=[input_image_plot, output_image_plot] ) random_button.click( fn=get_random_sample, outputs=[sample_input_slider] ).then( fn=run_inference, inputs=[sample_input_slider], outputs=[input_image_plot, output_image_plot] ) def load_initial_data_and_predict(): load_model() load_dataset() return run_inference(0) demo.load( load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot] ) if __name__ == "__main__": demo.launch()