ajsbsd commited on
Commit
9979917
·
verified ·
1 Parent(s): bc194f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -34
app.py CHANGED
@@ -1,43 +1,73 @@
1
  import gradio as gr
2
  import torch
3
- from neuralop.models import FNO # Make sure this import path matches your environment
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- import os # For path handling
 
 
7
 
8
  # --- Configuration ---
9
- # Set paths relative to the root of your Hugging Face Space repository
10
- MODEL_PATH = "fno_ckpt_single_res"
11
- DATASET_PATH = "navier_stokes_2d.pt" # Ensure this file is in your repo root
 
12
 
13
  # --- Global Variables for Model and Data (loaded once) ---
14
  MODEL = None
15
- FULL_DATASET_X = None # To store all initial conditions
16
-
17
- # --- 1. Model Loading Function ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def load_model():
19
  """Loads the pre-trained FNO model."""
20
  global MODEL
21
  if MODEL is None:
22
  print("Loading FNO model...")
23
  try:
24
- # Ensure model is loaded to CPU for general compatibility in Spaces
25
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
26
- MODEL.eval() # Set to evaluation mode
27
  print("Model loaded successfully.")
28
  except Exception as e:
29
  print(f"Error loading model: {e}")
30
  raise gr.Error(f"Failed to load model: {e}")
31
  return MODEL
32
 
33
- # --- 2. Dataset Loading Function ---
34
  def load_dataset():
35
- """Loads the initial conditions dataset."""
36
  global FULL_DATASET_X
37
  if FULL_DATASET_X is None:
38
- print("Loading dataset...")
 
39
  try:
40
- data = torch.load(DATASET_PATH, map_location='cpu')
41
  if isinstance(data, dict) and 'x' in data:
42
  FULL_DATASET_X = data['x']
43
  elif isinstance(data, torch.Tensor):
@@ -45,54 +75,47 @@ def load_dataset():
45
  else:
46
  raise ValueError("Unknown dataset format or 'x' key missing.")
47
  print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
48
- except FileNotFoundError:
49
- print(f"Dataset file not found at {DATASET_PATH}")
50
- raise gr.Error(f"Dataset file not found. Please ensure '{DATASET_PATH}' is in your Space.")
51
  except Exception as e:
52
  print(f"Error loading dataset: {e}")
53
- raise gr.Error(f"Failed to load dataset: {e}")
54
  return FULL_DATASET_X
55
 
56
- # --- 3. Inference Function for Gradio ---
57
  def run_inference(sample_index: int):
58
  """
59
  Performs inference for a selected sample index from the dataset.
60
  Returns two Matplotlib figures: one for input, one for output.
61
  """
62
  model = load_model()
63
- dataset = load_dataset()
64
 
65
  if not (0 <= sample_index < dataset.shape[0]):
66
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
67
 
68
- # Extract single initial condition and add channel dimension
69
- # (shape: [1, H, W] -> [1, 1, H, W])
70
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
71
 
72
  print(f"Running inference for sample index {sample_index}...")
73
  with torch.no_grad():
74
  predicted_solution = model(single_initial_condition)
75
 
76
- # Convert tensors to numpy for plotting
77
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
78
  output_numpy = predicted_solution.squeeze().cpu().numpy()
79
 
80
- # Create Matplotlib figures
81
  fig_input, ax_input = plt.subplots()
82
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
83
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
84
  fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
85
- plt.close(fig_input) # Close to prevent display issues in Gradio
86
 
87
  fig_output, ax_output = plt.subplots()
88
  im_output = ax_output.imshow(output_numpy, cmap='viridis')
89
  ax_output.set_title(f"Predicted Solution")
90
  fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
91
- plt.close(fig_output) # Close to prevent display issues in Gradio
92
 
93
  return fig_input, fig_output
94
 
95
- # --- Gradio Interface Setup ---
96
  with gr.Blocks() as demo:
97
  gr.Markdown(
98
  """
@@ -104,9 +127,11 @@ with gr.Blocks() as demo:
104
 
105
  with gr.Row():
106
  with gr.Column():
 
 
107
  sample_input_slider = gr.Slider(
108
  minimum=0,
109
- maximum=9999, # Assuming 10,000 samples based on your dataset shape
110
  value=0,
111
  step=1,
112
  label="Select Sample Index"
@@ -116,23 +141,18 @@ with gr.Blocks() as demo:
116
  input_image_plot = gr.Plot(label="Selected Initial Condition")
117
  output_image_plot = gr.Plot(label="Predicted Solution")
118
 
119
- # Bind the button click to the inference function
120
  run_button.click(
121
  fn=run_inference,
122
  inputs=[sample_input_slider],
123
  outputs=[input_image_plot, output_image_plot]
124
  )
125
 
126
- # Optional: Load initial data on startup for the first display
127
  def load_initial_data_and_predict():
128
- # Ensure model and dataset are loaded when the space starts
129
  load_model()
130
- load_dataset()
131
- # Run inference for the default value (index 0)
132
  return run_inference(0)
133
 
134
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
135
 
136
- # Launch the Gradio app (only runs when you test locally)
137
  if __name__ == "__main__":
138
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from neuralop.models import FNO
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ import os
7
+ import requests # <--- ADD THIS IMPORT for downloading files
8
+ from tqdm import tqdm # Optional: for a progress bar during download
9
 
10
  # --- Configuration ---
11
+ MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your repo
12
+ # Zenodo direct download URL for the Navier-Stokes 2D dataset
13
+ DATASET_URL = "https://zenodo.org/record/12825163/files/navier_stokes_2d.pt?download=1"
14
+ LOCAL_DATASET_PATH = "navier_stokes_2d.pt" # Where the file will be saved locally in the Space
15
 
16
  # --- Global Variables for Model and Data (loaded once) ---
17
  MODEL = None
18
+ FULL_DATASET_X = None
19
+
20
+ # --- Function to Download Dataset ---
21
+ def download_file(url, local_filename):
22
+ """Downloads a file from a URL to a local path with a progress bar."""
23
+ if os.path.exists(local_filename):
24
+ print(f"{local_filename} already exists. Skipping download.")
25
+ return
26
+
27
+ print(f"Downloading {url} to {local_filename}...")
28
+ try:
29
+ response = requests.get(url, stream=True)
30
+ response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
31
+
32
+ total_size = int(response.headers.get('content-length', 0))
33
+ block_size = 1024 # 1 KB
34
+
35
+ with open(local_filename, 'wb') as f:
36
+ with tqdm(total=total_size, unit='iB', unit_scale=True, desc=local_filename) as pbar:
37
+ for chunk in response.iter_content(chunk_size=block_size):
38
+ if chunk:
39
+ f.write(chunk)
40
+ pbar.update(len(chunk))
41
+ print(f"Downloaded {local_filename} successfully.")
42
+ except requests.exceptions.RequestException as e:
43
+ print(f"Error downloading file: {e}")
44
+ raise gr.Error(f"Failed to download dataset from Zenodo: {e}")
45
+
46
+
47
+ # --- 1. Model Loading Function (No change here for model) ---
48
  def load_model():
49
  """Loads the pre-trained FNO model."""
50
  global MODEL
51
  if MODEL is None:
52
  print("Loading FNO model...")
53
  try:
 
54
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
55
+ MODEL.eval()
56
  print("Model loaded successfully.")
57
  except Exception as e:
58
  print(f"Error loading model: {e}")
59
  raise gr.Error(f"Failed to load model: {e}")
60
  return MODEL
61
 
62
+ # --- 2. Dataset Loading Function (MODIFIED) ---
63
  def load_dataset():
64
+ """Downloads and loads the initial conditions dataset."""
65
  global FULL_DATASET_X
66
  if FULL_DATASET_X is None:
67
+ download_file(DATASET_URL, LOCAL_DATASET_PATH) # <--- Download here!
68
+ print("Loading dataset from local file...")
69
  try:
70
+ data = torch.load(LOCAL_DATASET_PATH, map_location='cpu')
71
  if isinstance(data, dict) and 'x' in data:
72
  FULL_DATASET_X = data['x']
73
  elif isinstance(data, torch.Tensor):
 
75
  else:
76
  raise ValueError("Unknown dataset format or 'x' key missing.")
77
  print(f"Dataset loaded. Total samples: {FULL_DATASET_X.shape[0]}")
 
 
 
78
  except Exception as e:
79
  print(f"Error loading dataset: {e}")
80
+ raise gr.Error(f"Failed to load dataset from local file: {e}")
81
  return FULL_DATASET_X
82
 
83
+ # --- 3. Inference Function for Gradio (No change) ---
84
  def run_inference(sample_index: int):
85
  """
86
  Performs inference for a selected sample index from the dataset.
87
  Returns two Matplotlib figures: one for input, one for output.
88
  """
89
  model = load_model()
90
+ dataset = load_dataset() # This will trigger download and load if not already done
91
 
92
  if not (0 <= sample_index < dataset.shape[0]):
93
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
94
 
 
 
95
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
96
 
97
  print(f"Running inference for sample index {sample_index}...")
98
  with torch.no_grad():
99
  predicted_solution = model(single_initial_condition)
100
 
 
101
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
102
  output_numpy = predicted_solution.squeeze().cpu().numpy()
103
 
 
104
  fig_input, ax_input = plt.subplots()
105
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
106
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
107
  fig_input.colorbar(im_input, ax=ax_input, label="Vorticity")
108
+ plt.close(fig_input)
109
 
110
  fig_output, ax_output = plt.subplots()
111
  im_output = ax_output.imshow(output_numpy, cmap='viridis')
112
  ax_output.set_title(f"Predicted Solution")
113
  fig_output.colorbar(im_output, ax=ax_output, label="Vorticity")
114
+ plt.close(fig_output)
115
 
116
  return fig_input, fig_output
117
 
118
+ # --- Gradio Interface Setup (No change) ---
119
  with gr.Blocks() as demo:
120
  gr.Markdown(
121
  """
 
127
 
128
  with gr.Row():
129
  with gr.Column():
130
+ # Max value can be dynamic based on dataset size if needed,
131
+ # but 9999 for 10,000 samples is correct.
132
  sample_input_slider = gr.Slider(
133
  minimum=0,
134
+ maximum=9999,
135
  value=0,
136
  step=1,
137
  label="Select Sample Index"
 
141
  input_image_plot = gr.Plot(label="Selected Initial Condition")
142
  output_image_plot = gr.Plot(label="Predicted Solution")
143
 
 
144
  run_button.click(
145
  fn=run_inference,
146
  inputs=[sample_input_slider],
147
  outputs=[input_image_plot, output_image_plot]
148
  )
149
 
 
150
  def load_initial_data_and_predict():
 
151
  load_model()
152
+ load_dataset() # This will now download if not present
 
153
  return run_inference(0)
154
 
155
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
156
 
 
157
  if __name__ == "__main__":
158
  demo.launch()