import os import sys import math import torch import random import numpy as np import typing as tp from torch import nn from einops import rearrange from fractions import Fraction from torch.nn import functional as F now_dir = os.getcwd() sys.path.append(now_dir) from .states import capture_init from .demucs import rescale_module from main.configs.config import Config from .hdemucs import pad1d, spectro, ispectro, wiener, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer translations = Config().translations def create_sin_embedding(length: int, dim: int, shift: int = 0, device="cpu", max_period=10000): assert dim % 2 == 0 pos = shift + torch.arange(length, device=device).view(-1, 1, 1) half_dim = dim // 2 adim = torch.arange(dim // 2, device=device).view(1, 1, -1) phase = pos / (max_period ** (adim / (half_dim - 1))) return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): if d_model % 4 != 0: raise ValueError(translations["dims"].format(dims=d_model)) pe = torch.zeros(d_model, height, width) d_model = int(d_model / 2) div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)) pos_w = torch.arange(0.0, width).unsqueeze(1) pos_h = torch.arange(0.0, height).unsqueeze(1) pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) return pe[None, :].to(device) def create_sin_embedding_cape( length: int, dim: int, batch_size: int, mean_normalize: bool, augment: bool, max_global_shift: float = 0.0, max_local_shift: float = 0.0, max_scale: float = 1.0, device: str = "cpu", max_period: float = 10000.0): assert dim % 2 == 0 pos = 1.0 * torch.arange(length).view(-1, 1, 1) pos = pos.repeat(1, batch_size, 1) if mean_normalize: pos -= torch.nanmean(pos, dim=0, keepdim=True) if augment: delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1]) delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1]) log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]) pos = (pos + delta + delta_local) * np.exp(log_lambdas) pos = pos.to(device) half_dim = dim // 2 adim = torch.arange(dim // 2, device=device).view(1, 1, -1) phase = pos / (max_period ** (adim / (half_dim - 1))) return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float() class MyGroupNorm(nn.GroupNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): x = x.transpose(1, 2) return super().forward(x).transpose(1, 2) class LayerScale(nn.Module): def __init__(self, channels: int, init: float = 0, channel_last=False): super().__init__() self.channel_last = channel_last self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) self.scale.data[:] = init def forward(self, x): if self.channel_last: return self.scale * x else: return self.scale[:, None] * x class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, group_norm=0, norm_first=False, norm_out=False, layer_norm_eps=1e-5, layer_scale=False, init_values=1e-4, device=None, dtype=None, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, auto_sparsity=False, sparsity=0.95, batch_first=False): factory_kwargs = {"device": device, "dtype": dtype} super().__init__(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, device=device, dtype=dtype) self.auto_sparsity = auto_sparsity if group_norm: self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm_out = None if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() def forward(self, src, src_mask=None, src_key_padding_mask=None): x = src T, B, C = x.shape if self.norm_first: x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) x = x + self.gamma_2(self._ff_block(self.norm2(x))) if self.norm_out: x = self.norm_out(x) else: x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))) x = self.norm2(x + self.gamma_2(self._ff_block(x))) return x class CrossTransformerEncoder(nn.Module): def __init__(self, dim: int, emb: str = "sin", hidden_scale: float = 4.0, num_heads: int = 8, num_layers: int = 6, cross_first: bool = False, dropout: float = 0.0, max_positions: int = 1000, norm_in: bool = True, norm_in_group: bool = False, group_norm: int = False, norm_first: bool = False, norm_out: bool = False, max_period: float = 10000.0, weight_decay: float = 0.0, lr: tp.Optional[float] = None, layer_scale: bool = False, gelu: bool = True, sin_random_shift: int = 0, weight_pos_embed: float = 1.0, cape_mean_normalize: bool = True, cape_augment: bool = True, cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], sparse_self_attn: bool = False, sparse_cross_attn: bool = False, mask_type: str = "diag", mask_random_seed: int = 42, sparse_attn_window: int = 500, global_window: int = 50, auto_sparsity: bool = False, sparsity: float = 0.95): super().__init__() assert dim % num_heads == 0 hidden_dim = int(dim * hidden_scale) self.num_layers = num_layers self.classic_parity = 1 if cross_first else 0 self.emb = emb self.max_period = max_period self.weight_decay = weight_decay self.weight_pos_embed = weight_pos_embed self.sin_random_shift = sin_random_shift if emb == "cape": self.cape_mean_normalize = cape_mean_normalize self.cape_augment = cape_augment self.cape_glob_loc_scale = cape_glob_loc_scale if emb == "scaled": self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) self.lr = lr activation: tp.Any = F.gelu if gelu else F.relu self.norm_in: nn.Module self.norm_in_t: nn.Module if norm_in: self.norm_in = nn.LayerNorm(dim) self.norm_in_t = nn.LayerNorm(dim) elif norm_in_group: self.norm_in = MyGroupNorm(int(norm_in_group), dim) self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) else: self.norm_in = nn.Identity() self.norm_in_t = nn.Identity() self.layers = nn.ModuleList() self.layers_t = nn.ModuleList() kwargs_common = { "d_model": dim, "nhead": num_heads, "dim_feedforward": hidden_dim, "dropout": dropout, "activation": activation, "group_norm": group_norm, "norm_first": norm_first, "norm_out": norm_out, "layer_scale": layer_scale, "mask_type": mask_type, "mask_random_seed": mask_random_seed, "sparse_attn_window": sparse_attn_window, "global_window": global_window, "sparsity": sparsity, "auto_sparsity": auto_sparsity, "batch_first": True, } kwargs_classic_encoder = dict(kwargs_common) kwargs_classic_encoder.update({"sparse": sparse_self_attn}) kwargs_cross_encoder = dict(kwargs_common) kwargs_cross_encoder.update({"sparse": sparse_cross_attn}) for idx in range(num_layers): if idx % 2 == self.classic_parity: self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) else: self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) def forward(self, x, xt): B, C, Fr, T1 = x.shape pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period) pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") x = rearrange(x, "b c fr t1 -> b (t1 fr) c") x = self.norm_in(x) x = x + self.weight_pos_embed * pos_emb_2d B, C, T2 = xt.shape xt = rearrange(xt, "b c t2 -> b t2 c") pos_emb = self._get_pos_embedding(T2, B, C, x.device) pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") xt = self.norm_in_t(xt) xt = xt + self.weight_pos_embed * pos_emb for idx in range(self.num_layers): if idx % 2 == self.classic_parity: x = self.layers[idx](x) xt = self.layers_t[idx](xt) else: old_x = x x = self.layers[idx](x, xt) xt = self.layers_t[idx](xt, old_x) x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) xt = rearrange(xt, "b t2 c -> b c t2") return x, xt def _get_pos_embedding(self, T, B, C, device): if self.emb == "sin": shift = random.randrange(self.sin_random_shift + 1) pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period) elif self.emb == "cape": if self.training: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=self.cape_augment, max_global_shift=self.cape_glob_loc_scale[0], max_local_shift=self.cape_glob_loc_scale[1], max_scale=self.cape_glob_loc_scale[2]) else: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False) elif self.emb == "scaled": pos = torch.arange(T, device=device) pos_emb = self.position_embeddings(pos)[:, None] return pos_emb def make_optim_group(self): group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} if self.lr is not None: group["lr"] = self.lr return group class CrossTransformerEncoderLayer(nn.Module): def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation=F.relu, layer_norm_eps: float = 1e-5, layer_scale: bool = False, init_values: float = 1e-4, norm_first: bool = False, group_norm: bool = False, norm_out: bool = False, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, sparsity=0.95, auto_sparsity=None, device=None, dtype=None, batch_first=False): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.auto_sparsity = auto_sparsity self.cross_attn: nn.Module self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1: nn.Module self.norm2: nn.Module self.norm3: nn.Module if group_norm: self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) else: self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm_out = None if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) if isinstance(activation, str): self.activation = self._get_activation_fn(activation) else: self.activation = activation def forward(self, q, k, mask=None): if self.norm_first: x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) x = x + self.gamma_2(self._ff_block(self.norm3(x))) if self.norm_out: x = self.norm_out(x) else: x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) x = self.norm2(x + self.gamma_2(self._ff_block(x))) return x def _ca_block(self, q, k, attn_mask=None): x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] return self.dropout1(x) def _ff_block(self, x): x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) def _get_activation_fn(self, activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError(translations["activation"].format(activation=activation)) class HTDemucs(nn.Module): @capture_init def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=4, rewrite=True, multi_freqs=None, multi_freqs_depth=3, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=8, dconv_init=1e-3, bottom_channels=0, t_layers=5, t_emb="sin", t_hidden_scale=4.0, t_heads=8, t_dropout=0.0, t_max_positions=10000, t_norm_in=True, t_norm_in_group=False, t_group_norm=False, t_norm_first=True, t_norm_out=True, t_max_period=10000.0, t_weight_decay=0.0, t_lr=None, t_layer_scale=True, t_gelu=True, t_weight_pos_embed=1.0, t_sin_random_shift=0, t_cape_mean_normalize=True, t_cape_augment=True, t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], t_sparse_self_attn=False, t_sparse_cross_attn=False, t_mask_type="diag", t_mask_random_seed=42, t_sparse_attn_window=500, t_global_window=100, t_sparsity=0.95, t_auto_sparsity=False, t_cross_first=False, rescale=0.1, samplerate=44100, segment=4 * 10, use_train_segment=True): super().__init__() self.cac = cac self.wiener_residual = wiener_residual self.audio_channels = audio_channels self.sources = sources self.kernel_size = kernel_size self.context = context self.stride = stride self.depth = depth self.bottom_channels = bottom_channels self.channels = channels self.samplerate = samplerate self.segment = segment self.use_train_segment = use_train_segment self.nfft = nfft self.hop_length = nfft // 4 self.wiener_iters = wiener_iters self.end_iters = end_iters self.freq_emb = None assert wiener_iters == end_iters self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() self.tencoder = nn.ModuleList() self.tdecoder = nn.ModuleList() chin = audio_channels chin_z = chin if self.cac: chin_z *= 2 chout = channels_time or channels chout_z = channels freqs = nfft // 2 for index in range(depth): norm = index >= norm_starts freq = freqs > 1 stri = stride ker = kernel_size if not freq: assert freqs == 1 ker = time_stride * 2 stri = time_stride pad = True last_freq = False if freq and freqs <= kernel_size: ker = freqs pad = False last_freq = True kw = { "kernel_size": ker, "stride": stri, "freq": freq, "pad": pad, "norm": norm, "rewrite": rewrite, "norm_groups": norm_groups, "dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True}, } kwt = dict(kw) kwt["freq"] = 0 kwt["kernel_size"] = kernel_size kwt["stride"] = stride kwt["pad"] = True kw_dec = dict(kw) multi = False if multi_freqs and index < multi_freqs_depth: multi = True kw_dec["context_freq"] = False if last_freq: chout_z = max(chout, chout_z) chout = chout_z enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw) if freq: tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt) self.tencoder.append(tenc) if multi: enc = MultiWrap(enc, multi_freqs) self.encoder.append(enc) if index == 0: chin = self.audio_channels * len(self.sources) chin_z = chin if self.cac: chin_z *= 2 dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec) if multi: dec = MultiWrap(dec, multi_freqs) if freq: tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt) self.tdecoder.insert(0, tdec) self.decoder.insert(0, dec) chin = chout chin_z = chout_z chout = int(growth * chout) chout_z = int(growth * chout_z) if freq: if freqs <= kernel_size: freqs = 1 else: freqs //= stride if index == 0 and freq_emb: self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale) self.freq_emb_scale = freq_emb if rescale: rescale_module(self, reference=rescale) transformer_channels = channels * growth ** (depth - 1) if bottom_channels: self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1) self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1) self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1) transformer_channels = bottom_channels if t_layers > 0: self.crosstransformer = CrossTransformerEncoder(dim=transformer_channels, emb=t_emb, hidden_scale=t_hidden_scale, num_heads=t_heads, num_layers=t_layers, cross_first=t_cross_first, dropout=t_dropout, max_positions=t_max_positions, norm_in=t_norm_in, norm_in_group=t_norm_in_group, group_norm=t_group_norm, norm_first=t_norm_first, norm_out=t_norm_out, max_period=t_max_period, weight_decay=t_weight_decay, lr=t_lr, layer_scale=t_layer_scale, gelu=t_gelu, sin_random_shift=t_sin_random_shift, weight_pos_embed=t_weight_pos_embed, cape_mean_normalize=t_cape_mean_normalize, cape_augment=t_cape_augment, cape_glob_loc_scale=t_cape_glob_loc_scale, sparse_self_attn=t_sparse_self_attn, sparse_cross_attn=t_sparse_cross_attn, mask_type=t_mask_type, mask_random_seed=t_mask_random_seed, sparse_attn_window=t_sparse_attn_window, global_window=t_global_window, sparsity=t_sparsity, auto_sparsity=t_auto_sparsity) else: self.crosstransformer = None def _spec(self, x): hl = self.hop_length nfft = self.nfft assert hl == nfft // 4 le = int(math.ceil(x.shape[-1] / hl)) pad = hl // 2 * 3 x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") z = spectro(x, nfft, hl)[..., :-1, :] assert z.shape[-1] == le + 4, (z.shape, x.shape, le) z = z[..., 2 : 2 + le] return z def _ispec(self, z, length=None, scale=0): hl = self.hop_length // (4**scale) z = F.pad(z, (0, 0, 0, 1)) z = F.pad(z, (2, 2)) pad = hl // 2 * 3 le = hl * int(math.ceil(length / hl)) + 2 * pad x = ispectro(z, hl, length=le) x = x[..., pad : pad + length] return x def _magnitude(self, z): if self.cac: B, C, Fr, T = z.shape m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) m = m.reshape(B, C * 2, Fr, T) else: m = z.abs() return m def _mask(self, z, m): niters = self.wiener_iters if self.cac: B, S, C, Fr, T = m.shape out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) out = torch.view_as_complex(out.contiguous()) return out if self.training: niters = self.end_iters if niters < 0: z = z[:, None] return z / (1e-8 + z.abs()) * m else: return self._wiener(m, z, niters) def _wiener(self, mag_out, mix_stft, niters): init = mix_stft.dtype wiener_win_len = 300 residual = self.wiener_residual B, S, C, Fq, T = mag_out.shape mag_out = mag_out.permute(0, 4, 3, 2, 1) mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) outs = [] for sample in range(B): pos = 0 out = [] for pos in range(0, T, wiener_win_len): frame = slice(pos, pos + wiener_win_len) z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual) out.append(z_out.transpose(-1, -2)) outs.append(torch.cat(out, dim=0)) out = torch.view_as_complex(torch.stack(outs, 0)) out = out.permute(0, 4, 3, 2, 1).contiguous() if residual: out = out[:, :-1] assert list(out.shape) == [B, S, C, Fq, T] return out.to(init) def valid_length(self, length: int): if not self.use_train_segment: return length training_length = int(self.segment * self.samplerate) if training_length < length: raise ValueError(translations["length_or_training_length"].format(length=length, training_length=training_length)) return training_length def forward(self, mix): length = mix.shape[-1] length_pre_pad = None if self.use_train_segment: if self.training: self.segment = Fraction(mix.shape[-1], self.samplerate) else: training_length = int(self.segment * self.samplerate) if mix.shape[-1] < training_length: length_pre_pad = mix.shape[-1] mix = F.pad(mix, (0, training_length - length_pre_pad)) z = self._spec(mix) mag = self._magnitude(z).to(mix.device) x = mag B, C, Fq, T = x.shape mean = x.mean(dim=(1, 2, 3), keepdim=True) std = x.std(dim=(1, 2, 3), keepdim=True) x = (x - mean) / (1e-5 + std) xt = mix meant = xt.mean(dim=(1, 2), keepdim=True) stdt = xt.std(dim=(1, 2), keepdim=True) xt = (xt - meant) / (1e-5 + stdt) saved = [] saved_t = [] lengths = [] lengths_t = [] for idx, encode in enumerate(self.encoder): lengths.append(x.shape[-1]) inject = None if idx < len(self.tencoder): lengths_t.append(xt.shape[-1]) tenc = self.tencoder[idx] xt = tenc(xt) if not tenc.empty: saved_t.append(xt) else: inject = xt x = encode(x, inject) if idx == 0 and self.freq_emb is not None: frs = torch.arange(x.shape[-2], device=x.device) emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) x = x + self.freq_emb_scale * emb saved.append(x) if self.crosstransformer: if self.bottom_channels: b, c, f, t = x.shape x = rearrange(x, "b c f t-> b c (f t)") x = self.channel_upsampler(x) x = rearrange(x, "b c (f t)-> b c f t", f=f) xt = self.channel_upsampler_t(xt) x, xt = self.crosstransformer(x, xt) if self.bottom_channels: x = rearrange(x, "b c f t-> b c (f t)") x = self.channel_downsampler(x) x = rearrange(x, "b c (f t)-> b c f t", f=f) xt = self.channel_downsampler_t(xt) for idx, decode in enumerate(self.decoder): skip = saved.pop(-1) x, pre = decode(x, skip, lengths.pop(-1)) offset = self.depth - len(self.tdecoder) if idx >= offset: tdec = self.tdecoder[idx - offset] length_t = lengths_t.pop(-1) if tdec.empty: assert pre.shape[2] == 1, pre.shape pre = pre[:, :, 0] xt, _ = tdec(pre, None, length_t) else: skip = saved_t.pop(-1) xt, _ = tdec(xt, skip, length_t) assert len(saved) == 0 assert len(lengths_t) == 0 assert len(saved_t) == 0 S = len(self.sources) x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] device_type = x.device.type device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type x_is_other_gpu = not device_type in ["cuda", "cpu"] if x_is_other_gpu: x = x.cpu() zout = self._mask(z, x) if self.use_train_segment: x = self._ispec(zout, length) if self.training else self._ispec(zout, training_length) else: x = self._ispec(zout, length) if x_is_other_gpu: x = x.to(device_load) if self.use_train_segment: xt = xt.view(B, S, -1, length) if self.training else xt.view(B, S, -1, training_length) else: xt = xt.view(B, S, -1, length) xt = xt * stdt[:, None] + meant[:, None] x = xt + x if length_pre_pad: x = x[..., :length_pre_pad] return x