import os import sys import torch import numpy as np import torch.nn as nn import onnxruntime as ort import torch.nn.functional as F from einops import rearrange from torch.nn.utils.parametrizations import weight_norm sys.path.append(os.getcwd()) os.environ["LRU_CACHE_CAPACITY"] = "3" from main.library.predictors.FCPE.wav2mel import spawn_wav2mel, Wav2Mel from main.library.predictors.FCPE.encoder import EncoderLayer, ConformerNaiveEncoder from main.library.predictors.FCPE.utils import l2_regularization, ensemble_f0, batch_interp_with_replacement_detach, decrypt_model, DotDict class PCmer(nn.Module): def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout): super().__init__() self.num_layers = num_layers self.num_heads = num_heads self.dim_model = dim_model self.dim_values = dim_values self.dim_keys = dim_keys self.residual_dropout = residual_dropout self.attention_dropout = attention_dropout self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)]) def forward(self, phone, mask=None): for layer in self._layers: phone = layer(phone, mask) return phone class CFNaiveMelPE(nn.Module): def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False): super().__init__() self.input_channels = input_channels self.out_dims = out_dims self.hidden_dims = hidden_dims self.n_layers = n_layers self.n_heads = n_heads self.f0_max = f0_max self.f0_min = f0_min self.use_fa_norm = use_fa_norm self.residual_dropout = 0.1 self.attention_dropout = 0.1 self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1)) self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout) self.norm = nn.LayerNorm(hidden_dims) self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims)) self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach() self.register_buffer("cent_table", self.cent_table_b) self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach() self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b) def forward(self, x, _h_emb=None): x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2) if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device)) return torch.sigmoid(self.output_proj(self.norm(self.net(x)))) @torch.no_grad() def latent2cents_decoder(self, y, threshold = 0.05, mask = True): B, N, _ = y.size() ci = self.cent_table[None, None, :].expand(B, N, -1) rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) if mask: confident = torch.max(y, dim=-1, keepdim=True)[0] confident_mask = torch.ones_like(confident) confident_mask[confident <= threshold] = float("-INF") rtn = rtn * confident_mask return rtn @torch.no_grad() def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True): B, N, _ = y.size() ci = self.cent_table[None, None, :].expand(B, N, -1) confident, max_index = torch.max(y, dim=-1, keepdim=True) local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4) local_argmax_index[local_argmax_index < 0] = 0 local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1 y_l = torch.gather(y, -1, local_argmax_index) rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) if mask: confident_mask = torch.ones_like(confident) confident_mask[confident <= threshold] = float("-INF") rtn = rtn * confident_mask return rtn @torch.no_grad() def infer(self, mel, decoder = "local_argmax", threshold = 0.05): latent = self.forward(mel) if decoder == "argmax": cents = self.latent2cents_local_decoder elif decoder == "local_argmax": cents = self.latent2cents_local_decoder return self.cent_to_f0(cents(latent, threshold=threshold)) @torch.no_grad() def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor: return 10 * 2 ** (cent / 1200) @torch.no_grad() def f0_to_cent(self, f0): return 1200 * torch.log2(f0 / 10) class FCPE_LEGACY(nn.Module): def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True): super().__init__() self.loss_mse_scale = loss_mse_scale self.loss_l2_regularization = loss_l2_regularization self.loss_l2_regularization_scale = loss_l2_regularization_scale self.loss_grad1_mse = loss_grad1_mse self.loss_grad1_mse_scale = loss_grad1_mse_scale self.f0_max = f0_max self.f0_min = f0_min self.confidence = confidence self.threshold = threshold self.use_input_conv = use_input_conv self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims)) self.register_buffer("cent_table", self.cent_table_b) self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1)) self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1) self.norm = nn.LayerNorm(n_chans) self.n_out = out_dims self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out)) def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None): if cdecoder == "argmax": self.cdecoder = self.cents_decoder elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel))))) if not infer: loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0))) if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale) x = loss_all else: x = self.cent_to_f0(self.cdecoder(x)) x = (1 + x / 700).log() if not return_hz_f0 else x if output_interp_target_length is not None: x = F.interpolate(torch.where(x == 0, float("nan"), x).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2) x = torch.where(x.isnan(), float(0.0), x) return x def cents_decoder(self, y, mask=True): B, N, _ = y.size() rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) if mask: confident = torch.max(y, dim=-1, keepdim=True)[0] confident_mask = torch.ones_like(confident) confident_mask[confident <= self.threshold] = float("-INF") rtn = rtn * confident_mask return (rtn, confident) if self.confidence else rtn def cents_local_decoder(self, y, mask=True): B, N, _ = y.size() confident, max_index = torch.max(y, dim=-1, keepdim=True) local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1) y_l = torch.gather(y, -1, local_argmax_index) rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) if mask: confident_mask = torch.ones_like(confident) confident_mask[confident <= self.threshold] = float("-INF") rtn = rtn * confident_mask return (rtn, confident) if self.confidence else rtn def cent_to_f0(self, cent): return 10.0 * 2 ** (cent / 1200.0) def f0_to_cent(self, f0): return 1200.0 * torch.log2(f0 / 10.0) def gaussian_blurred_cent(self, cents): B, N, _ = cents.size() return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float() class InferCFNaiveMelPE(torch.nn.Module): def __init__(self, args, state_dict): super().__init__() self.wav2mel = spawn_wav2mel(args, device="cpu") self.model = CFNaiveMelPE(input_channels=args.mel.num_mels, out_dims=args.model.out_dims, hidden_dims=args.model.hidden_dims, n_layers=args.model.n_layers, n_heads=args.model.n_heads, f0_max=args.model.f0_max, f0_min=args.model.f0_min, use_fa_norm=args.model.use_fa_norm, conv_only=args.model.conv_only, conv_dropout=args.model.conv_dropout, atten_dropout=args.model.atten_dropout, use_harmonic_emb=False) self.model.load_state_dict(state_dict) self.model.eval() self.args_dict = dict(args) self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False) def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]): with torch.no_grad(): mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C") f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts)) return f0s def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False): if test_time_augmentation: assert len(tta_key_shifts) > 0 flag = 0 if tta_use_origin_uv: if 0 not in tta_key_shifts: flag = 1 tta_key_shifts.append(0) tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2)) f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts) f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty) f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0 else: f0 = self.__call__(wav, sr, decoder_mode, threshold) f0_for_uv = f0 if f0_min is None: f0_min = self.args_dict["model"]["f0_min"] uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype) f0 = f0 * (1 - uv) if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1) if f0_max is not None: f0[f0 > f0_max] = f0_max if output_interp_target_length is not None: f0 = F.interpolate(torch.where(f0 == 0, float("nan"), f0).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2) f0 = torch.where(f0.isnan(), float(0.0), f0) if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2) else: return f0 class FCPEInfer_LEGACY: def __init__(self, configs, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.dtype = dtype self.onnx = onnx self.f0_min = f0_min self.f0_max = f0_max if self.onnx: sess_options = ort.SessionOptions() sess_options.log_severity_level = 3 self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers) else: ckpt = torch.load(model_path, map_location=torch.device(self.device), weights_only=True) self.args = DotDict(ckpt["config"]) model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.f0_max, f0_min=self.f0_min, confidence=self.args.model.confidence) model.to(self.device).to(self.dtype) model.load_state_dict(ckpt["model"]) model.eval() self.model = model @torch.no_grad() def __call__(self, audio, sr, threshold=0.05, p_len=None): if not self.onnx: self.model.threshold = threshold self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype) return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True, output_interp_target_length=p_len)) class FCPEInfer: def __init__(self, configs, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.dtype = dtype self.onnx = onnx self.f0_min = f0_min self.f0_max = f0_max if self.onnx: sess_options = ort.SessionOptions() sess_options.log_severity_level = 3 self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers) else: ckpt = torch.load(model_path, map_location=torch.device(device), weights_only=True) ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0 self.args = DotDict(ckpt["config_dict"]) model = InferCFNaiveMelPE(self.args, ckpt["model"]) model = model.to(device).to(self.dtype) model.eval() self.model = model @torch.no_grad() def __call__(self, audio, sr, threshold=0.05, p_len=None): if self.onnx: self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype) return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=self.f0_min, f0_max=self.f0_max, output_interp_target_length=p_len)) class FCPE: def __init__(self, configs, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=16000, threshold=0.05, providers=None, onnx=False, legacy=False): self.model = FCPEInfer_LEGACY if legacy else FCPEInfer self.fcpe = self.model(configs, model_path, device=device, dtype=dtype, providers=providers, onnx=onnx, f0_min=f0_min, f0_max=f0_max) self.hop_length = hop_length self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.threshold = threshold self.sample_rate = sample_rate self.dtype = dtype self.legacy = legacy def compute_f0(self, wav, p_len=None): x = torch.FloatTensor(wav).to(self.dtype).to(self.device) p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len) f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0] if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len)) return f0.cpu().numpy()