Yiming-M commited on
Commit
68d6ff9
·
verified ·
1 Parent(s): e5db566

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import torch.nn.functional as F
3
  from torch import Tensor
 
 
4
  import numpy as np
5
  from PIL import Image
6
  import json, os, random
@@ -60,7 +62,8 @@ model_configs = {
60
  },
61
  }
62
 
63
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
64
 
65
  if truncation is None: # regression, no truncation.
66
  bins, anchor_points = None, None
@@ -109,7 +112,6 @@ def load_model(model_choice: str):
109
  state_dict = load_file(weights_path)
110
  new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
111
  model.load_state_dict(new_state_dict)
112
- model.to(device)
113
  model.eval()
114
 
115
  loaded_model = model
@@ -139,6 +141,7 @@ def transform(image: Image.Image):
139
  # -----------------------------
140
  # Inference function
141
  # -----------------------------
 
142
  def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"):
143
  """
144
  Given an input image, preprocess it, run the model to obtain a density map,
@@ -149,6 +152,7 @@ def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"):
149
  if loaded_model is None or model_configs[model_choice]["model_name"] not in loaded_model.__class__.__name__:
150
  load_model(model_choice)
151
 
 
152
  # Preprocess the image
153
  input_width, input_height = image.size
154
  input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
 
1
  import torch
2
  import torch.nn.functional as F
3
  from torch import Tensor
4
+ import spaces
5
+
6
  import numpy as np
7
  from PIL import Image
8
  import json, os, random
 
62
  },
63
  }
64
 
65
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ device = "cuda"
67
 
68
  if truncation is None: # regression, no truncation.
69
  bins, anchor_points = None, None
 
112
  state_dict = load_file(weights_path)
113
  new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
114
  model.load_state_dict(new_state_dict)
 
115
  model.eval()
116
 
117
  loaded_model = model
 
141
  # -----------------------------
142
  # Inference function
143
  # -----------------------------
144
+ @spaces.GPU(duration=120)
145
  def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"):
146
  """
147
  Given an input image, preprocess it, run the model to obtain a density map,
 
152
  if loaded_model is None or model_configs[model_choice]["model_name"] not in loaded_model.__class__.__name__:
153
  load_model(model_choice)
154
 
155
+ loaded_model.to(device)
156
  # Preprocess the image
157
  input_width, input_height = image.size
158
  input_tensor = transform(image).to(device) # shape: (1, 3, H, W)