import gradio as gr import torch from neuralop.models import FNO import matplotlib.pyplot as plt import numpy as np import os from huggingface_hub import hf_hub_download MODEL_PATH = "fno_ckpt_single_res" HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" HF_DATASET_FILENAME = "navier_stokes_2d.pt" MODEL = None FULL_DATASET_X = None def download_file_from_hf_hub(repo_id, filename): """Downloads a file from Hugging Face Hub.""" print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...") try: local_path = hf_hub_download(repo_id=repo_id, filename=filename) print(f"Downloaded {filename} to {local_path} successfully.") return local_path except Exception as e: print(f"Error downloading file from HF Hub: {e}") raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}") def load_model(): """Loads the pre-trained FNO model to CPU.""" global MODEL if MODEL is None: print("Loading FNO model to CPU...") try: MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu') MODEL.eval() print("Model loaded successfully to CPU.") except Exception as e: print(f"Error loading model: {e}") raise gr.Error(f"Failed to load model: {e}") return MODEL def load_dataset(): """Downloads and loads the initial conditions dataset from HF Hub.""" global FULL_DATASET_X if FULL_DATASET_X is None: local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME) print("Loading dataset from local file...") try: data = torch.load(local_dataset_path, map_location='cpu') if isinstance(data, dict) and 'x' in data: FULL_DATASET_X = data['x'] elif isinstance(data, torch.Tensor): FULL_DATASET_X = data else: raise ValueError("Unknown dataset format or 'x' key missing.") print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}") except Exception as e: print(f"Error loading dataset: {e}") raise gr.Error(f"Failed to load dataset from local file: {e}") return FULL_DATASET_X def create_enhanced_plot(data, title, is_input=True): """Creates enhanced matplotlib plots with professional styling.""" plt.style.use('dark_background') fig, ax = plt.subplots(figsize=(8, 6), facecolor='#0f0f23') ax.set_facecolor('#0f0f23') # Enhanced colormap and visualization cmap = 'plasma' if is_input else 'viridis' im = ax.imshow(data, cmap=cmap, interpolation='bilinear') # Styling ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=20) ax.set_xlabel('X Coordinate', color='#8892b0', fontsize=10) ax.set_ylabel('Y Coordinate', color='#8892b0', fontsize=10) # Enhanced colorbar cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) cbar.set_label('Vorticity', color='#8892b0', fontsize=10) cbar.ax.tick_params(colors='#8892b0', labelsize=8) # Grid and ticks ax.tick_params(colors='#8892b0', labelsize=8) ax.grid(True, alpha=0.1, color='white') plt.tight_layout() return fig def run_inference(sample_index: int, progress=gr.Progress()): """ Performs inference for a selected sample index from the dataset on CPU. Returns two enhanced Matplotlib figures with progress tracking. """ progress(0.1, desc="Loading model...") device = torch.device("cpu") model = load_model() if next(model.parameters()).device != device: model.to(device) print(f"Model moved to {device} within run_inference.") progress(0.3, desc="Loading dataset...") dataset = load_dataset() if not (0 <= sample_index < dataset.shape[0]): raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.") progress(0.5, desc="Preparing input...") single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device) print(f"Input moved to {device}.") progress(0.7, desc="Running inference...") print(f"Running inference for sample index {sample_index}...") with torch.no_grad(): predicted_solution = model(single_initial_condition) progress(0.9, desc="Generating visualizations...") input_numpy = single_initial_condition.squeeze().cpu().numpy() output_numpy = predicted_solution.squeeze().cpu().numpy() fig_input = create_enhanced_plot(input_numpy, f"Initial Condition • Sample {sample_index}", is_input=True) fig_output = create_enhanced_plot(output_numpy, "Predicted Solution", is_input=False) progress(1.0, desc="Complete!") return fig_input, fig_output def get_random_sample(): """Returns a random sample index for exploration.""" dataset = load_dataset() return np.random.randint(0, dataset.shape[0]) # Custom CSS for professional styling with forced dark mode custom_css = """ /* Force dark mode for all elements */ * { color-scheme: dark !important; } body, html { background-color: #0f0f23 !important; color: #e2e8f0 !important; } #main-container { background: linear-gradient(135deg, #1a1d1f 0%, #764ba2 100%); min-height: 100vh; } .gradio-container { background: rgba(15, 15, 35, 0.95) !important; backdrop-filter: blur(10px); border-radius: 20px; border: 1px solid rgba(255, 255, 255, 0.1); box-shadow: 0 20px 40px rgba(0, 0, 0, 0.3); color: #e2e8f0 !important; } /* Force all text to be light colored */ p, span, div, label, h1, h2, h3, h4, h5, h6 { color: #e2e8f0 !important; } /* Fix link colors for both visited and unvisited */ a, a:link, a:visited, a:hover, a:active { color: #60a5fa !important; text-decoration: none !important; transition: color 0.3s ease !important; } a:hover { color: #93c5fd !important; text-decoration: underline !important; } /* Force markdown content colors */ .markdown-content, .markdown-content * { color: #e2e8f0 !important; } .markdown-content a, .markdown-content a:link, .markdown-content a:visited { color: #60a5fa !important; } .markdown-content a:hover { color: #93c5fd !important; } .gr-button { background: linear-gradient(45deg, #667eea, #764ba2) !important; border: none !important; border-radius: 12px !important; color: white !important; font-weight: 600 !important; padding: 12px 24px !important; transition: all 0.3s ease !important; box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; } .gr-button:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; } .gr-slider input[type="range"] { background: linear-gradient(to right, #667eea, #764ba2) !important; } .markdown-content h1 { background: linear-gradient(45deg, #667eea, #764ba2); -webkit-background-clip: text; -webkit-text-fill-color: transparent; text-align: center; font-size: 2.5rem !important; margin-bottom: 1rem !important; } .markdown-content h3 { color: #8892b0 !important; border-left: 4px solid #667eea; padding-left: 1rem; margin-top: 2rem !important; } .feature-card { background: rgba(255, 255, 255, 0.05); border-radius: 15px; padding: 1.5rem; border: 1px solid rgba(255, 255, 255, 0.1); backdrop-filter: blur(10px); color: #e2e8f0 !important; } .feature-card strong { color: #93c5fd !important; } .status-indicator { display: inline-block; width: 8px; height: 8px; background: #00ff88; border-radius: 50%; margin-right: 8px; animation: pulse 2s infinite; } @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.5; } 100% { opacity: 1; } } .info-panel { background: rgba(102, 126, 234, 0.1); border-left: 4px solid #667eea; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #e2e8f0 !important; } .info-panel strong { color: #93c5fd !important; } /* Force input and form elements to dark mode */ input, textarea, select { background-color: rgba(255, 255, 255, 0.1) !important; color: #e2e8f0 !important; border: 1px solid rgba(255, 255, 255, 0.2) !important; } /* Force plot containers to dark mode */ .gr-plot { background-color: #0f0f23 !important; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo: with gr.Column(elem_id="main-container"): # Hero Section gr.Markdown( """ # 🌊 Neural Fluid Dynamics ## Advanced Fourier Neural Operator for Navier-Stokes Equations