Spaces:
Paused
Paused
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from .seq import BiGRU | |
| from .deepunet import DeepUnet | |
| from .mel import MelSpectrogram | |
| from .constants import * | |
| logger = logging.getLogger(__name__) | |
| class RMVPE(nn.Module, PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| n_blocks: int, | |
| n_gru: int, | |
| kernel_size: int, | |
| en_de_layers=5, | |
| inter_layers=4, | |
| in_channels=1, | |
| en_out_channels=16, | |
| ): | |
| super().__init__() | |
| self.mel_extractor = MelSpectrogram( | |
| N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, None, MEL_FMIN, MEL_FMAX | |
| ) | |
| self.unet = DeepUnet( | |
| kernel_size, | |
| n_blocks, | |
| en_de_layers, | |
| inter_layers, | |
| in_channels, | |
| en_out_channels, | |
| ) | |
| self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) | |
| if n_gru: | |
| self.fc = nn.Sequential( | |
| BiGRU(3 * N_MELS, 256, n_gru), | |
| nn.Linear(512, N_CLASS), | |
| nn.Dropout(0.25), | |
| nn.Sigmoid(), | |
| ) | |
| else: | |
| self.fc = nn.Sequential( | |
| nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid() | |
| ) | |
| cents_mapping = 20 * np.arange(360) + MAGIC_CONST | |
| self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368 | |
| self.cents_mapping_torch = torch.from_numpy(self.cents_mapping).to( | |
| dtype=torch.float32 | |
| ) | |
| def to(self, device): | |
| self.cents_mapping_torch = self.cents_mapping_torch.to(device) | |
| return super().to(device) | |
| def forward(self, mel: torch.Tensor) -> torch.Tensor: | |
| mel = mel.transpose(-1, -2).unsqueeze(1) | |
| x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) | |
| x = self.fc(x) | |
| return x | |
| def mel2hidden(self, mel: torch.Tensor): | |
| with torch.no_grad(): | |
| n_frames = mel.shape[2] | |
| n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames | |
| mel = F.pad(mel, (0, n_pad), mode="constant") | |
| hidden = self(mel) | |
| return hidden[:, :n_frames] | |
| def decode(self, hidden: torch.Tensor, thred=0.03): | |
| cents_pred = self.to_local_average_cents(hidden, thred=thred) | |
| f0 = 10 * (2 ** (cents_pred / 1200)) | |
| f0[f0 == 10] = 0 | |
| return f0 | |
| def infer(self, audio: torch.Tensor, thred=0.03, return_tensor=False): | |
| mel = self.mel_extractor(audio.unsqueeze(0)) | |
| hidden = self.mel2hidden(mel) | |
| hidden = hidden[0].float() | |
| f0 = self.decode(hidden, thred=thred) | |
| if return_tensor: | |
| return f0 | |
| return f0.cpu().numpy() | |
| def infer_from_audio(self, audio: np.ndarray, thred=0.03): | |
| audio = torch.from_numpy(audio).to(next(self.parameters()).device) | |
| return self.infer(audio, thred=thred) | |
| def to_local_average_cents( | |
| self, salience: torch.Tensor, thred=0.05 | |
| ) -> torch.Tensor: | |
| center = torch.argmax(salience, dim=1) | |
| salience = F.pad(salience, (4, 4)) | |
| center += 4 | |
| batch_indices = torch.arange(salience.shape[0], device=salience.device) | |
| # Create indices for the 9-point window around each center | |
| offsets = torch.arange(-4, 5, device=salience.device) | |
| indices = center.unsqueeze(1) + offsets.unsqueeze(0) | |
| # Extract values using advanced indexing | |
| todo_salience = salience[batch_indices.unsqueeze(1), indices] | |
| todo_cents_mapping = self.cents_mapping_torch[indices] | |
| product_sum = torch.sum(todo_salience * todo_cents_mapping, 1) | |
| weight_sum = torch.sum(todo_salience, 1) | |
| divided = product_sum / weight_sum | |
| maxx = torch.max(salience, 1).values | |
| divided[maxx <= thred] = 0 | |
| return divided | |