ajsbsd commited on
Commit
b5b786d
·
verified ·
1 Parent(s): 595b768

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -16
app.py CHANGED
@@ -4,14 +4,15 @@ from neuralop.models import FNO
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import os
7
- import requests # <--- ADD THIS IMPORT for downloading files
8
- from tqdm import tqdm # Optional: for a progress bar during download
 
 
9
 
10
  # --- Configuration ---
11
- MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your repo
12
- # Zenodo direct download URL for the Navier-Stokes 2D dataset
13
  DATASET_URL = "https://zenodo.org/record/12825163/files/navier_stokes_2d.pt?download=1"
14
- LOCAL_DATASET_PATH = "navier_stokes_2d.pt" # Where the file will be saved locally in the Space
15
 
16
  # --- Global Variables for Model and Data (loaded once) ---
17
  MODEL = None
@@ -27,10 +28,10 @@ def download_file(url, local_filename):
27
  print(f"Downloading {url} to {local_filename}...")
28
  try:
29
  response = requests.get(url, stream=True)
30
- response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
31
 
32
  total_size = int(response.headers.get('content-length', 0))
33
- block_size = 1024 # 1 KB
34
 
35
  with open(local_filename, 'wb') as f:
36
  with tqdm(total=total_size, unit='iB', unit_scale=True, desc=local_filename) as pbar:
@@ -43,15 +44,21 @@ def download_file(url, local_filename):
43
  print(f"Error downloading file: {e}")
44
  raise gr.Error(f"Failed to download dataset from Zenodo: {e}")
45
 
46
-
47
- # --- 1. Model Loading Function (No change here for model) ---
48
  def load_model():
49
  """Loads the pre-trained FNO model."""
50
  global MODEL
51
  if MODEL is None:
52
  print("Loading FNO model...")
53
  try:
 
54
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
 
 
 
 
 
 
55
  MODEL.eval()
56
  print("Model loaded successfully.")
57
  except Exception as e:
@@ -59,12 +66,12 @@ def load_model():
59
  raise gr.Error(f"Failed to load model: {e}")
60
  return MODEL
61
 
62
- # --- 2. Dataset Loading Function (MODIFIED) ---
63
  def load_dataset():
64
  """Downloads and loads the initial conditions dataset."""
65
  global FULL_DATASET_X
66
  if FULL_DATASET_X is None:
67
- download_file(DATASET_URL, LOCAL_DATASET_PATH) # <--- Download here!
68
  print("Loading dataset from local file...")
69
  try:
70
  data = torch.load(LOCAL_DATASET_PATH, map_location='cpu')
@@ -80,27 +87,40 @@ def load_dataset():
80
  raise gr.Error(f"Failed to load dataset from local file: {e}")
81
  return FULL_DATASET_X
82
 
83
- # --- 3. Inference Function for Gradio (No change) ---
 
84
  def run_inference(sample_index: int):
85
  """
86
  Performs inference for a selected sample index from the dataset.
87
  Returns two Matplotlib figures: one for input, one for output.
88
  """
89
  model = load_model()
90
- dataset = load_dataset() # This will trigger download and load if not already done
91
 
92
  if not (0 <= sample_index < dataset.shape[0]):
93
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
94
 
 
 
95
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
96
 
 
 
 
 
 
 
 
 
97
  print(f"Running inference for sample index {sample_index}...")
98
  with torch.no_grad():
99
  predicted_solution = model(single_initial_condition)
100
 
 
101
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
102
  output_numpy = predicted_solution.squeeze().cpu().numpy()
103
 
 
104
  fig_input, ax_input = plt.subplots()
105
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
106
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
@@ -127,8 +147,6 @@ with gr.Blocks() as demo:
127
 
128
  with gr.Row():
129
  with gr.Column():
130
- # Max value can be dynamic based on dataset size if needed,
131
- # but 9999 for 10,000 samples is correct.
132
  sample_input_slider = gr.Slider(
133
  minimum=0,
134
  maximum=9999,
@@ -148,8 +166,10 @@ with gr.Blocks() as demo:
148
  )
149
 
150
  def load_initial_data_and_predict():
 
151
  load_model()
152
- load_dataset() # This will now download if not present
 
153
  return run_inference(0)
154
 
155
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import os
7
+ import requests
8
+ from tqdm import tqdm
9
+ from huggingface_hub import HfApi, HfFolder, Repository, create_repo # <--- ADD THIS IMPORT
10
+ import spaces # <--- ADD THIS IMPORT
11
 
12
  # --- Configuration ---
13
+ MODEL_PATH = "fno_ckpt_single_res"
 
14
  DATASET_URL = "https://zenodo.org/record/12825163/files/navier_stokes_2d.pt?download=1"
15
+ LOCAL_DATASET_PATH = "navier_stokes_2d.pt"
16
 
17
  # --- Global Variables for Model and Data (loaded once) ---
18
  MODEL = None
 
28
  print(f"Downloading {url} to {local_filename}...")
29
  try:
30
  response = requests.get(url, stream=True)
31
+ response.raise_for_status()
32
 
33
  total_size = int(response.headers.get('content-length', 0))
34
+ block_size = 1024
35
 
36
  with open(local_filename, 'wb') as f:
37
  with tqdm(total=total_size, unit='iB', unit_scale=True, desc=local_filename) as pbar:
 
44
  print(f"Error downloading file: {e}")
45
  raise gr.Error(f"Failed to download dataset from Zenodo: {e}")
46
 
47
+ # --- 1. Model Loading Function ---
 
48
  def load_model():
49
  """Loads the pre-trained FNO model."""
50
  global MODEL
51
  if MODEL is None:
52
  print("Loading FNO model...")
53
  try:
54
+ # Load to CPU, then move to GPU if available and needed
55
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
56
+ # Move model to GPU if available
57
+ if torch.cuda.is_available():
58
+ MODEL.cuda()
59
+ print("Model moved to GPU.")
60
+ else:
61
+ print("CUDA not available. Model will run on CPU.")
62
  MODEL.eval()
63
  print("Model loaded successfully.")
64
  except Exception as e:
 
66
  raise gr.Error(f"Failed to load model: {e}")
67
  return MODEL
68
 
69
+ # --- 2. Dataset Loading Function ---
70
  def load_dataset():
71
  """Downloads and loads the initial conditions dataset."""
72
  global FULL_DATASET_X
73
  if FULL_DATASET_X is None:
74
+ download_file(DATASET_URL, LOCAL_DATASET_PATH)
75
  print("Loading dataset from local file...")
76
  try:
77
  data = torch.load(LOCAL_DATASET_PATH, map_location='cpu')
 
87
  raise gr.Error(f"Failed to load dataset from local file: {e}")
88
  return FULL_DATASET_X
89
 
90
+ # --- 3. Inference Function for Gradio (MODIFIED with @spaces.GPU()) ---
91
+ @spaces.GPU() # <--- ADD THIS DECORATOR
92
  def run_inference(sample_index: int):
93
  """
94
  Performs inference for a selected sample index from the dataset.
95
  Returns two Matplotlib figures: one for input, one for output.
96
  """
97
  model = load_model()
98
+ dataset = load_dataset()
99
 
100
  if not (0 <= sample_index < dataset.shape[0]):
101
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
102
 
103
+ # Extract single initial condition and add channel dimension
104
+ # (shape: [1, H, W] -> [1, 1, H, W])
105
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
106
 
107
+ # Move input tensor to GPU if model is on GPU
108
+ if torch.cuda.is_available():
109
+ single_initial_condition = single_initial_condition.cuda()
110
+ print("Input moved to GPU.")
111
+ else:
112
+ print("CUDA not available. Input remains on CPU.")
113
+
114
+
115
  print(f"Running inference for sample index {sample_index}...")
116
  with torch.no_grad():
117
  predicted_solution = model(single_initial_condition)
118
 
119
+ # Move results back to CPU for plotting with Matplotlib
120
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
121
  output_numpy = predicted_solution.squeeze().cpu().numpy()
122
 
123
+ # Create Matplotlib figures
124
  fig_input, ax_input = plt.subplots()
125
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
126
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
 
147
 
148
  with gr.Row():
149
  with gr.Column():
 
 
150
  sample_input_slider = gr.Slider(
151
  minimum=0,
152
  maximum=9999,
 
166
  )
167
 
168
  def load_initial_data_and_predict():
169
+ # Ensure model and dataset are loaded when the space starts
170
  load_model()
171
+ load_dataset()
172
+ # Run inference for the default value (index 0)
173
  return run_inference(0)
174
 
175
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])