cocktailpeanut commited on
Commit
00b82df
Β·
1 Parent(s): 55a4620
Files changed (1) hide show
  1. diffrhythm/infer/infer_utils.py +2 -2
diffrhythm/infer/infer_utils.py CHANGED
@@ -14,7 +14,7 @@ def prepare_model(device):
14
  # prepare cfm model
15
  dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt")
16
  dit_config_path = "./diffrhythm/config/diffrhythm-1b.json"
17
- with open(dit_config_path) as f:
18
  model_config = json.load(f)
19
  dit_model_cls = DiT
20
  cfm = CFM(
@@ -194,4 +194,4 @@ def load_checkpoint(model, ckpt_path, device, use_ema=True):
194
  checkpoint = {"model_state_dict": checkpoint}
195
  model.load_state_dict(checkpoint["model_state_dict"], strict=False)
196
 
197
- return model.to(device)
 
14
  # prepare cfm model
15
  dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt")
16
  dit_config_path = "./diffrhythm/config/diffrhythm-1b.json"
17
+ with open(dit_config_path, encoding="utf-8") as f:
18
  model_config = json.load(f)
19
  dit_model_cls = DiT
20
  cfm = CFM(
 
194
  checkpoint = {"model_state_dict": checkpoint}
195
  model.load_state_dict(checkpoint["model_state_dict"], strict=False)
196
 
197
+ return model.to(device)