Gregniuki commited on
Commit
b38f0ce
·
verified ·
1 Parent(s): e5e3ba0

Update model/utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +28 -13
model/utils.py CHANGED
@@ -557,24 +557,39 @@ def repetition_found(text, length = 2, tolerance = 10):
557
 
558
  # load model checkpoint for inference
559
 
560
- def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
- from ema_pytorch import EMA
 
 
 
 
562
 
563
  ckpt_type = ckpt_path.split(".")[-1]
564
  if ckpt_type == "safetensors":
565
  from safetensors.torch import load_file
566
- checkpoint = load_file(ckpt_path, device=device)
 
567
  else:
568
- checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
 
570
- if use_ema == True:
571
- ema_model = EMA(model, include_online_model = False).to(device)
572
  if ckpt_type == "safetensors":
573
- ema_model.load_state_dict(checkpoint)
574
- else:
575
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
- ema_model.copy_params_from_ema_to_model()
 
 
 
 
 
 
 
 
 
577
  else:
578
- model.load_state_dict(checkpoint['model_state_dict'])
579
-
580
- return model
 
 
 
557
 
558
  # load model checkpoint for inference
559
 
560
+ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
561
+ if dtype is None:
562
+ dtype = (
563
+ torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
564
+ )
565
+ model = model.to(dtype)
566
 
567
  ckpt_type = ckpt_path.split(".")[-1]
568
  if ckpt_type == "safetensors":
569
  from safetensors.torch import load_file
570
+
571
+ checkpoint = load_file(ckpt_path)
572
  else:
573
+ checkpoint = torch.load(ckpt_path, weights_only=True)
574
 
575
+ if use_ema:
 
576
  if ckpt_type == "safetensors":
577
+ checkpoint = {"ema_model_state_dict": checkpoint}
578
+ checkpoint["model_state_dict"] = {
579
+ k.replace("ema_model.", ""): v
580
+ for k, v in checkpoint["ema_model_state_dict"].items()
581
+ if k not in ["initted", "step"]
582
+ }
583
+
584
+ # patch for backward compatibility, 305e3ea
585
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
586
+ if key in checkpoint["model_state_dict"]:
587
+ del checkpoint["model_state_dict"][key]
588
+
589
+ model.load_state_dict(checkpoint["model_state_dict"])
590
  else:
591
+ if ckpt_type == "safetensors":
592
+ checkpoint = {"model_state_dict": checkpoint}
593
+ model.load_state_dict(checkpoint["model_state_dict"])
594
+
595
+ return model.to(device)