AbstractPhil commited on
Commit
504e98b
Β·
1 Parent(s): c093ba7
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -39,13 +39,14 @@ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
39
  # ─── Loader ───────────────────────────────────────────────────
40
  from safetensors.torch import safe_open
41
 
 
42
  def load_adapter(repo, filename, config):
43
  # Don't initialize device here
44
  path = hf_hub_download(repo_id=repo, filename=filename)
45
 
46
  model = TwoStreamShuntAdapter(config).eval()
47
  tensors = {}
48
- with safe_open(path, framework="pt", device="cpu") as f:
49
  for key in f.keys():
50
  tensors[key] = f.get_tensor(key)
51
  model.load_state_dict(tensors)
 
39
  # ─── Loader ───────────────────────────────────────────────────
40
  from safetensors.torch import safe_open
41
 
42
+ @spaces.GPU
43
  def load_adapter(repo, filename, config):
44
  # Don't initialize device here
45
  path = hf_hub_download(repo_id=repo, filename=filename)
46
 
47
  model = TwoStreamShuntAdapter(config).eval()
48
  tensors = {}
49
+ with safe_open(path, framework="pt", device="cuda") as f:
50
  for key in f.keys():
51
  tensors[key] = f.get_tensor(key)
52
  model.load_state_dict(tensors)