ZeyuXie commited on
Commit
0faef57
·
verified ·
1 Parent(s): 7e86f55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -16,11 +16,11 @@ class dotdict(dict):
16
  __delattr__ = dict.__delitem__
17
 
18
  class InferRunner:
19
- def __init__(self):
20
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
21
  self.vae = AutoencoderKL(**vae_config).to(device)
22
- vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
23
- self.vae.load_state_dict(vae_weights)
24
 
25
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
26
  self.pico_model = PicoDiffusion(
@@ -39,13 +39,9 @@ def infer(caption, runner):
39
  wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
40
  sf.write(f"synthesized/{caption}.wav", wave, samplerate=16000, subtype='PCM_16')
41
 
42
- infer_runner = InferRunner()
43
- if torch.cuda.is_available():
44
- device = "cuda"
45
- device_selection = "cuda:0"
46
- else:
47
- device = "cpu"
48
- device_selection = "cpu"
49
 
50
  with gr.Blocks() as demo:
51
  with gr.Row():
 
16
  __delattr__ = dict.__delitem__
17
 
18
  class InferRunner:
19
+ def __init__(self, device):
20
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
21
  self.vae = AutoencoderKL(**vae_config).to(device)
22
+ # vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
23
+ # self.vae.load_state_dict(vae_weights)
24
 
25
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
26
  self.pico_model = PicoDiffusion(
 
39
  wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
40
  sf.write(f"synthesized/{caption}.wav", wave, samplerate=16000, subtype='PCM_16')
41
 
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ infer_runner = InferRunner(device)
44
+
 
 
 
 
45
 
46
  with gr.Blocks() as demo:
47
  with gr.Row():