AnhP's picture
Upload 170 files
1e4a2ab verified
raw
history blame
17.9 kB
import os
import sys
import torch
import numpy as np
import torch.nn as nn
import onnxruntime as ort
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils.parametrizations import weight_norm
sys.path.append(os.getcwd())
os.environ["LRU_CACHE_CAPACITY"] = "3"
from main.library.predictors.FCPE.wav2mel import spawn_wav2mel, Wav2Mel
from main.library.predictors.FCPE.encoder import EncoderLayer, ConformerNaiveEncoder
from main.library.predictors.FCPE.utils import l2_regularization, ensemble_f0, batch_interp_with_replacement_detach, decrypt_model, DotDict
class PCmer(nn.Module):
def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
super().__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_model = dim_model
self.dim_values = dim_values
self.dim_keys = dim_keys
self.residual_dropout = residual_dropout
self.attention_dropout = attention_dropout
self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)])
def forward(self, phone, mask=None):
for layer in self._layers:
phone = layer(phone, mask)
return phone
class CFNaiveMelPE(nn.Module):
def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False):
super().__init__()
self.input_channels = input_channels
self.out_dims = out_dims
self.hidden_dims = hidden_dims
self.n_layers = n_layers
self.n_heads = n_heads
self.f0_max = f0_max
self.f0_min = f0_min
self.use_fa_norm = use_fa_norm
self.residual_dropout = 0.1
self.attention_dropout = 0.1
self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1))
self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout)
self.norm = nn.LayerNorm(hidden_dims)
self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims))
self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach()
self.register_buffer("cent_table", self.cent_table_b)
self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
def forward(self, x, _h_emb=None):
x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
return torch.sigmoid(self.output_proj(self.norm(self.net(x))))
@torch.no_grad()
def latent2cents_decoder(self, y, threshold = 0.05, mask = True):
B, N, _ = y.size()
ci = self.cent_table[None, None, :].expand(B, N, -1)
rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
if mask:
confident = torch.max(y, dim=-1, keepdim=True)[0]
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return rtn
@torch.no_grad()
def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True):
B, N, _ = y.size()
ci = self.cent_table[None, None, :].expand(B, N, -1)
confident, max_index = torch.max(y, dim=-1, keepdim=True)
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
local_argmax_index[local_argmax_index < 0] = 0
local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
y_l = torch.gather(y, -1, local_argmax_index)
rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
if mask:
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return rtn
@torch.no_grad()
def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
latent = self.forward(mel)
if decoder == "argmax": cents = self.latent2cents_local_decoder
elif decoder == "local_argmax": cents = self.latent2cents_local_decoder
return self.cent_to_f0(cents(latent, threshold=threshold))
@torch.no_grad()
def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
return 10 * 2 ** (cent / 1200)
@torch.no_grad()
def f0_to_cent(self, f0):
return 1200 * torch.log2(f0 / 10)
class FCPE_LEGACY(nn.Module):
def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
super().__init__()
self.loss_mse_scale = loss_mse_scale
self.loss_l2_regularization = loss_l2_regularization
self.loss_l2_regularization_scale = loss_l2_regularization_scale
self.loss_grad1_mse = loss_grad1_mse
self.loss_grad1_mse_scale = loss_grad1_mse_scale
self.f0_max = f0_max
self.f0_min = f0_min
self.confidence = confidence
self.threshold = threshold
self.use_input_conv = use_input_conv
self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
self.register_buffer("cent_table", self.cent_table_b)
self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1))
self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
self.norm = nn.LayerNorm(n_chans)
self.n_out = out_dims
self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None):
if cdecoder == "argmax": self.cdecoder = self.cents_decoder
elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
if not infer:
loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
x = loss_all
else:
x = self.cent_to_f0(self.cdecoder(x))
x = (1 + x / 700).log() if not return_hz_f0 else x
if output_interp_target_length is not None:
x = F.interpolate(torch.where(x == 0, float("nan"), x).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
x = torch.where(x.isnan(), float(0.0), x)
return x
def cents_decoder(self, y, mask=True):
B, N, _ = y.size()
rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
if mask:
confident = torch.max(y, dim=-1, keepdim=True)[0]
confident_mask = torch.ones_like(confident)
confident_mask[confident <= self.threshold] = float("-INF")
rtn = rtn * confident_mask
return (rtn, confident) if self.confidence else rtn
def cents_local_decoder(self, y, mask=True):
B, N, _ = y.size()
confident, max_index = torch.max(y, dim=-1, keepdim=True)
local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
y_l = torch.gather(y, -1, local_argmax_index)
rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
if mask:
confident_mask = torch.ones_like(confident)
confident_mask[confident <= self.threshold] = float("-INF")
rtn = rtn * confident_mask
return (rtn, confident) if self.confidence else rtn
def cent_to_f0(self, cent):
return 10.0 * 2 ** (cent / 1200.0)
def f0_to_cent(self, f0):
return 1200.0 * torch.log2(f0 / 10.0)
def gaussian_blurred_cent(self, cents):
B, N, _ = cents.size()
return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
class InferCFNaiveMelPE(torch.nn.Module):
def __init__(self, args, state_dict):
super().__init__()
self.wav2mel = spawn_wav2mel(args, device="cpu")
self.model = CFNaiveMelPE(input_channels=args.mel.num_mels, out_dims=args.model.out_dims, hidden_dims=args.model.hidden_dims, n_layers=args.model.n_layers, n_heads=args.model.n_heads, f0_max=args.model.f0_max, f0_min=args.model.f0_min, use_fa_norm=args.model.use_fa_norm, conv_only=args.model.conv_only, conv_dropout=args.model.conv_dropout, atten_dropout=args.model.atten_dropout, use_harmonic_emb=False)
self.model.load_state_dict(state_dict)
self.model.eval()
self.args_dict = dict(args)
self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]):
with torch.no_grad():
mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C")
f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
return f0s
def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False):
if test_time_augmentation:
assert len(tta_key_shifts) > 0
flag = 0
if tta_use_origin_uv:
if 0 not in tta_key_shifts:
flag = 1
tta_key_shifts.append(0)
tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty)
f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0
else:
f0 = self.__call__(wav, sr, decoder_mode, threshold)
f0_for_uv = f0
if f0_min is None: f0_min = self.args_dict["model"]["f0_min"]
uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
f0 = f0 * (1 - uv)
if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1)
if f0_max is not None: f0[f0 > f0_max] = f0_max
if output_interp_target_length is not None:
f0 = F.interpolate(torch.where(f0 == 0, float("nan"), f0).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
f0 = torch.where(f0.isnan(), float(0.0), f0)
if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
else: return f0
class FCPEInfer_LEGACY:
def __init__(self, configs, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100):
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.onnx = onnx
self.f0_min = f0_min
self.f0_max = f0_max
if self.onnx:
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
else:
ckpt = torch.load(model_path, map_location=torch.device(self.device), weights_only=True)
self.args = DotDict(ckpt["config"])
model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.f0_max, f0_min=self.f0_min, confidence=self.args.model.confidence)
model.to(self.device).to(self.dtype)
model.load_state_dict(ckpt["model"])
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, audio, sr, threshold=0.05, p_len=None):
if not self.onnx: self.model.threshold = threshold
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True, output_interp_target_length=p_len))
class FCPEInfer:
def __init__(self, configs, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100):
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.onnx = onnx
self.f0_min = f0_min
self.f0_max = f0_max
if self.onnx:
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
else:
ckpt = torch.load(model_path, map_location=torch.device(device), weights_only=True)
ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
self.args = DotDict(ckpt["config_dict"])
model = InferCFNaiveMelPE(self.args, ckpt["model"])
model = model.to(device).to(self.dtype)
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, audio, sr, threshold=0.05, p_len=None):
if self.onnx: self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=self.f0_min, f0_max=self.f0_max, output_interp_target_length=p_len))
class FCPE:
def __init__(self, configs, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=16000, threshold=0.05, providers=None, onnx=False, legacy=False):
self.model = FCPEInfer_LEGACY if legacy else FCPEInfer
self.fcpe = self.model(configs, model_path, device=device, dtype=dtype, providers=providers, onnx=onnx, f0_min=f0_min, f0_max=f0_max)
self.hop_length = hop_length
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.threshold = threshold
self.sample_rate = sample_rate
self.dtype = dtype
self.legacy = legacy
def compute_f0(self, wav, p_len=None):
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len)
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
return f0.cpu().numpy()