jadechoghari commited on
Commit
5707c0d
·
verified ·
1 Parent(s): 7e1b955

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +3 -3
modeling.py CHANGED
@@ -10,7 +10,7 @@ import argparse
10
  from model import OptimizedAudioRestorationModel
11
  import librosa
12
  from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows
13
- from huggingface_hub import snapshot_download
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
@@ -46,8 +46,8 @@ class VoiceRestore(PreTrainedModel):
46
 
47
  # Optimized restoration model
48
  self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model)
49
- save_path = "/content/voicerestore/checkpoints/voice-restore-20d-16h-optim.pt"
50
- state_dict = torch.load(save_path, map_location=torch.device(device))
51
  if 'model_state_dict' in state_dict:
52
  state_dict = state_dict['model_state_dict']
53
 
 
10
  from model import OptimizedAudioRestorationModel
11
  import librosa
12
  from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows
13
+ from safetensors.torch import load_file
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
46
 
47
  # Optimized restoration model
48
  self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model)
49
+ save_path = "./pytorch_model.safetensors"
50
+ state_dict = load_file(save_path, device=device)
51
  if 'model_state_dict' in state_dict:
52
  state_dict = state_dict['model_state_dict']
53