Spaces:
Runtime error
Runtime error
Commit
·
969cb52
1
Parent(s):
99cc645
change mel func
Browse files
orator/src/orator/models/voice_encoder/melspec.py
CHANGED
@@ -1,75 +1,78 @@
|
|
1 |
from functools import lru_cache
|
2 |
|
|
|
3 |
import numpy as np
|
4 |
-
import
|
5 |
-
from torchaudio.transforms import MelSpectrogram
|
6 |
-
|
7 |
-
from .config import VoiceEncConfig
|
8 |
-
|
9 |
-
|
10 |
-
class ResembleMelSpectrogram(torch.nn.Module):
|
11 |
-
def __init__(self, hp=VoiceEncConfig()):
|
12 |
-
"""
|
13 |
-
Torch implementation of Resemble's mel extraction.
|
14 |
-
Note that the values are NOT identical to librosa's implementation due to floating point precisions, however
|
15 |
-
the results are very very close. One test file gave an L1 error of just 0.005%, full results:
|
16 |
-
Librosa mel max: 0.871768
|
17 |
-
Torch mel max: 0.871768
|
18 |
-
Librosa mel mean: 0.316302
|
19 |
-
Torch mel mean: 0.316289
|
20 |
-
Max diff: 0.061105
|
21 |
-
Mean diff: 1.453384e-05
|
22 |
-
Percent error: 0.004595%
|
23 |
-
"""
|
24 |
-
super().__init__()
|
25 |
-
self.melspec = MelSpectrogram(
|
26 |
-
hp.sample_rate,
|
27 |
-
n_fft=hp.n_fft,
|
28 |
-
win_length=hp.win_size,
|
29 |
-
hop_length=hp.hop_size,
|
30 |
-
f_min=hp.fmin,
|
31 |
-
f_max=hp.fmax,
|
32 |
-
n_mels=hp.num_mels,
|
33 |
-
power=1,
|
34 |
-
normalized=False,
|
35 |
-
# NOTE: Folowing librosa's default.
|
36 |
-
pad_mode="constant",
|
37 |
-
norm="slaney",
|
38 |
-
mel_scale="slaney",
|
39 |
-
)
|
40 |
-
self.register_buffer(
|
41 |
-
"stft_magnitude_min",
|
42 |
-
torch.FloatTensor([hp.stft_magnitude_min])
|
43 |
-
)
|
44 |
-
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
|
45 |
-
self.preemphasis = hp.preemphasis
|
46 |
-
self.hop_size = hp.hop_size
|
47 |
-
|
48 |
-
def forward(self, wav, pad=True):
|
49 |
-
"""
|
50 |
-
Args:
|
51 |
-
wav: [B, T]
|
52 |
-
"""
|
53 |
-
if self.preemphasis > 0:
|
54 |
-
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
|
55 |
-
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
|
56 |
-
|
57 |
-
mel = self.melspec(wav)
|
58 |
-
|
59 |
-
mel = self._amp_to_db(mel)
|
60 |
-
mel_normed = self._normalize(mel)
|
61 |
-
assert not pad or mel_normed.shape[-1] == 1 + \
|
62 |
-
wav.shape[-1] // self.hop_size # Sanity check
|
63 |
-
return mel_normed # (M, T)
|
64 |
-
|
65 |
-
def _normalize(self, s, headroom_db=15):
|
66 |
-
s = (s - self.min_level_db) / (-self.min_level_db + headroom_db)
|
67 |
-
return s
|
68 |
-
|
69 |
-
def _amp_to_db(self, x):
|
70 |
-
return 20 * torch.maximum(self.stft_magnitude_min, x).log10()
|
71 |
|
72 |
|
73 |
@lru_cache()
|
74 |
-
def
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from functools import lru_cache
|
2 |
|
3 |
+
from scipy import signal
|
4 |
import numpy as np
|
5 |
+
import librosa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
@lru_cache()
|
9 |
+
def mel_basis(hp):
|
10 |
+
assert hp.fmax <= hp.sample_rate // 2
|
11 |
+
return librosa.filters.mel(
|
12 |
+
sr=hp.sample_rate,
|
13 |
+
n_fft=hp.n_fft,
|
14 |
+
n_mels=hp.num_mels,
|
15 |
+
fmin=hp.fmin,
|
16 |
+
fmax=hp.fmax) # -> (nmel, nfreq)
|
17 |
+
|
18 |
+
|
19 |
+
def preemphasis(wav, hp):
|
20 |
+
assert hp.preemphasis != 0
|
21 |
+
wav = signal.lfilter([1, -hp.preemphasis], [1], wav)
|
22 |
+
wav = np.clip(wav, -1, 1)
|
23 |
+
return wav
|
24 |
+
|
25 |
+
|
26 |
+
def melspectrogram(wav, hp, pad=True):
|
27 |
+
# Run through pre-emphasis
|
28 |
+
if hp.preemphasis > 0:
|
29 |
+
wav = preemphasis(wav, hp)
|
30 |
+
assert np.abs(wav).max() - 1 < 1e-07
|
31 |
+
|
32 |
+
# Do the stft
|
33 |
+
spec_complex = _stft(wav, hp, pad=pad)
|
34 |
+
|
35 |
+
# Get the magnitudes
|
36 |
+
spec_magnitudes = np.abs(spec_complex)
|
37 |
+
|
38 |
+
if hp.mel_power != 1.0:
|
39 |
+
spec_magnitudes **= hp.mel_power
|
40 |
+
|
41 |
+
# Get the mel and convert magnitudes->db
|
42 |
+
mel = np.dot(mel_basis(hp), spec_magnitudes)
|
43 |
+
if hp.mel_type == "db":
|
44 |
+
mel = _amp_to_db(mel, hp)
|
45 |
+
|
46 |
+
# Normalise the mel from db to 0,1
|
47 |
+
if hp.normalized_mels:
|
48 |
+
mel = _normalize(mel, hp).astype(np.float32)
|
49 |
+
|
50 |
+
assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check
|
51 |
+
return mel # (M, T)
|
52 |
+
|
53 |
+
|
54 |
+
def _stft(y, hp, pad=True):
|
55 |
+
# NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for
|
56 |
+
# historical consistency and streaming-version consistency
|
57 |
+
return librosa.stft(
|
58 |
+
y,
|
59 |
+
n_fft=hp.n_fft,
|
60 |
+
hop_length=hp.hop_size,
|
61 |
+
win_length=hp.win_size,
|
62 |
+
center=pad,
|
63 |
+
pad_mode="reflect",
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def _amp_to_db(x, hp):
|
68 |
+
return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x))
|
69 |
+
|
70 |
+
|
71 |
+
def _db_to_amp(x):
|
72 |
+
return np.power(10.0, x * 0.05)
|
73 |
+
|
74 |
+
|
75 |
+
def _normalize(s, hp, headroom_db=15):
|
76 |
+
min_level_db = 20 * np.log10(hp.stft_magnitude_min)
|
77 |
+
s = (s - min_level_db) / (-min_level_db + headroom_db)
|
78 |
+
return s
|
orator/src/orator/models/voice_encoder/voice_encoder.py
CHANGED
@@ -269,8 +269,6 @@ class VoiceEncoder(nn.Module):
|
|
269 |
if "rate" not in kwargs:
|
270 |
kwargs["rate"] = 1.3 # Resemble's default value.
|
271 |
|
272 |
-
|
273 |
-
mels = [mel_func(torch.from_numpy(w)
|
274 |
-
[None])[0].T for w in wavs]
|
275 |
|
276 |
return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)
|
|
|
269 |
if "rate" not in kwargs:
|
270 |
kwargs["rate"] = 1.3 # Resemble's default value.
|
271 |
|
272 |
+
mels = [melspectrogram(w, self.hp).T for w in wavs]
|
|
|
|
|
273 |
|
274 |
return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)
|