Spaces:
Running
Running
Update inference/m4singer/base_svs_infer.py
Browse files
inference/m4singer/base_svs_infer.py
CHANGED
@@ -18,7 +18,7 @@ import re
|
|
18 |
class BaseSVSInfer:
|
19 |
def __init__(self, hparams, device=None):
|
20 |
if device is None:
|
21 |
-
device = 'cuda'
|
22 |
self.hparams = hparams
|
23 |
self.device = device
|
24 |
|
@@ -51,7 +51,7 @@ class BaseSVSInfer:
|
|
51 |
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
|
52 |
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
|
53 |
print('| load HifiGAN: ', ckpt)
|
54 |
-
ckpt_dict = torch.load(ckpt
|
55 |
config = set_hparams(config_path, global_hparams=False)
|
56 |
state = ckpt_dict["state_dict"]["model_gen"]
|
57 |
vocoder = HifiGanGenerator(config)
|
|
|
18 |
class BaseSVSInfer:
|
19 |
def __init__(self, hparams, device=None):
|
20 |
if device is None:
|
21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
self.hparams = hparams
|
23 |
self.device = device
|
24 |
|
|
|
51 |
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
|
52 |
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
|
53 |
print('| load HifiGAN: ', ckpt)
|
54 |
+
ckpt_dict = torch.load(ckpt, map_location="cpu")
|
55 |
config = set_hparams(config_path, global_hparams=False)
|
56 |
state = ckpt_dict["state_dict"]["model_gen"]
|
57 |
vocoder = HifiGanGenerator(config)
|