Spaces:
Running
Running
Update model/utils.py
Browse files- model/utils.py +4 -4
model/utils.py
CHANGED
@@ -562,15 +562,15 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
|
562 |
|
563 |
ckpt_type = ckpt_path.split(".")[-1]
|
564 |
if ckpt_type == "safetensors":
|
565 |
-
|
566 |
-
|
567 |
else:
|
568 |
-
checkpoint = torch.load(ckpt_path, weights_only=
|
569 |
|
570 |
if use_ema == True:
|
571 |
ema_model = EMA(model, include_online_model = False).to(device)
|
572 |
if ckpt_type == "safetensors":
|
573 |
-
ema_model.load_state_dict(checkpoint)
|
574 |
else:
|
575 |
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
576 |
ema_model.copy_params_from_ema_to_model()
|
|
|
562 |
|
563 |
ckpt_type = ckpt_path.split(".")[-1]
|
564 |
if ckpt_type == "safetensors":
|
565 |
+
# from safetensors.torch import load_file
|
566 |
+
# checkpoint = load_file(ckpt_path, device=device)
|
567 |
else:
|
568 |
+
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
569 |
|
570 |
if use_ema == True:
|
571 |
ema_model = EMA(model, include_online_model = False).to(device)
|
572 |
if ckpt_type == "safetensors":
|
573 |
+
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
574 |
else:
|
575 |
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
576 |
ema_model.copy_params_from_ema_to_model()
|