File size: 3,842 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import torch
from torch import nn
from io import BytesIO
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
def decrypt_model(configs, input_path):
with open(input_path, "rb") as f:
data = f.read()
with open(os.path.join(configs["binary_path"], "decrypt.bin"), "rb") as f:
key = f.read()
return BytesIO(unpad(AES.new(key, AES.MODE_CBC, data[:16]).decrypt(data[16:]), AES.block_size)).read()
def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
def l2_regularization(model, l2_alpha):
l2_loss = []
for module in model.modules():
if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
return l2_alpha * sum(l2_loss)
def torch_interp(x, xp, fp):
sort_idx = torch.argsort(xp)
xp = xp[sort_idx]
fp = fp[sort_idx]
right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
left_idxs = (right_idxs - 1).clamp(min=0)
x_left = xp[left_idxs]
y_left = fp[left_idxs]
interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
interp_vals[x < xp[0]] = fp[0]
interp_vals[x > xp[-1]] = fp[-1]
return interp_vals
def batch_interp_with_replacement_detach(uv, f0):
result = f0.clone()
for i in range(uv.shape[0]):
interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach()
result[i][uv[i]] = interp_vals
return result
def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
device = f0s.device
f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12))
notes = torch.log2(f0s / 440) * 12 + 69
notes[notes < 0] = 0
uv_penalty = tta_uv_penalty**2
dp = torch.zeros_like(notes, device=device)
backtrack = torch.zeros_like(notes, device=device).long()
dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
for t in range(1, notes.size(1)):
penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device)
t_uv = notes[:, t, :] <= 0
penalty += uv_penalty * t_uv.unsqueeze(1)
t1_uv = notes[:, t - 1, :] <= 0
l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5
l2 = l2 * (l2 > 0)
penalty += l2
penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1)
dp[:, t, :] = min_value
backtrack[:, t, :] = min_indices
t = f0s.size(1) - 1
f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
min_indices = torch.argmin(dp[:, t, :], dim=-1)
for i in range(0, t + 1):
f0_result[:, t - i] = f0s[:, t - i, min_indices]
min_indices = backtrack[:, t - i, min_indices]
return f0_result.unsqueeze(-1)
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()
class Transpose(nn.Module):
def __init__(self, dims):
super().__init__()
assert len(dims) == 2, "dims == 2"
self.dims = dims
def forward(self, x):
return x.transpose(*self.dims)
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid() |