|
import os
|
|
import sys
|
|
import torch
|
|
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
|
|
from librosa.filters import mel
|
|
from torchaudio.transforms import Resample
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library import opencl
|
|
from main.library.predictors.FCPE.stft import STFT
|
|
|
|
def spawn_wav2mel(args, device = None):
|
|
_type = args.mel.type
|
|
if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
|
|
elif str(_type).lower() == 'stft': _type = 'stft'
|
|
wav2mel = Wav2MelModule(sr=args.mel.sr, n_mels=args.mel.num_mels, n_fft=args.mel.n_fft, win_size=args.mel.win_size, hop_length=args.mel.hop_size, fmin=args.mel.fmin, fmax=args.mel.fmax, clip_val=1e-05, mel_type=_type)
|
|
|
|
return wav2mel.to(torch.device(device))
|
|
|
|
class HannWindow(torch.nn.Module):
|
|
def __init__(self, win_size):
|
|
super().__init__()
|
|
self.register_buffer('window', torch.hann_window(win_size), persistent=False)
|
|
|
|
def forward(self):
|
|
return self.window
|
|
|
|
class MelModule(torch.nn.Module):
|
|
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
|
|
super().__init__()
|
|
if fmin is None: fmin = 0
|
|
if fmax is None: fmax = sr / 2
|
|
self.target_sr = sr
|
|
self.n_mels = n_mels
|
|
self.n_fft = n_fft
|
|
self.win_size = win_size
|
|
self.hop_length = hop_length
|
|
self.fmin = fmin
|
|
self.fmax = fmax
|
|
self.clip_val = clip_val
|
|
self.register_buffer('mel_basis', torch.tensor(mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
|
|
self.hann_window = torch.nn.ModuleDict()
|
|
self.out_stft = out_stft
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
|
|
n_fft = self.n_fft
|
|
win_size = self.win_size
|
|
hop_length = self.hop_length
|
|
clip_val = self.clip_val
|
|
factor = 2 ** (key_shift / 12)
|
|
n_fft_new = int(np.round(n_fft * factor))
|
|
win_size_new = int(np.round(win_size * factor))
|
|
hop_length_new = int(np.round(hop_length * speed))
|
|
|
|
y = y.squeeze(-1)
|
|
key_shift_key = str(key_shift)
|
|
|
|
if not no_cache_window:
|
|
if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
|
|
else:
|
|
hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
|
|
self.hann_window[key_shift_key] = hann_window
|
|
|
|
hann_window_tensor = hann_window()
|
|
else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
|
|
|
|
pad_left = (win_size_new - hop_length_new) // 2
|
|
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
|
|
|
mode = 'reflect' if pad_right < y.size(-1) else 'constant'
|
|
pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1)
|
|
|
|
if str(y.device).startswith("ocl"):
|
|
stft = opencl.STFT(filter_length=n_fft_new, hop_length=hop_length_new, win_length=win_size_new).to(y.device)
|
|
spec = stft.transform(pad, 1e-9)
|
|
else:
|
|
spec = torch.stft(pad, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
|
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
|
|
|
|
if key_shift != 0:
|
|
size = n_fft // 2 + 1
|
|
resize = spec.size(1)
|
|
|
|
if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
|
|
spec = spec[:, :size, :] * win_size / win_size_new
|
|
|
|
spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
|
|
return torch.log(torch.clamp(spec, min=clip_val) * 1).transpose(-1, -2)
|
|
|
|
class Wav2MelModule(torch.nn.Module):
|
|
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
|
|
super().__init__()
|
|
if fmin is None: fmin = 0
|
|
if fmax is None: fmax = sr / 2
|
|
self.sampling_rate = sr
|
|
self.n_mels = n_mels
|
|
self.n_fft = n_fft
|
|
self.win_size = win_size
|
|
self.hop_size = hop_length
|
|
self.fmin = fmin
|
|
self.fmax = fmax
|
|
self.clip_val = clip_val
|
|
self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
|
|
self.resample_kernel = torch.nn.ModuleDict()
|
|
if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
|
|
elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
|
|
self.mel_type = mel_type
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
|
|
if sample_rate == self.sampling_rate: audio_res = audio
|
|
else:
|
|
key_str = str(sample_rate)
|
|
if key_str not in self.resample_kernel:
|
|
if len(self.resample_kernel) > 8: self.resample_kernel.clear()
|
|
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
|
|
|
|
audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
|
|
|
|
mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
|
|
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
|
if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
|
|
if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
|
|
|
|
return mel
|
|
|
|
class Wav2Mel:
|
|
def __init__(self, device=None, dtype=torch.float32):
|
|
self.sample_rate = 16000
|
|
self.hop_size = 160
|
|
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
|
|
self.resample_kernel = {}
|
|
|
|
def extract_nvstft(self, audio, keyshift=0, train=False):
|
|
return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
|
|
|
|
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
|
|
audio = audio.to(self.dtype).to(self.device)
|
|
if sample_rate == self.sample_rate: audio_res = audio
|
|
else:
|
|
key_str = str(sample_rate)
|
|
if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
|
|
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
|
|
audio_res = self.resample_kernel[key_str](audio)
|
|
|
|
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
|
|
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
|
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
|
|
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
|
|
|
|
def __call__(self, audio, sample_rate, keyshift=0, train=False):
|
|
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train) |