Gregniuki commited on
Commit
b5f4618
·
verified ·
1 Parent(s): 6f9ec66

Update model/utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +4 -4
model/utils.py CHANGED
@@ -562,15 +562,15 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
562
 
563
  ckpt_type = ckpt_path.split(".")[-1]
564
  if ckpt_type == "safetensors":
565
- from safetensors.torch import load_file
566
- checkpoint = load_file(ckpt_path, device=device)
567
  else:
568
- checkpoint = torch.load(ckpt_path, weights_only=False, map_location=device)
569
 
570
  if use_ema == True:
571
  ema_model = EMA(model, include_online_model = False).to(device)
572
  if ckpt_type == "safetensors":
573
- ema_model.load_state_dict(checkpoint)
574
  else:
575
  ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
  ema_model.copy_params_from_ema_to_model()
 
562
 
563
  ckpt_type = ckpt_path.split(".")[-1]
564
  if ckpt_type == "safetensors":
565
+ # from safetensors.torch import load_file
566
+ # checkpoint = load_file(ckpt_path, device=device)
567
  else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
 
570
  if use_ema == True:
571
  ema_model = EMA(model, include_online_model = False).to(device)
572
  if ckpt_type == "safetensors":
573
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
574
  else:
575
  ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
  ema_model.copy_params_from_ema_to_model()