import os import torch from torch import nn from io import BytesIO from Crypto.Cipher import AES from Crypto.Util.Padding import unpad def decrypt_model(configs, input_path): with open(input_path, "rb") as f: data = f.read() with open(os.path.join(configs["binary_path"], "decrypt.bin"), "rb") as f: key = f.read() return BytesIO(unpad(AES.new(key, AES.MODE_CBC, data[:16]).decrypt(data[16:]), AES.block_size)).read() def calc_same_padding(kernel_size): pad = kernel_size // 2 return (pad, pad - (kernel_size + 1) % 2) def l2_regularization(model, l2_alpha): l2_loss = [] for module in model.modules(): if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0) return l2_alpha * sum(l2_loss) def torch_interp(x, xp, fp): sort_idx = torch.argsort(xp) xp = xp[sort_idx] fp = fp[sort_idx] right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1) left_idxs = (right_idxs - 1).clamp(min=0) x_left = xp[left_idxs] y_left = fp[left_idxs] interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left)) interp_vals[x < xp[0]] = fp[0] interp_vals[x > xp[-1]] = fp[-1] return interp_vals def batch_interp_with_replacement_detach(uv, f0): result = f0.clone() for i in range(uv.shape[0]): interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach() result[i][uv[i]] = interp_vals return result def ensemble_f0(f0s, key_shift_list, tta_uv_penalty): device = f0s.device f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12)) notes = torch.log2(f0s / 440) * 12 + 69 notes[notes < 0] = 0 uv_penalty = tta_uv_penalty**2 dp = torch.zeros_like(notes, device=device) backtrack = torch.zeros_like(notes, device=device).long() dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty for t in range(1, notes.size(1)): penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device) t_uv = notes[:, t, :] <= 0 penalty += uv_penalty * t_uv.unsqueeze(1) t1_uv = notes[:, t - 1, :] <= 0 l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5 l2 = l2 * (l2 > 0) penalty += l2 penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2 min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1) dp[:, t, :] = min_value backtrack[:, t, :] = min_indices t = f0s.size(1) - 1 f0_result = torch.zeros_like(f0s[:, :, 0], device=device) min_indices = torch.argmin(dp[:, t, :], dim=-1) for i in range(0, t + 1): f0_result[:, t - i] = f0s[:, t - i, min_indices] min_indices = backtrack[:, t - i, min_indices] return f0_result.unsqueeze(-1) class DotDict(dict): def __getattr__(*args): val = dict.get(*args) return DotDict(val) if type(val) is dict else val __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ class Swish(nn.Module): def forward(self, x): return x * x.sigmoid() class Transpose(nn.Module): def __init__(self, dims): super().__init__() assert len(dims) == 2, "dims == 2" self.dims = dims def forward(self, x): return x.transpose(*self.dims) class GLU(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): out, gate = x.chunk(2, dim=self.dim) return out * gate.sigmoid()