hugo flores garcia commited on
Commit
21eac81
·
1 Parent(s): cad13c9
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +9 -2
README.md CHANGED
@@ -82,7 +82,7 @@ output_signal.write("scratch/output.wav")
82
  You can launch a gradio UI to play with vampnet.
83
 
84
  ```bash
85
- python app.py --args.load conf/interface.yml --Interface.device cuda
86
  ```
87
 
88
  # Training / Fine-tuning
 
82
  You can launch a gradio UI to play with vampnet.
83
 
84
  ```bash
85
+ python app.py
86
  ```
87
 
88
  # Training / Fine-tuning
app.py CHANGED
@@ -15,7 +15,14 @@ import gradio as gr
15
  from vampnet.interface import Interface, signal_concat
16
  from vampnet import mask as pmask
17
 
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
19
 
20
  interface = Interface.default()
21
  init_model_choice = open("DEFAULT_MODEL").read().strip()
@@ -134,7 +141,7 @@ def _vamp(
134
 
135
 
136
  t0 = time.time()
137
- interface.to("cuda" if torch.cuda.is_available() else "cpu")
138
  print(f"using device {interface.device}")
139
  _seed = seed if seed > 0 else None
140
  if _seed is None:
 
15
  from vampnet.interface import Interface, signal_concat
16
  from vampnet import mask as pmask
17
 
18
+ if torch.cuda.is_available():
19
+ device = "cuda"
20
+ elif torch.backends.mps.is_available():
21
+ device = "mps"
22
+ else:
23
+ device = "cpu"
24
+
25
+ print(f"using device {device}\n"*10)
26
 
27
  interface = Interface.default()
28
  init_model_choice = open("DEFAULT_MODEL").read().strip()
 
141
 
142
 
143
  t0 = time.time()
144
+ interface.to(device)
145
  print(f"using device {interface.device}")
146
  _seed = seed if seed > 0 else None
147
  if _seed is None: