AnhP's picture
Upload 65 files
98bb602 verified
raw
history blame
27.6 kB
import math
import torch
import typing as tp
from torch import nn
from copy import deepcopy
from typing import Optional
from torch.nn import functional as F
from .states import capture_init
from .demucs import DConv, rescale_module
def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape
x = x.reshape(-1, length)
device_type = x.device.type
is_other_gpu = not device_type in ["cuda", "cpu"]
if is_other_gpu: x = x.cpu()
z = torch.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=torch.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
_, freqs, frame = z.shape
return z.view(*other, freqs, frame)
def ispectro(z, hop_length=None, length=None, pad=0):
*other, freqs, frames = z.shape
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
device_type = z.device.type
is_other_gpu = not device_type in ["cuda", "cpu"]
if is_other_gpu: z = z.cpu()
x = torch.istft(z, n_fft, hop_length, window=torch.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
_, length = x.shape
return x.view(*other, length)
def atan2(y, x):
pi = 2 * torch.asin(torch.tensor(1.0))
x += ((x == 0) & (y == 0)) * 1.0
out = torch.atan(y / x)
out += ((y >= 0) & (x < 0)) * pi
out -= ((y < 0) & (x < 0)) * pi
out *= 1 - ((y > 0) & (x == 0)) * 1.0
out += ((y > 0) & (x == 0)) * (pi / 2)
out *= 1 - ((y < 0) & (x == 0)) * 1.0
out += ((y < 0) & (x == 0)) * (-pi / 2)
return out
def _norm(x: torch.Tensor) -> torch.Tensor:
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
if out is a:
real_a = a[..., 0]
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
else:
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
return out
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
if out is a:
real_a = a[..., 0]
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
else:
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
return out
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
ez = _norm(z)
if out is None or out.shape != z.shape: out = torch.zeros_like(z)
out[..., 0] = z[..., 0] / ez
out[..., 1] = -z[..., 1] / ez
return out
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None or out.shape != z.shape: out = torch.zeros_like(z)
out[..., 0] = z[..., 0]
out[..., 1] = -z[..., 1]
return out
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
nb_channels = M.shape[-2]
if out is None or out.shape != M.shape: out = torch.empty_like(M)
if nb_channels == 1: out = _inv(M, out)
elif nb_channels == 2:
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
invDet = _inv(det)
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
else: raise Exception("Torch == 2 Channels")
return out
def expectation_maximization(y: torch.Tensor, x: torch.Tensor, iterations: int = 2, eps: float = 1e-10, batch_size: int = 200):
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
nb_sources = y.shape[-1]
regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
for _ in range(iterations):
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
for j in range(nb_sources):
R[j] = torch.tensor(0.0, device=x.device)
weight = torch.tensor(eps, device=x.device)
pos: int = 0
batch_size = batch_size if batch_size else nb_frames
while pos < nb_frames:
t = torch.arange(pos, min(nb_frames, pos + batch_size))
pos = int(t[-1]) + 1
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
weight = weight + torch.sum(v[t, ..., j], dim=0)
R[j] = R[j] / weight[..., None, None, None]
weight = torch.zeros_like(weight)
if y.requires_grad: y = y.clone()
pos = 0
while pos < nb_frames:
t = torch.arange(pos, min(nb_frames, pos + batch_size))
pos = int(t[-1]) + 1
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
Cxx = regularization
for j in range(nb_sources):
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
inv_Cxx = _invert(Cxx)
for j in range(nb_sources):
gain = torch.zeros_like(inv_Cxx)
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
for index in indices:
gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
gain = gain * v[t, ..., None, None, None, j]
for i in range(nb_channels):
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
return y, v, R
def wiener(targets_spectrograms: torch.Tensor, mix_stft: torch.Tensor, iterations: int = 1, softmask: bool = False, residual: bool = False, scale_factor: float = 10.0, eps: float = 1e-10):
if softmask: y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
else:
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
nb_sources = targets_spectrograms.shape[-1]
y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
if residual: y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
if iterations == 0: return y
max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
mix_stft = mix_stft / max_abs
y = y / max_abs
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
y = y * max_abs
return y
def _covariance(y_j):
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
for index in indices:
Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
return Cj
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "constant", value: float = 0.0):
x0 = x
length = x.shape[-1]
padding_left, padding_right = paddings
if mode == "reflect":
max_pad = max(padding_left, padding_right)
if length <= max_pad:
extra_pad = max_pad - length + 1
extra_pad_right = min(padding_right, extra_pad)
extra_pad_left = extra_pad - extra_pad_right
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
x = F.pad(x, (extra_pad_left, extra_pad_right))
out = F.pad(x, paddings, mode, value)
assert out.shape[-1] == length + padding_left + padding_right
assert (out[..., padding_left : padding_left + length] == x0).all()
return out
class ScaledEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth=False):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
if smooth:
weight = torch.cumsum(self.embedding.weight.data, dim=0)
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
self.embedding.weight.data[:] = weight
self.embedding.weight.data /= scale
self.scale = scale
@property
def weight(self):
return self.embedding.weight * self.scale
def forward(self, x):
return self.embedding(x) * self.scale
class HEncLayer(nn.Module):
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
super().__init__()
norm_fn = lambda d: nn.Identity()
if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
pad = kernel_size // 4 if pad else 0
klass = nn.Conv1d
self.freq = freq
self.kernel_size = kernel_size
self.stride = stride
self.empty = empty
self.norm = norm
self.pad = pad
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
pad = [pad, 0]
klass = nn.Conv2d
self.conv = klass(chin, chout, kernel_size, stride, pad)
if self.empty: return
self.norm1 = norm_fn(chout)
self.rewrite = None
if rewrite:
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
self.norm2 = norm_fn(2 * chout)
self.dconv = None
if dconv: self.dconv = DConv(chout, **dconv_kw)
def forward(self, x, inject=None):
if not self.freq and x.dim() == 4:
B, C, Fr, T = x.shape
x = x.view(B, -1, T)
if not self.freq:
le = x.shape[-1]
if not le % self.stride == 0: x = F.pad(x, (0, self.stride - (le % self.stride)))
y = self.conv(x)
if self.empty: return y
if inject is not None:
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
if inject.dim() == 3 and y.dim() == 4: inject = inject[:, :, None]
y = y + inject
y = F.gelu(self.norm1(y))
if self.dconv:
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
if self.rewrite:
z = self.norm2(self.rewrite(y))
z = F.glu(z, dim=1)
else: z = y
return z
class MultiWrap(nn.Module):
def __init__(self, layer, split_ratios):
super().__init__()
self.split_ratios = split_ratios
self.layers = nn.ModuleList()
self.conv = isinstance(layer, HEncLayer)
assert not layer.norm
assert layer.freq
assert layer.pad
if not self.conv: assert not layer.context_freq
for _ in range(len(split_ratios) + 1):
lay = deepcopy(layer)
if self.conv: lay.conv.padding = (0, 0)
else: lay.pad = False
for m in lay.modules():
if hasattr(m, "reset_parameters"): m.reset_parameters()
self.layers.append(lay)
def forward(self, x, skip=None, length=None):
B, C, Fr, T = x.shape
ratios = list(self.split_ratios) + [1]
start = 0
outs = []
for ratio, layer in zip(ratios, self.layers):
if self.conv:
pad = layer.kernel_size // 4
if ratio == 1:
limit = Fr
frames = -1
else:
limit = int(round(Fr * ratio))
le = limit - start
if start == 0: le += pad
frames = round((le - layer.kernel_size) / layer.stride + 1)
limit = start + (frames - 1) * layer.stride + layer.kernel_size
if start == 0: limit -= pad
assert limit - start > 0, (limit, start)
assert limit <= Fr, (limit, Fr)
y = x[:, :, start:limit, :]
if start == 0: y = F.pad(y, (0, 0, pad, 0))
if ratio == 1: y = F.pad(y, (0, 0, 0, pad))
outs.append(layer(y))
start = limit - layer.kernel_size + layer.stride
else:
limit = Fr if ratio == 1 else int(round(Fr * ratio))
last = layer.last
layer.last = True
y = x[:, :, start:limit]
s = skip[:, :, start:limit]
out, _ = layer(y, s, None)
if outs:
outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
out = out[:, :, layer.stride :]
if ratio == 1: out = out[:, :, : -layer.stride // 2, :]
if start == 0: out = out[:, :, layer.stride // 2 :, :]
outs.append(out)
layer.last = last
start = limit
out = torch.cat(outs, dim=2)
if not self.conv and not last: out = F.gelu(out)
if self.conv: return out
else: return out, None
class HDecLayer(nn.Module):
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True):
super().__init__()
norm_fn = lambda d: nn.Identity()
if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
pad = kernel_size // 4 if pad else 0
self.pad = pad
self.last = last
self.freq = freq
self.chin = chin
self.empty = empty
self.stride = stride
self.kernel_size = kernel_size
self.norm = norm
self.context_freq = context_freq
klass = nn.Conv1d
klass_tr = nn.ConvTranspose1d
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
klass = nn.Conv2d
klass_tr = nn.ConvTranspose2d
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
self.norm2 = norm_fn(chout)
if self.empty: return
self.rewrite = None
if rewrite:
if context_freq: self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
else: self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
self.norm1 = norm_fn(2 * chin)
self.dconv = None
if dconv: self.dconv = DConv(chin, **dconv_kw)
def forward(self, x, skip, length):
if self.freq and x.dim() == 3:
B, C, T = x.shape
x = x.view(B, self.chin, -1, T)
if not self.empty:
x = x + skip
y = F.glu(self.norm1(self.rewrite(x)), dim=1) if self.rewrite else x
if self.dconv:
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
else:
y = x
assert skip is None
z = self.norm2(self.conv_tr(y))
if self.freq:
if self.pad: z = z[..., self.pad : -self.pad, :]
else:
z = z[..., self.pad : self.pad + length]
assert z.shape[-1] == length, (z.shape[-1], length)
if not self.last: z = F.gelu(z)
return z, y
class HDemucs(nn.Module):
@capture_init
def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=6, rewrite=True, hybrid=True, hybrid_old=False, multi_freqs=None, multi_freqs_depth=2, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, rescale=0.1, samplerate=44100, segment=4 * 10):
super().__init__()
self.cac = cac
self.wiener_residual = wiener_residual
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.depth = depth
self.channels = channels
self.samplerate = samplerate
self.segment = segment
self.nfft = nfft
self.hop_length = nfft // 4
self.wiener_iters = wiener_iters
self.end_iters = end_iters
self.freq_emb = None
self.hybrid = hybrid
self.hybrid_old = hybrid_old
if hybrid_old: assert hybrid
if hybrid: assert wiener_iters == end_iters
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
if hybrid:
self.tencoder = nn.ModuleList()
self.tdecoder = nn.ModuleList()
chin = audio_channels
chin_z = chin
if self.cac: chin_z *= 2
chout = channels_time or channels
chout_z = channels
freqs = nfft // 2
for index in range(depth):
lstm = index >= dconv_lstm
attn = index >= dconv_attn
norm = index >= norm_starts
freq = freqs > 1
stri = stride
ker = kernel_size
if not freq:
assert freqs == 1
ker = time_stride * 2
stri = time_stride
pad = True
last_freq = False
if freq and freqs <= kernel_size:
ker = freqs
pad = False
last_freq = True
kw = {
"kernel_size": ker,
"stride": stri,
"freq": freq,
"pad": pad,
"norm": norm,
"rewrite": rewrite,
"norm_groups": norm_groups,
"dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
}
kwt = dict(kw)
kwt["freq"] = 0
kwt["kernel_size"] = kernel_size
kwt["stride"] = stride
kwt["pad"] = True
kw_dec = dict(kw)
multi = False
if multi_freqs and index < multi_freqs_depth:
multi = True
kw_dec["context_freq"] = False
if last_freq:
chout_z = max(chout, chout_z)
chout = chout_z
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
if hybrid and freq:
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
self.tencoder.append(tenc)
if multi: enc = MultiWrap(enc, multi_freqs)
self.encoder.append(enc)
if index == 0:
chin = self.audio_channels * len(self.sources)
chin_z = chin
if self.cac: chin_z *= 2
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
if multi: dec = MultiWrap(dec, multi_freqs)
if hybrid and freq:
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
self.tdecoder.insert(0, tdec)
self.decoder.insert(0, dec)
chin = chout
chin_z = chout_z
chout = int(growth * chout)
chout_z = int(growth * chout_z)
if freq:
if freqs <= kernel_size: freqs = 1
else: freqs //= stride
if index == 0 and freq_emb:
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
self.freq_emb_scale = freq_emb
if rescale: rescale_module(self, reference=rescale)
def _spec(self, x):
hl = self.hop_length
nfft = self.nfft
if self.hybrid:
assert hl == nfft // 4
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") if not self.hybrid_old else pad1d(x, (pad, pad + le * hl - x.shape[-1]))
z = spectro(x, nfft, hl)[..., :-1, :]
if self.hybrid:
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
z = z[..., 2 : 2 + le]
return z
def _ispec(self, z, length=None, scale=0):
hl = self.hop_length // (4**scale)
z = F.pad(z, (0, 0, 0, 1))
if self.hybrid:
z = F.pad(z, (2, 2))
pad = hl // 2 * 3
le = hl * int(math.ceil(length / hl)) + 2 * pad if not self.hybrid_old else hl * int(math.ceil(length / hl))
x = ispectro(z, hl, length=le)
x = x[..., pad : pad + length] if not self.hybrid_old else x[..., :length]
else: x = ispectro(z, hl, length)
return x
def _magnitude(self, z):
if self.cac:
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
else: m = z.abs()
return m
def _mask(self, z, m):
niters = self.wiener_iters
if self.cac:
B, S, C, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = torch.view_as_complex(out.contiguous())
return out
if self.training: niters = self.end_iters
if niters < 0:
z = z[:, None]
return z / (1e-8 + z.abs()) * m
else: return self._wiener(m, z, niters)
def _wiener(self, mag_out, mix_stft, niters):
init = mix_stft.dtype
wiener_win_len = 300
residual = self.wiener_residual
B, S, C, Fq, T = mag_out.shape
mag_out = mag_out.permute(0, 4, 3, 2, 1)
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
outs = []
for sample in range(B):
pos = 0
out = []
for pos in range(0, T, wiener_win_len):
frame = slice(pos, pos + wiener_win_len)
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
out.append(z_out.transpose(-1, -2))
outs.append(torch.cat(out, dim=0))
out = torch.view_as_complex(torch.stack(outs, 0))
out = out.permute(0, 4, 3, 2, 1).contiguous()
if residual: out = out[:, :-1]
assert list(out.shape) == [B, S, C, Fq, T]
return out.to(init)
def forward(self, mix):
x = mix
length = x.shape[-1]
z = self._spec(mix)
mag = self._magnitude(z).to(mix.device)
x = mag
B, C, Fq, T = x.shape
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
x = (x - mean) / (1e-5 + std)
if self.hybrid:
xt = mix
meant = xt.mean(dim=(1, 2), keepdim=True)
stdt = xt.std(dim=(1, 2), keepdim=True)
xt = (xt - meant) / (1e-5 + stdt)
saved = []
saved_t = []
lengths = []
lengths_t = []
for idx, encode in enumerate(self.encoder):
lengths.append(x.shape[-1])
inject = None
if self.hybrid and idx < len(self.tencoder):
lengths_t.append(xt.shape[-1])
tenc = self.tencoder[idx]
xt = tenc(xt)
if not tenc.empty: saved_t.append(xt)
else: inject = xt
x = encode(x, inject)
if idx == 0 and self.freq_emb is not None:
frs = torch.arange(x.shape[-2], device=x.device)
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
x = x + self.freq_emb_scale * emb
saved.append(x)
x = torch.zeros_like(x)
if self.hybrid: xt = torch.zeros_like(x)
for idx, decode in enumerate(self.decoder):
skip = saved.pop(-1)
x, pre = decode(x, skip, lengths.pop(-1))
if self.hybrid: offset = self.depth - len(self.tdecoder)
if self.hybrid and idx >= offset:
tdec = self.tdecoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
assert pre.shape[2] == 1, pre.shape
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
else:
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
assert len(saved) == 0
assert len(lengths_t) == 0
assert len(saved_t) == 0
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
device_type = x.device.type
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
x_is_other_gpu = not device_type in ["cuda", "cpu"]
if x_is_other_gpu: x = x.cpu()
zout = self._mask(z, x)
x = self._ispec(zout, length)
if x_is_other_gpu: x = x.to(device_load)
if self.hybrid:
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
x = xt + x
return x