File size: 17,927 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
import os
import sys
import torch
import numpy as np
import torch.nn as nn
import onnxruntime as ort
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils.parametrizations import weight_norm
sys.path.append(os.getcwd())
os.environ["LRU_CACHE_CAPACITY"] = "3"
from main.library.predictors.FCPE.wav2mel import spawn_wav2mel, Wav2Mel
from main.library.predictors.FCPE.encoder import EncoderLayer, ConformerNaiveEncoder
from main.library.predictors.FCPE.utils import l2_regularization, ensemble_f0, batch_interp_with_replacement_detach, decrypt_model, DotDict
class PCmer(nn.Module):
def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
super().__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_model = dim_model
self.dim_values = dim_values
self.dim_keys = dim_keys
self.residual_dropout = residual_dropout
self.attention_dropout = attention_dropout
self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)])
def forward(self, phone, mask=None):
for layer in self._layers:
phone = layer(phone, mask)
return phone
class CFNaiveMelPE(nn.Module):
def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False):
super().__init__()
self.input_channels = input_channels
self.out_dims = out_dims
self.hidden_dims = hidden_dims
self.n_layers = n_layers
self.n_heads = n_heads
self.f0_max = f0_max
self.f0_min = f0_min
self.use_fa_norm = use_fa_norm
self.residual_dropout = 0.1
self.attention_dropout = 0.1
self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1))
self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout)
self.norm = nn.LayerNorm(hidden_dims)
self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims))
self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach()
self.register_buffer("cent_table", self.cent_table_b)
self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
def forward(self, x, _h_emb=None):
x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
return torch.sigmoid(self.output_proj(self.norm(self.net(x))))
@torch.no_grad()
def latent2cents_decoder(self, y, threshold = 0.05, mask = True):
B, N, _ = y.size()
ci = self.cent_table[None, None, :].expand(B, N, -1)
rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
if mask:
confident = torch.max(y, dim=-1, keepdim=True)[0]
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return rtn
@torch.no_grad()
def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True):
B, N, _ = y.size()
ci = self.cent_table[None, None, :].expand(B, N, -1)
confident, max_index = torch.max(y, dim=-1, keepdim=True)
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
local_argmax_index[local_argmax_index < 0] = 0
local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
y_l = torch.gather(y, -1, local_argmax_index)
rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
if mask:
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return rtn
@torch.no_grad()
def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
latent = self.forward(mel)
if decoder == "argmax": cents = self.latent2cents_local_decoder
elif decoder == "local_argmax": cents = self.latent2cents_local_decoder
return self.cent_to_f0(cents(latent, threshold=threshold))
@torch.no_grad()
def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
return 10 * 2 ** (cent / 1200)
@torch.no_grad()
def f0_to_cent(self, f0):
return 1200 * torch.log2(f0 / 10)
class FCPE_LEGACY(nn.Module):
def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
super().__init__()
self.loss_mse_scale = loss_mse_scale
self.loss_l2_regularization = loss_l2_regularization
self.loss_l2_regularization_scale = loss_l2_regularization_scale
self.loss_grad1_mse = loss_grad1_mse
self.loss_grad1_mse_scale = loss_grad1_mse_scale
self.f0_max = f0_max
self.f0_min = f0_min
self.confidence = confidence
self.threshold = threshold
self.use_input_conv = use_input_conv
self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
self.register_buffer("cent_table", self.cent_table_b)
self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1))
self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
self.norm = nn.LayerNorm(n_chans)
self.n_out = out_dims
self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None):
if cdecoder == "argmax": self.cdecoder = self.cents_decoder
elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
if not infer:
loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
x = loss_all
else:
x = self.cent_to_f0(self.cdecoder(x))
x = (1 + x / 700).log() if not return_hz_f0 else x
if output_interp_target_length is not None:
x = F.interpolate(torch.where(x == 0, float("nan"), x).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
x = torch.where(x.isnan(), float(0.0), x)
return x
def cents_decoder(self, y, mask=True):
B, N, _ = y.size()
rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
if mask:
confident = torch.max(y, dim=-1, keepdim=True)[0]
confident_mask = torch.ones_like(confident)
confident_mask[confident <= self.threshold] = float("-INF")
rtn = rtn * confident_mask
return (rtn, confident) if self.confidence else rtn
def cents_local_decoder(self, y, mask=True):
B, N, _ = y.size()
confident, max_index = torch.max(y, dim=-1, keepdim=True)
local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
y_l = torch.gather(y, -1, local_argmax_index)
rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
if mask:
confident_mask = torch.ones_like(confident)
confident_mask[confident <= self.threshold] = float("-INF")
rtn = rtn * confident_mask
return (rtn, confident) if self.confidence else rtn
def cent_to_f0(self, cent):
return 10.0 * 2 ** (cent / 1200.0)
def f0_to_cent(self, f0):
return 1200.0 * torch.log2(f0 / 10.0)
def gaussian_blurred_cent(self, cents):
B, N, _ = cents.size()
return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
class InferCFNaiveMelPE(torch.nn.Module):
def __init__(self, args, state_dict):
super().__init__()
self.wav2mel = spawn_wav2mel(args, device="cpu")
self.model = CFNaiveMelPE(input_channels=args.mel.num_mels, out_dims=args.model.out_dims, hidden_dims=args.model.hidden_dims, n_layers=args.model.n_layers, n_heads=args.model.n_heads, f0_max=args.model.f0_max, f0_min=args.model.f0_min, use_fa_norm=args.model.use_fa_norm, conv_only=args.model.conv_only, conv_dropout=args.model.conv_dropout, atten_dropout=args.model.atten_dropout, use_harmonic_emb=False)
self.model.load_state_dict(state_dict)
self.model.eval()
self.args_dict = dict(args)
self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]):
with torch.no_grad():
mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C")
f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
return f0s
def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False):
if test_time_augmentation:
assert len(tta_key_shifts) > 0
flag = 0
if tta_use_origin_uv:
if 0 not in tta_key_shifts:
flag = 1
tta_key_shifts.append(0)
tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty)
f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0
else:
f0 = self.__call__(wav, sr, decoder_mode, threshold)
f0_for_uv = f0
if f0_min is None: f0_min = self.args_dict["model"]["f0_min"]
uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
f0 = f0 * (1 - uv)
if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1)
if f0_max is not None: f0[f0 > f0_max] = f0_max
if output_interp_target_length is not None:
f0 = F.interpolate(torch.where(f0 == 0, float("nan"), f0).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
f0 = torch.where(f0.isnan(), float(0.0), f0)
if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
else: return f0
class FCPEInfer_LEGACY:
def __init__(self, configs, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100):
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.onnx = onnx
self.f0_min = f0_min
self.f0_max = f0_max
if self.onnx:
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
else:
ckpt = torch.load(model_path, map_location=torch.device(self.device), weights_only=True)
self.args = DotDict(ckpt["config"])
model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.f0_max, f0_min=self.f0_min, confidence=self.args.model.confidence)
model.to(self.device).to(self.dtype)
model.load_state_dict(ckpt["model"])
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, audio, sr, threshold=0.05, p_len=None):
if not self.onnx: self.model.threshold = threshold
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True, output_interp_target_length=p_len))
class FCPEInfer:
def __init__(self, configs, model_path, device=None, dtype=torch.float32, providers=None, onnx=False, f0_min=50, f0_max=1100):
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.onnx = onnx
self.f0_min = f0_min
self.f0_max = f0_max
if self.onnx:
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
else:
ckpt = torch.load(model_path, map_location=torch.device(device), weights_only=True)
ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
self.args = DotDict(ckpt["config_dict"])
model = InferCFNaiveMelPE(self.args, ckpt["model"])
model = model.to(device).to(self.dtype)
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, audio, sr, threshold=0.05, p_len=None):
if self.onnx: self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=self.f0_min, f0_max=self.f0_max, output_interp_target_length=p_len))
class FCPE:
def __init__(self, configs, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=16000, threshold=0.05, providers=None, onnx=False, legacy=False):
self.model = FCPEInfer_LEGACY if legacy else FCPEInfer
self.fcpe = self.model(configs, model_path, device=device, dtype=dtype, providers=providers, onnx=onnx, f0_min=f0_min, f0_max=f0_max)
self.hop_length = hop_length
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.threshold = threshold
self.sample_rate = sample_rate
self.dtype = dtype
self.legacy = legacy
def compute_f0(self, wav, p_len=None):
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len)
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
return f0.cpu().numpy() |