github-actions[bot]
Sync from https://github.com/JacobLinCool/zero-rvc
2d9b22b
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from huggingface_hub import PyTorchModelHubMixin
from .seq import BiGRU
from .deepunet import DeepUnet
from .mel import MelSpectrogram
from .constants import *
logger = logging.getLogger(__name__)
class RMVPE(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
n_blocks: int,
n_gru: int,
kernel_size: int,
en_de_layers=5,
inter_layers=4,
in_channels=1,
en_out_channels=16,
):
super().__init__()
self.mel_extractor = MelSpectrogram(
N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, None, MEL_FMIN, MEL_FMAX
)
self.unet = DeepUnet(
kernel_size,
n_blocks,
en_de_layers,
inter_layers,
in_channels,
en_out_channels,
)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * N_MELS, 256, n_gru),
nn.Linear(512, N_CLASS),
nn.Dropout(0.25),
nn.Sigmoid(),
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
)
cents_mapping = 20 * np.arange(360) + MAGIC_CONST
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
self.cents_mapping_torch = torch.from_numpy(self.cents_mapping).to(
dtype=torch.float32
)
def to(self, device):
self.cents_mapping_torch = self.cents_mapping_torch.to(device)
return super().to(device)
def forward(self, mel: torch.Tensor) -> torch.Tensor:
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
return x
def mel2hidden(self, mel: torch.Tensor):
with torch.no_grad():
n_frames = mel.shape[2]
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
mel = F.pad(mel, (0, n_pad), mode="constant")
hidden = self(mel)
return hidden[:, :n_frames]
def decode(self, hidden: torch.Tensor, thred=0.03):
cents_pred = self.to_local_average_cents(hidden, thred=thred)
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
return f0
def infer(self, audio: torch.Tensor, thred=0.03, return_tensor=False):
mel = self.mel_extractor(audio.unsqueeze(0))
hidden = self.mel2hidden(mel)
hidden = hidden[0].float()
f0 = self.decode(hidden, thred=thred)
if return_tensor:
return f0
return f0.cpu().numpy()
def infer_from_audio(self, audio: np.ndarray, thred=0.03):
audio = torch.from_numpy(audio).to(next(self.parameters()).device)
return self.infer(audio, thred=thred)
def to_local_average_cents(
self, salience: torch.Tensor, thred=0.05
) -> torch.Tensor:
center = torch.argmax(salience, dim=1)
salience = F.pad(salience, (4, 4))
center += 4
batch_indices = torch.arange(salience.shape[0], device=salience.device)
# Create indices for the 9-point window around each center
offsets = torch.arange(-4, 5, device=salience.device)
indices = center.unsqueeze(1) + offsets.unsqueeze(0)
# Extract values using advanced indexing
todo_salience = salience[batch_indices.unsqueeze(1), indices]
todo_cents_mapping = self.cents_mapping_torch[indices]
product_sum = torch.sum(todo_salience * todo_cents_mapping, 1)
weight_sum = torch.sum(todo_salience, 1)
divided = product_sum / weight_sum
maxx = torch.max(salience, 1).values
divided[maxx <= thred] = 0
return divided