kevinwang676 commited on
Commit
e535096
·
verified ·
1 Parent(s): 318a6da

Update inference/m4singer/base_svs_infer.py

Browse files
inference/m4singer/base_svs_infer.py CHANGED
@@ -18,7 +18,7 @@ import re
18
  class BaseSVSInfer:
19
  def __init__(self, hparams, device=None):
20
  if device is None:
21
- device = 'cuda' #if torch.cuda.is_available() else 'cpu'
22
  self.hparams = hparams
23
  self.device = device
24
 
@@ -51,7 +51,7 @@ class BaseSVSInfer:
51
  ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
52
  lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
53
  print('| load HifiGAN: ', ckpt)
54
- ckpt_dict = torch.load(ckpt) #torch.load(ckpt, map_location="cpu")
55
  config = set_hparams(config_path, global_hparams=False)
56
  state = ckpt_dict["state_dict"]["model_gen"]
57
  vocoder = HifiGanGenerator(config)
 
18
  class BaseSVSInfer:
19
  def __init__(self, hparams, device=None):
20
  if device is None:
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  self.hparams = hparams
23
  self.device = device
24
 
 
51
  ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
52
  lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
53
  print('| load HifiGAN: ', ckpt)
54
+ ckpt_dict = torch.load(ckpt, map_location="cpu")
55
  config = set_hparams(config_path, global_hparams=False)
56
  state = ckpt_dict["state_dict"]["model_gen"]
57
  vocoder = HifiGanGenerator(config)