Spaces:
Paused
Paused
import logging | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from huggingface_hub import PyTorchModelHubMixin | |
from .seq import BiGRU | |
from .deepunet import DeepUnet | |
from .mel import MelSpectrogram | |
from .constants import * | |
logger = logging.getLogger(__name__) | |
class RMVPE(nn.Module, PyTorchModelHubMixin): | |
def __init__( | |
self, | |
n_blocks: int, | |
n_gru: int, | |
kernel_size: int, | |
en_de_layers=5, | |
inter_layers=4, | |
in_channels=1, | |
en_out_channels=16, | |
): | |
super().__init__() | |
self.mel_extractor = MelSpectrogram( | |
N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, None, MEL_FMIN, MEL_FMAX | |
) | |
self.unet = DeepUnet( | |
kernel_size, | |
n_blocks, | |
en_de_layers, | |
inter_layers, | |
in_channels, | |
en_out_channels, | |
) | |
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) | |
if n_gru: | |
self.fc = nn.Sequential( | |
BiGRU(3 * N_MELS, 256, n_gru), | |
nn.Linear(512, N_CLASS), | |
nn.Dropout(0.25), | |
nn.Sigmoid(), | |
) | |
else: | |
self.fc = nn.Sequential( | |
nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid() | |
) | |
cents_mapping = 20 * np.arange(360) + MAGIC_CONST | |
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368 | |
self.cents_mapping_torch = torch.from_numpy(self.cents_mapping).to( | |
dtype=torch.float32 | |
) | |
def to(self, device): | |
self.cents_mapping_torch = self.cents_mapping_torch.to(device) | |
return super().to(device) | |
def forward(self, mel: torch.Tensor) -> torch.Tensor: | |
mel = mel.transpose(-1, -2).unsqueeze(1) | |
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) | |
x = self.fc(x) | |
return x | |
def mel2hidden(self, mel: torch.Tensor): | |
with torch.no_grad(): | |
n_frames = mel.shape[2] | |
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames | |
mel = F.pad(mel, (0, n_pad), mode="constant") | |
hidden = self(mel) | |
return hidden[:, :n_frames] | |
def decode(self, hidden: torch.Tensor, thred=0.03): | |
cents_pred = self.to_local_average_cents(hidden, thred=thred) | |
f0 = 10 * (2 ** (cents_pred / 1200)) | |
f0[f0 == 10] = 0 | |
return f0 | |
def infer(self, audio: torch.Tensor, thred=0.03, return_tensor=False): | |
mel = self.mel_extractor(audio.unsqueeze(0)) | |
hidden = self.mel2hidden(mel) | |
hidden = hidden[0].float() | |
f0 = self.decode(hidden, thred=thred) | |
if return_tensor: | |
return f0 | |
return f0.cpu().numpy() | |
def infer_from_audio(self, audio: np.ndarray, thred=0.03): | |
audio = torch.from_numpy(audio).to(next(self.parameters()).device) | |
return self.infer(audio, thred=thred) | |
def to_local_average_cents( | |
self, salience: torch.Tensor, thred=0.05 | |
) -> torch.Tensor: | |
center = torch.argmax(salience, dim=1) | |
salience = F.pad(salience, (4, 4)) | |
center += 4 | |
batch_indices = torch.arange(salience.shape[0], device=salience.device) | |
# Create indices for the 9-point window around each center | |
offsets = torch.arange(-4, 5, device=salience.device) | |
indices = center.unsqueeze(1) + offsets.unsqueeze(0) | |
# Extract values using advanced indexing | |
todo_salience = salience[batch_indices.unsqueeze(1), indices] | |
todo_cents_mapping = self.cents_mapping_torch[indices] | |
product_sum = torch.sum(todo_salience * todo_cents_mapping, 1) | |
weight_sum = torch.sum(todo_salience, 1) | |
divided = product_sum / weight_sum | |
maxx = torch.max(salience, 1).values | |
divided[maxx <= thred] = 0 | |
return divided | |