File size: 7,487 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
import sys
import torch
import numpy as np
import torch.nn.functional as F
from librosa.filters import mel
from torchaudio.transforms import Resample
sys.path.append(os.getcwd())
from main.library import opencl
from main.library.predictors.FCPE.stft import STFT
def spawn_wav2mel(args, device = None):
_type = args.mel.type
if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
elif str(_type).lower() == 'stft': _type = 'stft'
wav2mel = Wav2MelModule(sr=args.mel.sr, n_mels=args.mel.num_mels, n_fft=args.mel.n_fft, win_size=args.mel.win_size, hop_length=args.mel.hop_size, fmin=args.mel.fmin, fmax=args.mel.fmax, clip_val=1e-05, mel_type=_type)
return wav2mel.to(torch.device(device))
class HannWindow(torch.nn.Module):
def __init__(self, win_size):
super().__init__()
self.register_buffer('window', torch.hann_window(win_size), persistent=False)
def forward(self):
return self.window
class MelModule(torch.nn.Module):
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
super().__init__()
if fmin is None: fmin = 0
if fmax is None: fmax = sr / 2
self.target_sr = sr
self.n_mels = n_mels
self.n_fft = n_fft
self.win_size = win_size
self.hop_length = hop_length
self.fmin = fmin
self.fmax = fmax
self.clip_val = clip_val
self.register_buffer('mel_basis', torch.tensor(mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
self.hann_window = torch.nn.ModuleDict()
self.out_stft = out_stft
@torch.no_grad()
def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
n_fft = self.n_fft
win_size = self.win_size
hop_length = self.hop_length
clip_val = self.clip_val
factor = 2 ** (key_shift / 12)
n_fft_new = int(np.round(n_fft * factor))
win_size_new = int(np.round(win_size * factor))
hop_length_new = int(np.round(hop_length * speed))
y = y.squeeze(-1)
key_shift_key = str(key_shift)
if not no_cache_window:
if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
else:
hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
self.hann_window[key_shift_key] = hann_window
hann_window_tensor = hann_window()
else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
pad_left = (win_size_new - hop_length_new) // 2
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
mode = 'reflect' if pad_right < y.size(-1) else 'constant'
pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1)
if str(y.device).startswith("ocl"):
stft = opencl.STFT(filter_length=n_fft_new, hop_length=hop_length_new, win_length=win_size_new).to(y.device)
spec = stft.transform(pad, 1e-9)
else:
spec = torch.stft(pad, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
if key_shift != 0:
size = n_fft // 2 + 1
resize = spec.size(1)
if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
spec = spec[:, :size, :] * win_size / win_size_new
spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
return torch.log(torch.clamp(spec, min=clip_val) * 1).transpose(-1, -2)
class Wav2MelModule(torch.nn.Module):
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
super().__init__()
if fmin is None: fmin = 0
if fmax is None: fmax = sr / 2
self.sampling_rate = sr
self.n_mels = n_mels
self.n_fft = n_fft
self.win_size = win_size
self.hop_size = hop_length
self.fmin = fmin
self.fmax = fmax
self.clip_val = clip_val
self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
self.resample_kernel = torch.nn.ModuleDict()
if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
self.mel_type = mel_type
@torch.no_grad()
def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
if sample_rate == self.sampling_rate: audio_res = audio
else:
key_str = str(sample_rate)
if key_str not in self.resample_kernel:
if len(self.resample_kernel) > 8: self.resample_kernel.clear()
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
n_frames = int(audio.shape[1] // self.hop_size) + 1
if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
return mel
class Wav2Mel:
def __init__(self, device=None, dtype=torch.float32):
self.sample_rate = 16000
self.hop_size = 160
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
self.resample_kernel = {}
def extract_nvstft(self, audio, keyshift=0, train=False):
return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
audio = audio.to(self.dtype).to(self.device)
if sample_rate == self.sample_rate: audio_res = audio
else:
key_str = str(sample_rate)
if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
audio_res = self.resample_kernel[key_str](audio)
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
n_frames = int(audio.shape[1] // self.hop_size) + 1
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
def __call__(self, audio, sample_rate, keyshift=0, train=False):
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train) |