Spaces:
Running
Running
Katock
commited on
Commit
·
c0c0eca
1
Parent(s):
6e5351e
Update utils.py
Browse files
utils.py
CHANGED
@@ -6,18 +6,20 @@ import argparse
|
|
6 |
import logging
|
7 |
import json
|
8 |
import subprocess
|
|
|
9 |
import random
|
10 |
-
|
11 |
import librosa
|
12 |
import numpy as np
|
13 |
from scipy.io.wavfile import read
|
14 |
import torch
|
15 |
from torch.nn import functional as F
|
16 |
from modules.commons import sequence_mask
|
17 |
-
|
|
|
18 |
MATPLOTLIB_FLAG = False
|
19 |
|
20 |
-
logging.basicConfig(stream=sys.stdout, level=logging.
|
21 |
logger = logging
|
22 |
|
23 |
f0_bin = 256
|
@@ -26,26 +28,6 @@ f0_min = 50.0
|
|
26 |
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
27 |
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
28 |
|
29 |
-
|
30 |
-
# def normalize_f0(f0, random_scale=True):
|
31 |
-
# f0_norm = f0.clone() # create a copy of the input Tensor
|
32 |
-
# batch_size, _, frame_length = f0_norm.shape
|
33 |
-
# for i in range(batch_size):
|
34 |
-
# means = torch.mean(f0_norm[i, 0, :])
|
35 |
-
# if random_scale:
|
36 |
-
# factor = random.uniform(0.8, 1.2)
|
37 |
-
# else:
|
38 |
-
# factor = 1
|
39 |
-
# f0_norm[i, 0, :] = (f0_norm[i, 0, :] - means) * factor
|
40 |
-
# return f0_norm
|
41 |
-
# def normalize_f0(f0, random_scale=True):
|
42 |
-
# means = torch.mean(f0[:, 0, :], dim=1, keepdim=True)
|
43 |
-
# if random_scale:
|
44 |
-
# factor = torch.Tensor(f0.shape[0],1).uniform_(0.8, 1.2).to(f0.device)
|
45 |
-
# else:
|
46 |
-
# factor = torch.ones(f0.shape[0], 1, 1).to(f0.device)
|
47 |
-
# f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
|
48 |
-
# return f0_norm
|
49 |
def normalize_f0(f0, x_mask, uv, random_scale=True):
|
50 |
# calculate means based on x_mask
|
51 |
uv_sum = torch.sum(uv, dim=1, keepdim=True)
|
@@ -62,7 +44,6 @@ def normalize_f0(f0, x_mask, uv, random_scale=True):
|
|
62 |
exit(0)
|
63 |
return f0_norm * x_mask
|
64 |
|
65 |
-
|
66 |
def plot_data_to_numpy(x, y):
|
67 |
global MATPLOTLIB_FLAG
|
68 |
if not MATPLOTLIB_FLAG:
|
@@ -86,87 +67,6 @@ def plot_data_to_numpy(x, y):
|
|
86 |
return data
|
87 |
|
88 |
|
89 |
-
|
90 |
-
def interpolate_f0(f0):
|
91 |
-
'''
|
92 |
-
对F0进行插值处理
|
93 |
-
'''
|
94 |
-
|
95 |
-
data = np.reshape(f0, (f0.size, 1))
|
96 |
-
|
97 |
-
vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
|
98 |
-
vuv_vector[data > 0.0] = 1.0
|
99 |
-
vuv_vector[data <= 0.0] = 0.0
|
100 |
-
|
101 |
-
ip_data = data
|
102 |
-
|
103 |
-
frame_number = data.size
|
104 |
-
last_value = 0.0
|
105 |
-
for i in range(frame_number):
|
106 |
-
if data[i] <= 0.0:
|
107 |
-
j = i + 1
|
108 |
-
for j in range(i + 1, frame_number):
|
109 |
-
if data[j] > 0.0:
|
110 |
-
break
|
111 |
-
if j < frame_number - 1:
|
112 |
-
if last_value > 0.0:
|
113 |
-
step = (data[j] - data[i - 1]) / float(j - i)
|
114 |
-
for k in range(i, j):
|
115 |
-
ip_data[k] = data[i - 1] + step * (k - i + 1)
|
116 |
-
else:
|
117 |
-
for k in range(i, j):
|
118 |
-
ip_data[k] = data[j]
|
119 |
-
else:
|
120 |
-
for k in range(i, frame_number):
|
121 |
-
ip_data[k] = last_value
|
122 |
-
else:
|
123 |
-
ip_data[i] = data[i]
|
124 |
-
last_value = data[i]
|
125 |
-
|
126 |
-
return ip_data[:,0], vuv_vector[:,0]
|
127 |
-
|
128 |
-
|
129 |
-
def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
|
130 |
-
import parselmouth
|
131 |
-
x = wav_numpy
|
132 |
-
if p_len is None:
|
133 |
-
p_len = x.shape[0]//hop_length
|
134 |
-
else:
|
135 |
-
assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
|
136 |
-
time_step = hop_length / sampling_rate * 1000
|
137 |
-
f0_min = 50
|
138 |
-
f0_max = 1100
|
139 |
-
f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac(
|
140 |
-
time_step=time_step / 1000, voicing_threshold=0.6,
|
141 |
-
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
|
142 |
-
|
143 |
-
pad_size=(p_len - len(f0) + 1) // 2
|
144 |
-
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
145 |
-
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
146 |
-
return f0
|
147 |
-
|
148 |
-
def resize_f0(x, target_len):
|
149 |
-
source = np.array(x)
|
150 |
-
source[source<0.001] = np.nan
|
151 |
-
target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
|
152 |
-
res = np.nan_to_num(target)
|
153 |
-
return res
|
154 |
-
|
155 |
-
def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
|
156 |
-
import pyworld
|
157 |
-
if p_len is None:
|
158 |
-
p_len = wav_numpy.shape[0]//hop_length
|
159 |
-
f0, t = pyworld.dio(
|
160 |
-
wav_numpy.astype(np.double),
|
161 |
-
fs=sampling_rate,
|
162 |
-
f0_ceil=800,
|
163 |
-
frame_period=1000 * hop_length / sampling_rate,
|
164 |
-
)
|
165 |
-
f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
|
166 |
-
for index, pitch in enumerate(f0):
|
167 |
-
f0[index] = round(pitch, 1)
|
168 |
-
return resize_f0(f0, p_len)
|
169 |
-
|
170 |
def f0_to_coarse(f0):
|
171 |
is_torch = isinstance(f0, torch.Tensor)
|
172 |
f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
|
@@ -174,48 +74,73 @@ def f0_to_coarse(f0):
|
|
174 |
|
175 |
f0_mel[f0_mel <= 1] = 1
|
176 |
f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
|
177 |
-
f0_coarse = (f0_mel + 0.5).
|
178 |
assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
|
179 |
return f0_coarse
|
180 |
|
181 |
-
|
182 |
-
def get_hubert_model():
|
183 |
-
vec_path = "hubert/checkpoint_best_legacy_500.pt"
|
184 |
-
print("load model(s) from {}".format(vec_path))
|
185 |
-
from fairseq import checkpoint_utils
|
186 |
-
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
187 |
-
[vec_path],
|
188 |
-
suffix="",
|
189 |
-
)
|
190 |
-
model = models[0]
|
191 |
-
model.eval()
|
192 |
-
return model
|
193 |
-
|
194 |
-
def get_hubert_content(hmodel, wav_16k_tensor):
|
195 |
-
feats = wav_16k_tensor
|
196 |
-
if feats.dim() == 2: # double channels
|
197 |
-
feats = feats.mean(-1)
|
198 |
-
assert feats.dim() == 1, feats.dim()
|
199 |
-
feats = feats.view(1, -1)
|
200 |
-
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
201 |
-
inputs = {
|
202 |
-
"source": feats.to(wav_16k_tensor.device),
|
203 |
-
"padding_mask": padding_mask.to(wav_16k_tensor.device),
|
204 |
-
"output_layer": 9, # layer 9
|
205 |
-
}
|
206 |
-
with torch.no_grad():
|
207 |
-
logits = hmodel.extract_features(**inputs)
|
208 |
-
feats = hmodel.final_proj(logits[0])
|
209 |
-
return feats.transpose(1, 2)
|
210 |
-
|
211 |
-
|
212 |
def get_content(cmodel, y):
|
213 |
with torch.no_grad():
|
214 |
c = cmodel.extract_features(y.squeeze(1))[0]
|
215 |
c = c.transpose(1, 2)
|
216 |
return c
|
217 |
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
221 |
assert os.path.isfile(checkpoint_path)
|
@@ -244,6 +169,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|
244 |
model.module.load_state_dict(new_state_dict)
|
245 |
else:
|
246 |
model.load_state_dict(new_state_dict)
|
|
|
247 |
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
248 |
checkpoint_path, iteration))
|
249 |
return model, optimizer, learning_rate, iteration
|
@@ -368,7 +294,7 @@ def load_filepaths_and_text(filename, split="|"):
|
|
368 |
|
369 |
def get_hparams(init=True):
|
370 |
parser = argparse.ArgumentParser()
|
371 |
-
parser.add_argument('-c', '--config', type=str, default="./configs/
|
372 |
help='JSON file for configuration')
|
373 |
parser.add_argument('-m', '--model', type=str, required=True,
|
374 |
help='Model name')
|
@@ -411,7 +337,6 @@ def get_hparams_from_file(config_path):
|
|
411 |
with open(config_path, "r") as f:
|
412 |
data = f.read()
|
413 |
config = json.loads(data)
|
414 |
-
|
415 |
hparams =HParams(**config)
|
416 |
return hparams
|
417 |
|
@@ -468,6 +393,41 @@ def repeat_expand_2d(content, target_len):
|
|
468 |
return target
|
469 |
|
470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
class HParams():
|
472 |
def __init__(self, **kwargs):
|
473 |
for k, v in kwargs.items():
|
@@ -499,3 +459,19 @@ class HParams():
|
|
499 |
def __repr__(self):
|
500 |
return self.__dict__.__repr__()
|
501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import logging
|
7 |
import json
|
8 |
import subprocess
|
9 |
+
import warnings
|
10 |
import random
|
11 |
+
import functools
|
12 |
import librosa
|
13 |
import numpy as np
|
14 |
from scipy.io.wavfile import read
|
15 |
import torch
|
16 |
from torch.nn import functional as F
|
17 |
from modules.commons import sequence_mask
|
18 |
+
import tqdm
|
19 |
+
|
20 |
MATPLOTLIB_FLAG = False
|
21 |
|
22 |
+
logging.basicConfig(stream=sys.stdout, level=logging.WARN)
|
23 |
logger = logging
|
24 |
|
25 |
f0_bin = 256
|
|
|
28 |
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
29 |
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def normalize_f0(f0, x_mask, uv, random_scale=True):
|
32 |
# calculate means based on x_mask
|
33 |
uv_sum = torch.sum(uv, dim=1, keepdim=True)
|
|
|
44 |
exit(0)
|
45 |
return f0_norm * x_mask
|
46 |
|
|
|
47 |
def plot_data_to_numpy(x, y):
|
48 |
global MATPLOTLIB_FLAG
|
49 |
if not MATPLOTLIB_FLAG:
|
|
|
67 |
return data
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def f0_to_coarse(f0):
|
71 |
is_torch = isinstance(f0, torch.Tensor)
|
72 |
f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
|
|
|
74 |
|
75 |
f0_mel[f0_mel <= 1] = 1
|
76 |
f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
|
77 |
+
f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int)
|
78 |
assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
|
79 |
return f0_coarse
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
def get_content(cmodel, y):
|
82 |
with torch.no_grad():
|
83 |
c = cmodel.extract_features(y.squeeze(1))[0]
|
84 |
c = c.transpose(1, 2)
|
85 |
return c
|
86 |
|
87 |
+
def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
|
88 |
+
if f0_predictor == "pm":
|
89 |
+
from modules.F0Predictor.PMF0Predictor import PMF0Predictor
|
90 |
+
f0_predictor_object = PMF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
|
91 |
+
elif f0_predictor == "crepe":
|
92 |
+
from modules.F0Predictor.CrepeF0Predictor import CrepeF0Predictor
|
93 |
+
f0_predictor_object = CrepeF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,device=kargs["device"],threshold=kargs["threshold"])
|
94 |
+
elif f0_predictor == "harvest":
|
95 |
+
from modules.F0Predictor.HarvestF0Predictor import HarvestF0Predictor
|
96 |
+
f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
|
97 |
+
elif f0_predictor == "dio":
|
98 |
+
from modules.F0Predictor.DioF0Predictor import DioF0Predictor
|
99 |
+
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
|
100 |
+
else:
|
101 |
+
raise Exception("Unknown f0 predictor")
|
102 |
+
return f0_predictor_object
|
103 |
+
|
104 |
+
def get_speech_encoder(speech_encoder,device=None,**kargs):
|
105 |
+
if speech_encoder == "vec768l12":
|
106 |
+
from vencoder.ContentVec768L12 import ContentVec768L12
|
107 |
+
speech_encoder_object = ContentVec768L12(device = device)
|
108 |
+
elif speech_encoder == "vec256l9":
|
109 |
+
from vencoder.ContentVec256L9 import ContentVec256L9
|
110 |
+
speech_encoder_object = ContentVec256L9(device = device)
|
111 |
+
elif speech_encoder == "vec256l9-onnx":
|
112 |
+
from vencoder.ContentVec256L9_Onnx import ContentVec256L9_Onnx
|
113 |
+
speech_encoder_object = ContentVec256L9_Onnx(device = device)
|
114 |
+
elif speech_encoder == "vec256l12-onnx":
|
115 |
+
from vencoder.ContentVec256L12_Onnx import ContentVec256L12_Onnx
|
116 |
+
speech_encoder_object = ContentVec256L12_Onnx(device = device)
|
117 |
+
elif speech_encoder == "vec768l9-onnx":
|
118 |
+
from vencoder.ContentVec768L9_Onnx import ContentVec768L9_Onnx
|
119 |
+
speech_encoder_object = ContentVec768L9_Onnx(device = device)
|
120 |
+
elif speech_encoder == "vec768l12-onnx":
|
121 |
+
from vencoder.ContentVec768L12_Onnx import ContentVec768L12_Onnx
|
122 |
+
speech_encoder_object = ContentVec768L12_Onnx(device = device)
|
123 |
+
elif speech_encoder == "hubertsoft-onnx":
|
124 |
+
from vencoder.HubertSoft_Onnx import HubertSoft_Onnx
|
125 |
+
speech_encoder_object = HubertSoft_Onnx(device = device)
|
126 |
+
elif speech_encoder == "hubertsoft":
|
127 |
+
from vencoder.HubertSoft import HubertSoft
|
128 |
+
speech_encoder_object = HubertSoft(device = device)
|
129 |
+
elif speech_encoder == "whisper-ppg":
|
130 |
+
from vencoder.WhisperPPG import WhisperPPG
|
131 |
+
speech_encoder_object = WhisperPPG(device = device)
|
132 |
+
elif speech_encoder == "cnhubertlarge":
|
133 |
+
from vencoder.CNHubertLarge import CNHubertLarge
|
134 |
+
speech_encoder_object = CNHubertLarge(device = device)
|
135 |
+
elif speech_encoder == "dphubert":
|
136 |
+
from vencoder.DPHubert import DPHubert
|
137 |
+
speech_encoder_object = DPHubert(device = device)
|
138 |
+
elif speech_encoder == "whisper-ppg-large":
|
139 |
+
from vencoder.WhisperPPGLarge import WhisperPPGLarge
|
140 |
+
speech_encoder_object = WhisperPPGLarge(device = device)
|
141 |
+
else:
|
142 |
+
raise Exception("Unknown speech encoder")
|
143 |
+
return speech_encoder_object
|
144 |
|
145 |
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
146 |
assert os.path.isfile(checkpoint_path)
|
|
|
169 |
model.module.load_state_dict(new_state_dict)
|
170 |
else:
|
171 |
model.load_state_dict(new_state_dict)
|
172 |
+
print("load ")
|
173 |
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
174 |
checkpoint_path, iteration))
|
175 |
return model, optimizer, learning_rate, iteration
|
|
|
294 |
|
295 |
def get_hparams(init=True):
|
296 |
parser = argparse.ArgumentParser()
|
297 |
+
parser.add_argument('-c', '--config', type=str, default="./configs/config.json",
|
298 |
help='JSON file for configuration')
|
299 |
parser.add_argument('-m', '--model', type=str, required=True,
|
300 |
help='Model name')
|
|
|
337 |
with open(config_path, "r") as f:
|
338 |
data = f.read()
|
339 |
config = json.loads(data)
|
|
|
340 |
hparams =HParams(**config)
|
341 |
return hparams
|
342 |
|
|
|
393 |
return target
|
394 |
|
395 |
|
396 |
+
def mix_model(model_paths,mix_rate,mode):
|
397 |
+
mix_rate = torch.FloatTensor(mix_rate)/100
|
398 |
+
model_tem = torch.load(model_paths[0])
|
399 |
+
models = [torch.load(path)["model"] for path in model_paths]
|
400 |
+
if mode == 0:
|
401 |
+
mix_rate = F.softmax(mix_rate,dim=0)
|
402 |
+
for k in model_tem["model"].keys():
|
403 |
+
model_tem["model"][k] = torch.zeros_like(model_tem["model"][k])
|
404 |
+
for i,model in enumerate(models):
|
405 |
+
model_tem["model"][k] += model[k]*mix_rate[i]
|
406 |
+
torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
|
407 |
+
return os.path.join(os.path.curdir,"output.pth")
|
408 |
+
|
409 |
+
def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比 from RVC
|
410 |
+
# print(data1.max(),data2.max())
|
411 |
+
rms1 = librosa.feature.rms(
|
412 |
+
y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
|
413 |
+
) # 每半秒一个点
|
414 |
+
rms2 = librosa.feature.rms(y=data2.detach().cpu().numpy(), frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
|
415 |
+
rms1 = torch.from_numpy(rms1).to(data2.device)
|
416 |
+
rms1 = F.interpolate(
|
417 |
+
rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
|
418 |
+
).squeeze()
|
419 |
+
rms2 = torch.from_numpy(rms2).to(data2.device)
|
420 |
+
rms2 = F.interpolate(
|
421 |
+
rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
|
422 |
+
).squeeze()
|
423 |
+
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
|
424 |
+
data2 *= (
|
425 |
+
torch.pow(rms1, torch.tensor(1 - rate))
|
426 |
+
* torch.pow(rms2, torch.tensor(rate - 1))
|
427 |
+
)
|
428 |
+
return data2
|
429 |
+
|
430 |
+
|
431 |
class HParams():
|
432 |
def __init__(self, **kwargs):
|
433 |
for k, v in kwargs.items():
|
|
|
459 |
def __repr__(self):
|
460 |
return self.__dict__.__repr__()
|
461 |
|
462 |
+
def get(self,index):
|
463 |
+
return self.__dict__.get(index)
|
464 |
+
|
465 |
+
class Volume_Extractor:
|
466 |
+
def __init__(self, hop_size = 512):
|
467 |
+
self.hop_size = hop_size
|
468 |
+
|
469 |
+
def extract(self, audio): # audio: 2d tensor array
|
470 |
+
if not isinstance(audio,torch.Tensor):
|
471 |
+
audio = torch.Tensor(audio)
|
472 |
+
n_frames = int(audio.size(-1) // self.hop_size)
|
473 |
+
audio2 = audio ** 2
|
474 |
+
audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
|
475 |
+
volume = torch.FloatTensor([torch.mean(audio2[:,int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
|
476 |
+
volume = torch.sqrt(volume)
|
477 |
+
return volume
|