AnhP's picture
Upload 170 files
1e4a2ab verified
raw
history blame
3.84 kB
import os
import torch
from torch import nn
from io import BytesIO
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
def decrypt_model(configs, input_path):
with open(input_path, "rb") as f:
data = f.read()
with open(os.path.join(configs["binary_path"], "decrypt.bin"), "rb") as f:
key = f.read()
return BytesIO(unpad(AES.new(key, AES.MODE_CBC, data[:16]).decrypt(data[16:]), AES.block_size)).read()
def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
def l2_regularization(model, l2_alpha):
l2_loss = []
for module in model.modules():
if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
return l2_alpha * sum(l2_loss)
def torch_interp(x, xp, fp):
sort_idx = torch.argsort(xp)
xp = xp[sort_idx]
fp = fp[sort_idx]
right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
left_idxs = (right_idxs - 1).clamp(min=0)
x_left = xp[left_idxs]
y_left = fp[left_idxs]
interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
interp_vals[x < xp[0]] = fp[0]
interp_vals[x > xp[-1]] = fp[-1]
return interp_vals
def batch_interp_with_replacement_detach(uv, f0):
result = f0.clone()
for i in range(uv.shape[0]):
interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach()
result[i][uv[i]] = interp_vals
return result
def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
device = f0s.device
f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12))
notes = torch.log2(f0s / 440) * 12 + 69
notes[notes < 0] = 0
uv_penalty = tta_uv_penalty**2
dp = torch.zeros_like(notes, device=device)
backtrack = torch.zeros_like(notes, device=device).long()
dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
for t in range(1, notes.size(1)):
penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device)
t_uv = notes[:, t, :] <= 0
penalty += uv_penalty * t_uv.unsqueeze(1)
t1_uv = notes[:, t - 1, :] <= 0
l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5
l2 = l2 * (l2 > 0)
penalty += l2
penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1)
dp[:, t, :] = min_value
backtrack[:, t, :] = min_indices
t = f0s.size(1) - 1
f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
min_indices = torch.argmin(dp[:, t, :], dim=-1)
for i in range(0, t + 1):
f0_result[:, t - i] = f0s[:, t - i, min_indices]
min_indices = backtrack[:, t - i, min_indices]
return f0_result.unsqueeze(-1)
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()
class Transpose(nn.Module):
def __init__(self, dims):
super().__init__()
assert len(dims) == 2, "dims == 2"
self.dims = dims
def forward(self, x):
return x.transpose(*self.dims)
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()