AnhP's picture
Upload 170 files
1e4a2ab verified
raw
history blame
7.49 kB
import os
import sys
import torch
import numpy as np
import torch.nn.functional as F
from librosa.filters import mel
from torchaudio.transforms import Resample
sys.path.append(os.getcwd())
from main.library import opencl
from main.library.predictors.FCPE.stft import STFT
def spawn_wav2mel(args, device = None):
_type = args.mel.type
if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
elif str(_type).lower() == 'stft': _type = 'stft'
wav2mel = Wav2MelModule(sr=args.mel.sr, n_mels=args.mel.num_mels, n_fft=args.mel.n_fft, win_size=args.mel.win_size, hop_length=args.mel.hop_size, fmin=args.mel.fmin, fmax=args.mel.fmax, clip_val=1e-05, mel_type=_type)
return wav2mel.to(torch.device(device))
class HannWindow(torch.nn.Module):
def __init__(self, win_size):
super().__init__()
self.register_buffer('window', torch.hann_window(win_size), persistent=False)
def forward(self):
return self.window
class MelModule(torch.nn.Module):
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
super().__init__()
if fmin is None: fmin = 0
if fmax is None: fmax = sr / 2
self.target_sr = sr
self.n_mels = n_mels
self.n_fft = n_fft
self.win_size = win_size
self.hop_length = hop_length
self.fmin = fmin
self.fmax = fmax
self.clip_val = clip_val
self.register_buffer('mel_basis', torch.tensor(mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
self.hann_window = torch.nn.ModuleDict()
self.out_stft = out_stft
@torch.no_grad()
def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
n_fft = self.n_fft
win_size = self.win_size
hop_length = self.hop_length
clip_val = self.clip_val
factor = 2 ** (key_shift / 12)
n_fft_new = int(np.round(n_fft * factor))
win_size_new = int(np.round(win_size * factor))
hop_length_new = int(np.round(hop_length * speed))
y = y.squeeze(-1)
key_shift_key = str(key_shift)
if not no_cache_window:
if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
else:
hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
self.hann_window[key_shift_key] = hann_window
hann_window_tensor = hann_window()
else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
pad_left = (win_size_new - hop_length_new) // 2
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
mode = 'reflect' if pad_right < y.size(-1) else 'constant'
pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1)
if str(y.device).startswith("ocl"):
stft = opencl.STFT(filter_length=n_fft_new, hop_length=hop_length_new, win_length=win_size_new).to(y.device)
spec = stft.transform(pad, 1e-9)
else:
spec = torch.stft(pad, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
if key_shift != 0:
size = n_fft // 2 + 1
resize = spec.size(1)
if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
spec = spec[:, :size, :] * win_size / win_size_new
spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
return torch.log(torch.clamp(spec, min=clip_val) * 1).transpose(-1, -2)
class Wav2MelModule(torch.nn.Module):
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
super().__init__()
if fmin is None: fmin = 0
if fmax is None: fmax = sr / 2
self.sampling_rate = sr
self.n_mels = n_mels
self.n_fft = n_fft
self.win_size = win_size
self.hop_size = hop_length
self.fmin = fmin
self.fmax = fmax
self.clip_val = clip_val
self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
self.resample_kernel = torch.nn.ModuleDict()
if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
self.mel_type = mel_type
@torch.no_grad()
def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
if sample_rate == self.sampling_rate: audio_res = audio
else:
key_str = str(sample_rate)
if key_str not in self.resample_kernel:
if len(self.resample_kernel) > 8: self.resample_kernel.clear()
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
n_frames = int(audio.shape[1] // self.hop_size) + 1
if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
return mel
class Wav2Mel:
def __init__(self, device=None, dtype=torch.float32):
self.sample_rate = 16000
self.hop_size = 160
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
self.resample_kernel = {}
def extract_nvstft(self, audio, keyshift=0, train=False):
return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
audio = audio.to(self.dtype).to(self.device)
if sample_rate == self.sample_rate: audio_res = audio
else:
key_str = str(sample_rate)
if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
audio_res = self.resample_kernel[key_str](audio)
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
n_frames = int(audio.shape[1] // self.hop_size) + 1
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
def __call__(self, audio, sample_rate, keyshift=0, train=False):
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)