Spaces:
Running
Running
Update model/utils.py
Browse files- model/utils.py +28 -13
model/utils.py
CHANGED
@@ -557,24 +557,39 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
557 |
|
558 |
# load model checkpoint for inference
|
559 |
|
560 |
-
def load_checkpoint(model, ckpt_path, device, use_ema
|
561 |
-
|
|
|
|
|
|
|
|
|
562 |
|
563 |
ckpt_type = ckpt_path.split(".")[-1]
|
564 |
if ckpt_type == "safetensors":
|
565 |
from safetensors.torch import load_file
|
566 |
-
|
|
|
567 |
else:
|
568 |
-
checkpoint = torch.load(ckpt_path, weights_only=True
|
569 |
|
570 |
-
if use_ema
|
571 |
-
ema_model = EMA(model, include_online_model = False).to(device)
|
572 |
if ckpt_type == "safetensors":
|
573 |
-
|
574 |
-
|
575 |
-
ema_model.
|
576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
else:
|
578 |
-
|
579 |
-
|
580 |
-
|
|
|
|
|
|
557 |
|
558 |
# load model checkpoint for inference
|
559 |
|
560 |
+
def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
|
561 |
+
if dtype is None:
|
562 |
+
dtype = (
|
563 |
+
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
564 |
+
)
|
565 |
+
model = model.to(dtype)
|
566 |
|
567 |
ckpt_type = ckpt_path.split(".")[-1]
|
568 |
if ckpt_type == "safetensors":
|
569 |
from safetensors.torch import load_file
|
570 |
+
|
571 |
+
checkpoint = load_file(ckpt_path)
|
572 |
else:
|
573 |
+
checkpoint = torch.load(ckpt_path, weights_only=True)
|
574 |
|
575 |
+
if use_ema:
|
|
|
576 |
if ckpt_type == "safetensors":
|
577 |
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
578 |
+
checkpoint["model_state_dict"] = {
|
579 |
+
k.replace("ema_model.", ""): v
|
580 |
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
581 |
+
if k not in ["initted", "step"]
|
582 |
+
}
|
583 |
+
|
584 |
+
# patch for backward compatibility, 305e3ea
|
585 |
+
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
586 |
+
if key in checkpoint["model_state_dict"]:
|
587 |
+
del checkpoint["model_state_dict"][key]
|
588 |
+
|
589 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
590 |
else:
|
591 |
+
if ckpt_type == "safetensors":
|
592 |
+
checkpoint = {"model_state_dict": checkpoint}
|
593 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
594 |
+
|
595 |
+
return model.to(device)
|