ajsbsd commited on
Commit
c41cdf5
·
verified ·
1 Parent(s): b5b786d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -57
app.py CHANGED
@@ -4,77 +4,60 @@ from neuralop.models import FNO
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import os
7
- import requests
8
- from tqdm import tqdm
9
- from huggingface_hub import HfApi, HfFolder, Repository, create_repo # <--- ADD THIS IMPORT
10
- import spaces # <--- ADD THIS IMPORT
11
 
12
  # --- Configuration ---
13
- MODEL_PATH = "fno_ckpt_single_res"
14
- DATASET_URL = "https://zenodo.org/record/12825163/files/navier_stokes_2d.pt?download=1"
15
- LOCAL_DATASET_PATH = "navier_stokes_2d.pt"
 
16
 
17
  # --- Global Variables for Model and Data (loaded once) ---
18
  MODEL = None
19
  FULL_DATASET_X = None
20
 
21
- # --- Function to Download Dataset ---
22
- def download_file(url, local_filename):
23
- """Downloads a file from a URL to a local path with a progress bar."""
24
- if os.path.exists(local_filename):
25
- print(f"{local_filename} already exists. Skipping download.")
26
- return
27
-
28
- print(f"Downloading {url} to {local_filename}...")
29
  try:
30
- response = requests.get(url, stream=True)
31
- response.raise_for_status()
32
-
33
- total_size = int(response.headers.get('content-length', 0))
34
- block_size = 1024
35
-
36
- with open(local_filename, 'wb') as f:
37
- with tqdm(total=total_size, unit='iB', unit_scale=True, desc=local_filename) as pbar:
38
- for chunk in response.iter_content(chunk_size=block_size):
39
- if chunk:
40
- f.write(chunk)
41
- pbar.update(len(chunk))
42
- print(f"Downloaded {local_filename} successfully.")
43
- except requests.exceptions.RequestException as e:
44
- print(f"Error downloading file: {e}")
45
- raise gr.Error(f"Failed to download dataset from Zenodo: {e}")
46
-
47
- # --- 1. Model Loading Function ---
48
  def load_model():
49
- """Loads the pre-trained FNO model."""
50
  global MODEL
51
  if MODEL is None:
52
- print("Loading FNO model...")
53
  try:
54
- # Load to CPU, then move to GPU if available and needed
55
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
56
- # Move model to GPU if available
57
- if torch.cuda.is_available():
58
- MODEL.cuda()
59
- print("Model moved to GPU.")
60
- else:
61
- print("CUDA not available. Model will run on CPU.")
62
  MODEL.eval()
63
- print("Model loaded successfully.")
64
  except Exception as e:
65
  print(f"Error loading model: {e}")
66
  raise gr.Error(f"Failed to load model: {e}")
67
  return MODEL
68
 
69
- # --- 2. Dataset Loading Function ---
70
  def load_dataset():
71
- """Downloads and loads the initial conditions dataset."""
72
  global FULL_DATASET_X
73
  if FULL_DATASET_X is None:
74
- download_file(DATASET_URL, LOCAL_DATASET_PATH)
 
75
  print("Loading dataset from local file...")
76
  try:
77
- data = torch.load(LOCAL_DATASET_PATH, map_location='cpu')
78
  if isinstance(data, dict) and 'x' in data:
79
  FULL_DATASET_X = data['x']
80
  elif isinstance(data, torch.Tensor):
@@ -87,8 +70,8 @@ def load_dataset():
87
  raise gr.Error(f"Failed to load dataset from local file: {e}")
88
  return FULL_DATASET_X
89
 
90
- # --- 3. Inference Function for Gradio (MODIFIED with @spaces.GPU()) ---
91
- @spaces.GPU() # <--- ADD THIS DECORATOR
92
  def run_inference(sample_index: int):
93
  """
94
  Performs inference for a selected sample index from the dataset.
@@ -97,30 +80,28 @@ def run_inference(sample_index: int):
97
  model = load_model()
98
  dataset = load_dataset()
99
 
 
 
 
 
100
  if not (0 <= sample_index < dataset.shape[0]):
101
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
102
 
103
- # Extract single initial condition and add channel dimension
104
- # (shape: [1, H, W] -> [1, 1, H, W])
105
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
106
 
107
- # Move input tensor to GPU if model is on GPU
108
  if torch.cuda.is_available():
109
  single_initial_condition = single_initial_condition.cuda()
110
  print("Input moved to GPU.")
111
  else:
112
  print("CUDA not available. Input remains on CPU.")
113
 
114
-
115
  print(f"Running inference for sample index {sample_index}...")
116
  with torch.no_grad():
117
  predicted_solution = model(single_initial_condition)
118
 
119
- # Move results back to CPU for plotting with Matplotlib
120
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
121
  output_numpy = predicted_solution.squeeze().cpu().numpy()
122
 
123
- # Create Matplotlib figures
124
  fig_input, ax_input = plt.subplots()
125
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
126
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
@@ -166,10 +147,8 @@ with gr.Blocks() as demo:
166
  )
167
 
168
  def load_initial_data_and_predict():
169
- # Ensure model and dataset are loaded when the space starts
170
  load_model()
171
  load_dataset()
172
- # Run inference for the default value (index 0)
173
  return run_inference(0)
174
 
175
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import os
7
+ # import requests # <--- NO LONGER NEEDED for Zenodo download
8
+ # from tqdm import tqdm # <--- NO LONGER NEEDED for Zenodo download
9
+ import spaces
10
+ from huggingface_hub import hf_hub_download # <--- ADD THIS IMPORT
11
 
12
  # --- Configuration ---
13
+ MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your Space's repo
14
+ # Updated: Hugging Face Dataset/Model ID and filename
15
+ HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" # Your new repo ID
16
+ HF_DATASET_FILENAME = "navier_stokes_2d.pt"
17
 
18
  # --- Global Variables for Model and Data (loaded once) ---
19
  MODEL = None
20
  FULL_DATASET_X = None
21
 
22
+ # --- Function to Download Dataset (MODIFIED to use hf_hub_download) ---
23
+ def download_file_from_hf_hub(repo_id, filename):
24
+ """Downloads a file from Hugging Face Hub."""
25
+ print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
 
 
 
 
26
  try:
27
+ # hf_hub_download returns the local path to the downloaded file
28
+ local_path = hf_hub_download(repo_id=repo_id, filename=filename)
29
+ print(f"Downloaded {filename} to {local_path} successfully.")
30
+ return local_path
31
+ except Exception as e:
32
+ print(f"Error downloading file from HF Hub: {e}")
33
+ raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")
34
+
35
+
36
+ # --- 1. Model Loading Function (No change from last successful CUDA fix) ---
 
 
 
 
 
 
 
 
37
  def load_model():
38
+ """Loads the pre-trained FNO model to CPU."""
39
  global MODEL
40
  if MODEL is None:
41
+ print("Loading FNO model to CPU...")
42
  try:
 
43
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
 
 
 
 
 
 
44
  MODEL.eval()
45
+ print("Model loaded successfully to CPU.")
46
  except Exception as e:
47
  print(f"Error loading model: {e}")
48
  raise gr.Error(f"Failed to load model: {e}")
49
  return MODEL
50
 
51
+ # --- 2. Dataset Loading Function (MODIFIED) ---
52
  def load_dataset():
53
+ """Downloads and loads the initial conditions dataset from HF Hub."""
54
  global FULL_DATASET_X
55
  if FULL_DATASET_X is None:
56
+ # Call the new HF Hub download function
57
+ local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
58
  print("Loading dataset from local file...")
59
  try:
60
+ data = torch.load(local_dataset_path, map_location='cpu')
61
  if isinstance(data, dict) and 'x' in data:
62
  FULL_DATASET_X = data['x']
63
  elif isinstance(data, torch.Tensor):
 
70
  raise gr.Error(f"Failed to load dataset from local file: {e}")
71
  return FULL_DATASET_X
72
 
73
+ # --- 3. Inference Function for Gradio (No changes needed here) ---
74
+ @spaces.GPU()
75
  def run_inference(sample_index: int):
76
  """
77
  Performs inference for a selected sample index from the dataset.
 
80
  model = load_model()
81
  dataset = load_dataset()
82
 
83
+ if torch.cuda.is_available() and next(model.parameters()).device == torch.device('cpu'):
84
+ model.cuda()
85
+ print("Model moved to GPU within run_inference.")
86
+
87
  if not (0 <= sample_index < dataset.shape[0]):
88
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
89
 
 
 
90
  single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
91
 
 
92
  if torch.cuda.is_available():
93
  single_initial_condition = single_initial_condition.cuda()
94
  print("Input moved to GPU.")
95
  else:
96
  print("CUDA not available. Input remains on CPU.")
97
 
 
98
  print(f"Running inference for sample index {sample_index}...")
99
  with torch.no_grad():
100
  predicted_solution = model(single_initial_condition)
101
 
 
102
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
103
  output_numpy = predicted_solution.squeeze().cpu().numpy()
104
 
 
105
  fig_input, ax_input = plt.subplots()
106
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
107
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
 
147
  )
148
 
149
  def load_initial_data_and_predict():
 
150
  load_model()
151
  load_dataset()
 
152
  return run_inference(0)
153
 
154
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])