Update modeling.py
Browse files- modeling.py +3 -3
modeling.py
CHANGED
@@ -10,7 +10,7 @@ import argparse
|
|
10 |
from model import OptimizedAudioRestorationModel
|
11 |
import librosa
|
12 |
from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows
|
13 |
-
from
|
14 |
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
@@ -46,8 +46,8 @@ class VoiceRestore(PreTrainedModel):
|
|
46 |
|
47 |
# Optimized restoration model
|
48 |
self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model)
|
49 |
-
save_path = "
|
50 |
-
state_dict =
|
51 |
if 'model_state_dict' in state_dict:
|
52 |
state_dict = state_dict['model_state_dict']
|
53 |
|
|
|
10 |
from model import OptimizedAudioRestorationModel
|
11 |
import librosa
|
12 |
from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows
|
13 |
+
from safetensors.torch import load_file
|
14 |
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
|
|
46 |
|
47 |
# Optimized restoration model
|
48 |
self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model)
|
49 |
+
save_path = "./pytorch_model.safetensors"
|
50 |
+
state_dict = load_file(save_path, device=device)
|
51 |
if 'model_state_dict' in state_dict:
|
52 |
state_dict = state_dict['model_state_dict']
|
53 |
|