Spaces:
Build error
Build error
import os | |
import sys | |
import math | |
import torch | |
import random | |
import numpy as np | |
import typing as tp | |
from torch import nn | |
from einops import rearrange | |
from fractions import Fraction | |
from torch.nn import functional as F | |
now_dir = os.getcwd() | |
sys.path.append(now_dir) | |
from .states import capture_init | |
from .demucs import rescale_module | |
from main.configs.config import Config | |
from .hdemucs import pad1d, spectro, ispectro, wiener, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer | |
translations = Config().translations | |
def create_sin_embedding(length: int, dim: int, shift: int = 0, device="cpu", max_period=10000): | |
assert dim % 2 == 0 | |
pos = shift + torch.arange(length, device=device).view(-1, 1, 1) | |
half_dim = dim // 2 | |
adim = torch.arange(dim // 2, device=device).view(1, 1, -1) | |
phase = pos / (max_period ** (adim / (half_dim - 1))) | |
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) | |
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): | |
if d_model % 4 != 0: raise ValueError(translations["dims"].format(dims=d_model)) | |
pe = torch.zeros(d_model, height, width) | |
d_model = int(d_model / 2) | |
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)) | |
pos_w = torch.arange(0.0, width).unsqueeze(1) | |
pos_h = torch.arange(0.0, height).unsqueeze(1) | |
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) | |
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) | |
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) | |
pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) | |
return pe[None, :].to(device) | |
def create_sin_embedding_cape( length: int, dim: int, batch_size: int, mean_normalize: bool, augment: bool, max_global_shift: float = 0.0, max_local_shift: float = 0.0, max_scale: float = 1.0, device: str = "cpu", max_period: float = 10000.0): | |
assert dim % 2 == 0 | |
pos = 1.0 * torch.arange(length).view(-1, 1, 1) | |
pos = pos.repeat(1, batch_size, 1) | |
if mean_normalize: pos -= torch.nanmean(pos, dim=0, keepdim=True) | |
if augment: | |
delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1]) | |
delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1]) | |
log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]) | |
pos = (pos + delta + delta_local) * np.exp(log_lambdas) | |
pos = pos.to(device) | |
half_dim = dim // 2 | |
adim = torch.arange(dim // 2, device=device).view(1, 1, -1) | |
phase = pos / (max_period ** (adim / (half_dim - 1))) | |
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float() | |
class MyGroupNorm(nn.GroupNorm): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def forward(self, x): | |
x = x.transpose(1, 2) | |
return super().forward(x).transpose(1, 2) | |
class LayerScale(nn.Module): | |
def __init__(self, channels: int, init: float = 0, channel_last=False): | |
super().__init__() | |
self.channel_last = channel_last | |
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) | |
self.scale.data[:] = init | |
def forward(self, x): | |
if self.channel_last: return self.scale * x | |
else: return self.scale[:, None] * x | |
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): | |
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, group_norm=0, norm_first=False, norm_out=False, layer_norm_eps=1e-5, layer_scale=False, init_values=1e-4, device=None, dtype=None, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, auto_sparsity=False, sparsity=0.95, batch_first=False): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, device=device, dtype=dtype) | |
self.auto_sparsity = auto_sparsity | |
if group_norm: | |
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) | |
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) | |
self.norm_out = None | |
if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) | |
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() | |
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() | |
def forward(self, src, src_mask=None, src_key_padding_mask=None): | |
x = src | |
T, B, C = x.shape | |
if self.norm_first: | |
x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) | |
x = x + self.gamma_2(self._ff_block(self.norm2(x))) | |
if self.norm_out: x = self.norm_out(x) | |
else: | |
x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))) | |
x = self.norm2(x + self.gamma_2(self._ff_block(x))) | |
return x | |
class CrossTransformerEncoder(nn.Module): | |
def __init__(self, dim: int, emb: str = "sin", hidden_scale: float = 4.0, num_heads: int = 8, num_layers: int = 6, cross_first: bool = False, dropout: float = 0.0, max_positions: int = 1000, norm_in: bool = True, norm_in_group: bool = False, group_norm: int = False, norm_first: bool = False, norm_out: bool = False, max_period: float = 10000.0, weight_decay: float = 0.0, lr: tp.Optional[float] = None, layer_scale: bool = False, gelu: bool = True, sin_random_shift: int = 0, weight_pos_embed: float = 1.0, cape_mean_normalize: bool = True, cape_augment: bool = True, cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], sparse_self_attn: bool = False, sparse_cross_attn: bool = False, mask_type: str = "diag", mask_random_seed: int = 42, sparse_attn_window: int = 500, global_window: int = 50, auto_sparsity: bool = False, sparsity: float = 0.95): | |
super().__init__() | |
assert dim % num_heads == 0 | |
hidden_dim = int(dim * hidden_scale) | |
self.num_layers = num_layers | |
self.classic_parity = 1 if cross_first else 0 | |
self.emb = emb | |
self.max_period = max_period | |
self.weight_decay = weight_decay | |
self.weight_pos_embed = weight_pos_embed | |
self.sin_random_shift = sin_random_shift | |
if emb == "cape": | |
self.cape_mean_normalize = cape_mean_normalize | |
self.cape_augment = cape_augment | |
self.cape_glob_loc_scale = cape_glob_loc_scale | |
if emb == "scaled": self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) | |
self.lr = lr | |
activation: tp.Any = F.gelu if gelu else F.relu | |
self.norm_in: nn.Module | |
self.norm_in_t: nn.Module | |
if norm_in: | |
self.norm_in = nn.LayerNorm(dim) | |
self.norm_in_t = nn.LayerNorm(dim) | |
elif norm_in_group: | |
self.norm_in = MyGroupNorm(int(norm_in_group), dim) | |
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) | |
else: | |
self.norm_in = nn.Identity() | |
self.norm_in_t = nn.Identity() | |
self.layers = nn.ModuleList() | |
self.layers_t = nn.ModuleList() | |
kwargs_common = { | |
"d_model": dim, | |
"nhead": num_heads, | |
"dim_feedforward": hidden_dim, | |
"dropout": dropout, | |
"activation": activation, | |
"group_norm": group_norm, | |
"norm_first": norm_first, | |
"norm_out": norm_out, | |
"layer_scale": layer_scale, | |
"mask_type": mask_type, | |
"mask_random_seed": mask_random_seed, | |
"sparse_attn_window": sparse_attn_window, | |
"global_window": global_window, | |
"sparsity": sparsity, | |
"auto_sparsity": auto_sparsity, | |
"batch_first": True, | |
} | |
kwargs_classic_encoder = dict(kwargs_common) | |
kwargs_classic_encoder.update({"sparse": sparse_self_attn}) | |
kwargs_cross_encoder = dict(kwargs_common) | |
kwargs_cross_encoder.update({"sparse": sparse_cross_attn}) | |
for idx in range(num_layers): | |
if idx % 2 == self.classic_parity: | |
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) | |
self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) | |
else: | |
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) | |
self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) | |
def forward(self, x, xt): | |
B, C, Fr, T1 = x.shape | |
pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period) | |
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") | |
x = rearrange(x, "b c fr t1 -> b (t1 fr) c") | |
x = self.norm_in(x) | |
x = x + self.weight_pos_embed * pos_emb_2d | |
B, C, T2 = xt.shape | |
xt = rearrange(xt, "b c t2 -> b t2 c") | |
pos_emb = self._get_pos_embedding(T2, B, C, x.device) | |
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") | |
xt = self.norm_in_t(xt) | |
xt = xt + self.weight_pos_embed * pos_emb | |
for idx in range(self.num_layers): | |
if idx % 2 == self.classic_parity: | |
x = self.layers[idx](x) | |
xt = self.layers_t[idx](xt) | |
else: | |
old_x = x | |
x = self.layers[idx](x, xt) | |
xt = self.layers_t[idx](xt, old_x) | |
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) | |
xt = rearrange(xt, "b t2 c -> b c t2") | |
return x, xt | |
def _get_pos_embedding(self, T, B, C, device): | |
if self.emb == "sin": | |
shift = random.randrange(self.sin_random_shift + 1) | |
pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period) | |
elif self.emb == "cape": | |
if self.training: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=self.cape_augment, max_global_shift=self.cape_glob_loc_scale[0], max_local_shift=self.cape_glob_loc_scale[1], max_scale=self.cape_glob_loc_scale[2]) | |
else: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False) | |
elif self.emb == "scaled": | |
pos = torch.arange(T, device=device) | |
pos_emb = self.position_embeddings(pos)[:, None] | |
return pos_emb | |
def make_optim_group(self): | |
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} | |
if self.lr is not None: group["lr"] = self.lr | |
return group | |
class CrossTransformerEncoderLayer(nn.Module): | |
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation=F.relu, layer_norm_eps: float = 1e-5, layer_scale: bool = False, init_values: float = 1e-4, norm_first: bool = False, group_norm: bool = False, norm_out: bool = False, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, sparsity=0.95, auto_sparsity=None, device=None, dtype=None, batch_first=False): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.auto_sparsity = auto_sparsity | |
self.cross_attn: nn.Module | |
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) | |
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) | |
self.norm_first = norm_first | |
self.norm1: nn.Module | |
self.norm2: nn.Module | |
self.norm3: nn.Module | |
if group_norm: | |
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) | |
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) | |
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) | |
else: | |
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) | |
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) | |
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) | |
self.norm_out = None | |
if self.norm_first & norm_out: | |
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) | |
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() | |
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
if isinstance(activation, str): self.activation = self._get_activation_fn(activation) | |
else: self.activation = activation | |
def forward(self, q, k, mask=None): | |
if self.norm_first: | |
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) | |
x = x + self.gamma_2(self._ff_block(self.norm3(x))) | |
if self.norm_out: x = self.norm_out(x) | |
else: | |
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) | |
x = self.norm2(x + self.gamma_2(self._ff_block(x))) | |
return x | |
def _ca_block(self, q, k, attn_mask=None): | |
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] | |
return self.dropout1(x) | |
def _ff_block(self, x): | |
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
return self.dropout2(x) | |
def _get_activation_fn(self, activation): | |
if activation == "relu": return F.relu | |
elif activation == "gelu": return F.gelu | |
raise RuntimeError(translations["activation"].format(activation=activation)) | |
class HTDemucs(nn.Module): | |
def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=4, rewrite=True, multi_freqs=None, multi_freqs_depth=3, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=8, dconv_init=1e-3, bottom_channels=0, t_layers=5, t_emb="sin", t_hidden_scale=4.0, t_heads=8, t_dropout=0.0, t_max_positions=10000, t_norm_in=True, t_norm_in_group=False, t_group_norm=False, t_norm_first=True, t_norm_out=True, t_max_period=10000.0, t_weight_decay=0.0, t_lr=None, t_layer_scale=True, t_gelu=True, t_weight_pos_embed=1.0, t_sin_random_shift=0, t_cape_mean_normalize=True, t_cape_augment=True, t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], t_sparse_self_attn=False, t_sparse_cross_attn=False, t_mask_type="diag", t_mask_random_seed=42, t_sparse_attn_window=500, t_global_window=100, t_sparsity=0.95, t_auto_sparsity=False, t_cross_first=False, rescale=0.1, samplerate=44100, segment=4 * 10, use_train_segment=True): | |
super().__init__() | |
self.cac = cac | |
self.wiener_residual = wiener_residual | |
self.audio_channels = audio_channels | |
self.sources = sources | |
self.kernel_size = kernel_size | |
self.context = context | |
self.stride = stride | |
self.depth = depth | |
self.bottom_channels = bottom_channels | |
self.channels = channels | |
self.samplerate = samplerate | |
self.segment = segment | |
self.use_train_segment = use_train_segment | |
self.nfft = nfft | |
self.hop_length = nfft // 4 | |
self.wiener_iters = wiener_iters | |
self.end_iters = end_iters | |
self.freq_emb = None | |
assert wiener_iters == end_iters | |
self.encoder = nn.ModuleList() | |
self.decoder = nn.ModuleList() | |
self.tencoder = nn.ModuleList() | |
self.tdecoder = nn.ModuleList() | |
chin = audio_channels | |
chin_z = chin | |
if self.cac: chin_z *= 2 | |
chout = channels_time or channels | |
chout_z = channels | |
freqs = nfft // 2 | |
for index in range(depth): | |
norm = index >= norm_starts | |
freq = freqs > 1 | |
stri = stride | |
ker = kernel_size | |
if not freq: | |
assert freqs == 1 | |
ker = time_stride * 2 | |
stri = time_stride | |
pad = True | |
last_freq = False | |
if freq and freqs <= kernel_size: | |
ker = freqs | |
pad = False | |
last_freq = True | |
kw = { | |
"kernel_size": ker, | |
"stride": stri, | |
"freq": freq, | |
"pad": pad, | |
"norm": norm, | |
"rewrite": rewrite, | |
"norm_groups": norm_groups, | |
"dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True}, | |
} | |
kwt = dict(kw) | |
kwt["freq"] = 0 | |
kwt["kernel_size"] = kernel_size | |
kwt["stride"] = stride | |
kwt["pad"] = True | |
kw_dec = dict(kw) | |
multi = False | |
if multi_freqs and index < multi_freqs_depth: | |
multi = True | |
kw_dec["context_freq"] = False | |
if last_freq: | |
chout_z = max(chout, chout_z) | |
chout = chout_z | |
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw) | |
if freq: | |
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt) | |
self.tencoder.append(tenc) | |
if multi: enc = MultiWrap(enc, multi_freqs) | |
self.encoder.append(enc) | |
if index == 0: | |
chin = self.audio_channels * len(self.sources) | |
chin_z = chin | |
if self.cac: chin_z *= 2 | |
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec) | |
if multi: | |
dec = MultiWrap(dec, multi_freqs) | |
if freq: | |
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt) | |
self.tdecoder.insert(0, tdec) | |
self.decoder.insert(0, dec) | |
chin = chout | |
chin_z = chout_z | |
chout = int(growth * chout) | |
chout_z = int(growth * chout_z) | |
if freq: | |
if freqs <= kernel_size: freqs = 1 | |
else: freqs //= stride | |
if index == 0 and freq_emb: | |
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale) | |
self.freq_emb_scale = freq_emb | |
if rescale: rescale_module(self, reference=rescale) | |
transformer_channels = channels * growth ** (depth - 1) | |
if bottom_channels: | |
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) | |
self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1) | |
self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1) | |
self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1) | |
transformer_channels = bottom_channels | |
if t_layers > 0: self.crosstransformer = CrossTransformerEncoder(dim=transformer_channels, emb=t_emb, hidden_scale=t_hidden_scale, num_heads=t_heads, num_layers=t_layers, cross_first=t_cross_first, dropout=t_dropout, max_positions=t_max_positions, norm_in=t_norm_in, norm_in_group=t_norm_in_group, group_norm=t_group_norm, norm_first=t_norm_first, norm_out=t_norm_out, max_period=t_max_period, weight_decay=t_weight_decay, lr=t_lr, layer_scale=t_layer_scale, gelu=t_gelu, sin_random_shift=t_sin_random_shift, weight_pos_embed=t_weight_pos_embed, cape_mean_normalize=t_cape_mean_normalize, cape_augment=t_cape_augment, cape_glob_loc_scale=t_cape_glob_loc_scale, sparse_self_attn=t_sparse_self_attn, sparse_cross_attn=t_sparse_cross_attn, mask_type=t_mask_type, mask_random_seed=t_mask_random_seed, sparse_attn_window=t_sparse_attn_window, global_window=t_global_window, sparsity=t_sparsity, auto_sparsity=t_auto_sparsity) | |
else: self.crosstransformer = None | |
def _spec(self, x): | |
hl = self.hop_length | |
nfft = self.nfft | |
assert hl == nfft // 4 | |
le = int(math.ceil(x.shape[-1] / hl)) | |
pad = hl // 2 * 3 | |
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") | |
z = spectro(x, nfft, hl)[..., :-1, :] | |
assert z.shape[-1] == le + 4, (z.shape, x.shape, le) | |
z = z[..., 2 : 2 + le] | |
return z | |
def _ispec(self, z, length=None, scale=0): | |
hl = self.hop_length // (4**scale) | |
z = F.pad(z, (0, 0, 0, 1)) | |
z = F.pad(z, (2, 2)) | |
pad = hl // 2 * 3 | |
le = hl * int(math.ceil(length / hl)) + 2 * pad | |
x = ispectro(z, hl, length=le) | |
x = x[..., pad : pad + length] | |
return x | |
def _magnitude(self, z): | |
if self.cac: | |
B, C, Fr, T = z.shape | |
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) | |
m = m.reshape(B, C * 2, Fr, T) | |
else: m = z.abs() | |
return m | |
def _mask(self, z, m): | |
niters = self.wiener_iters | |
if self.cac: | |
B, S, C, Fr, T = m.shape | |
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) | |
out = torch.view_as_complex(out.contiguous()) | |
return out | |
if self.training: niters = self.end_iters | |
if niters < 0: | |
z = z[:, None] | |
return z / (1e-8 + z.abs()) * m | |
else: return self._wiener(m, z, niters) | |
def _wiener(self, mag_out, mix_stft, niters): | |
init = mix_stft.dtype | |
wiener_win_len = 300 | |
residual = self.wiener_residual | |
B, S, C, Fq, T = mag_out.shape | |
mag_out = mag_out.permute(0, 4, 3, 2, 1) | |
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) | |
outs = [] | |
for sample in range(B): | |
pos = 0 | |
out = [] | |
for pos in range(0, T, wiener_win_len): | |
frame = slice(pos, pos + wiener_win_len) | |
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual) | |
out.append(z_out.transpose(-1, -2)) | |
outs.append(torch.cat(out, dim=0)) | |
out = torch.view_as_complex(torch.stack(outs, 0)) | |
out = out.permute(0, 4, 3, 2, 1).contiguous() | |
if residual: out = out[:, :-1] | |
assert list(out.shape) == [B, S, C, Fq, T] | |
return out.to(init) | |
def valid_length(self, length: int): | |
if not self.use_train_segment: return length | |
training_length = int(self.segment * self.samplerate) | |
if training_length < length: raise ValueError(translations["length_or_training_length"].format(length=length, training_length=training_length)) | |
return training_length | |
def forward(self, mix): | |
length = mix.shape[-1] | |
length_pre_pad = None | |
if self.use_train_segment: | |
if self.training: self.segment = Fraction(mix.shape[-1], self.samplerate) | |
else: | |
training_length = int(self.segment * self.samplerate) | |
if mix.shape[-1] < training_length: | |
length_pre_pad = mix.shape[-1] | |
mix = F.pad(mix, (0, training_length - length_pre_pad)) | |
z = self._spec(mix) | |
mag = self._magnitude(z).to(mix.device) | |
x = mag | |
B, C, Fq, T = x.shape | |
mean = x.mean(dim=(1, 2, 3), keepdim=True) | |
std = x.std(dim=(1, 2, 3), keepdim=True) | |
x = (x - mean) / (1e-5 + std) | |
xt = mix | |
meant = xt.mean(dim=(1, 2), keepdim=True) | |
stdt = xt.std(dim=(1, 2), keepdim=True) | |
xt = (xt - meant) / (1e-5 + stdt) | |
saved = [] | |
saved_t = [] | |
lengths = [] | |
lengths_t = [] | |
for idx, encode in enumerate(self.encoder): | |
lengths.append(x.shape[-1]) | |
inject = None | |
if idx < len(self.tencoder): | |
lengths_t.append(xt.shape[-1]) | |
tenc = self.tencoder[idx] | |
xt = tenc(xt) | |
if not tenc.empty: saved_t.append(xt) | |
else: inject = xt | |
x = encode(x, inject) | |
if idx == 0 and self.freq_emb is not None: | |
frs = torch.arange(x.shape[-2], device=x.device) | |
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) | |
x = x + self.freq_emb_scale * emb | |
saved.append(x) | |
if self.crosstransformer: | |
if self.bottom_channels: | |
b, c, f, t = x.shape | |
x = rearrange(x, "b c f t-> b c (f t)") | |
x = self.channel_upsampler(x) | |
x = rearrange(x, "b c (f t)-> b c f t", f=f) | |
xt = self.channel_upsampler_t(xt) | |
x, xt = self.crosstransformer(x, xt) | |
if self.bottom_channels: | |
x = rearrange(x, "b c f t-> b c (f t)") | |
x = self.channel_downsampler(x) | |
x = rearrange(x, "b c (f t)-> b c f t", f=f) | |
xt = self.channel_downsampler_t(xt) | |
for idx, decode in enumerate(self.decoder): | |
skip = saved.pop(-1) | |
x, pre = decode(x, skip, lengths.pop(-1)) | |
offset = self.depth - len(self.tdecoder) | |
if idx >= offset: | |
tdec = self.tdecoder[idx - offset] | |
length_t = lengths_t.pop(-1) | |
if tdec.empty: | |
assert pre.shape[2] == 1, pre.shape | |
pre = pre[:, :, 0] | |
xt, _ = tdec(pre, None, length_t) | |
else: | |
skip = saved_t.pop(-1) | |
xt, _ = tdec(xt, skip, length_t) | |
assert len(saved) == 0 | |
assert len(lengths_t) == 0 | |
assert len(saved_t) == 0 | |
S = len(self.sources) | |
x = x.view(B, S, -1, Fq, T) | |
x = x * std[:, None] + mean[:, None] | |
device_type = x.device.type | |
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type | |
x_is_other_gpu = not device_type in ["cuda", "cpu"] | |
if x_is_other_gpu: x = x.cpu() | |
zout = self._mask(z, x) | |
if self.use_train_segment: x = self._ispec(zout, length) if self.training else self._ispec(zout, training_length) | |
else: x = self._ispec(zout, length) | |
if x_is_other_gpu: x = x.to(device_load) | |
if self.use_train_segment: xt = xt.view(B, S, -1, length) if self.training else xt.view(B, S, -1, training_length) | |
else: xt = xt.view(B, S, -1, length) | |
xt = xt * stdt[:, None] + meant[:, None] | |
x = xt + x | |
if length_pre_pad: x = x[..., :length_pre_pad] | |
return x |