Spaces:
Running
Running
# ***************************************************************************** | |
# MIT License | |
# | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# ***************************************************************************** | |
import math | |
import typing as tp | |
from typing import Any, Dict, List, Optional | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class _ScaledEmbedding(torch.nn.Module): | |
r"""Make continuous embeddings and boost learning rate | |
Args: | |
num_embeddings (int): number of embeddings | |
embedding_dim (int): embedding dimensions | |
scale (float, optional): amount to scale learning rate (Default: 10.0) | |
smooth (bool, optional): choose to apply smoothing (Default: ``False``) | |
""" | |
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False): | |
super().__init__() | |
self.embedding = nn.Embedding(num_embeddings, embedding_dim) | |
if smooth: | |
weight = torch.cumsum(self.embedding.weight.data, dim=0) | |
# when summing gaussian, scale raises as sqrt(n), so we normalize by that. | |
weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None] | |
self.embedding.weight.data[:] = weight | |
self.embedding.weight.data /= scale | |
self.scale = scale | |
def weight(self) -> torch.Tensor: | |
return self.embedding.weight * self.scale | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r"""Forward pass for embedding with scale. | |
Args: | |
x (torch.Tensor): input tensor of shape `(num_embeddings)` | |
Returns: | |
(Tensor): | |
Embedding output of shape `(num_embeddings, embedding_dim)` | |
""" | |
out = self.embedding(x) * self.scale | |
return out | |
class _HEncLayer(torch.nn.Module): | |
r"""Encoder layer. This used both by the time and the frequency branch. | |
Args: | |
chin (int): number of input channels. | |
chout (int): number of output channels. | |
kernel_size (int, optional): Kernel size for encoder (Default: 8) | |
stride (int, optional): Stride for encoder layer (Default: 4) | |
norm_groups (int, optional): number of groups for group norm. (Default: 4) | |
empty (bool, optional): used to make a layer with just the first conv. this is used | |
before merging the time and freq. branches. (Default: ``False``) | |
freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``) | |
norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``) | |
context (int, optional): context size for the 1x1 conv. (Default: 0) | |
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``) | |
pad (bool, optional): true to pad the input. Padding is done so that the output size is | |
always the input size / stride. (Default: ``True``) | |
""" | |
def __init__( | |
self, | |
chin: int, | |
chout: int, | |
kernel_size: int = 8, | |
stride: int = 4, | |
norm_groups: int = 4, | |
empty: bool = False, | |
freq: bool = True, | |
norm_type: str = "group_norm", | |
context: int = 0, | |
dconv_kw: Optional[Dict[str, Any]] = None, | |
pad: bool = True, | |
): | |
super().__init__() | |
if dconv_kw is None: | |
dconv_kw = {} | |
norm_fn = lambda d: nn.Identity() # noqa | |
if norm_type == "group_norm": | |
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa | |
pad_val = kernel_size // 4 if pad else 0 | |
klass = nn.Conv1d | |
self.freq = freq | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.empty = empty | |
self.pad = pad_val | |
if freq: | |
kernel_size = [kernel_size, 1] | |
stride = [stride, 1] | |
pad_val = [pad_val, 0] | |
klass = nn.Conv2d | |
self.conv = klass(chin, chout, kernel_size, stride, pad_val) | |
self.norm1 = norm_fn(chout) | |
if self.empty: | |
self.rewrite = nn.Identity() | |
self.norm2 = nn.Identity() | |
self.dconv = nn.Identity() | |
else: | |
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) | |
self.norm2 = norm_fn(2 * chout) | |
self.dconv = _DConv(chout, **dconv_kw) | |
def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor: | |
r"""Forward pass for encoding layer. | |
Size depends on whether frequency or time | |
Args: | |
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape | |
`(B, C, T)` for time | |
inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param, | |
same shape as x (default: ``None``) | |
Returns: | |
Tensor | |
output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency | |
and shape `(B, C, ceil(T / stride))` for time | |
""" | |
if not self.freq and x.dim() == 4: | |
B, C, Fr, T = x.shape | |
x = x.view(B, -1, T) | |
if not self.freq: | |
le = x.shape[-1] | |
if not le % self.stride == 0: | |
x = F.pad(x, (0, self.stride - (le % self.stride))) | |
y = self.conv(x) | |
if self.empty: | |
return y | |
if inject is not None: | |
if inject.shape[-1] != y.shape[-1]: | |
raise ValueError("Injection shapes do not align") | |
if inject.dim() == 3 and y.dim() == 4: | |
inject = inject[:, :, None] | |
y = y + inject | |
y = F.gelu(self.norm1(y)) | |
if self.freq: | |
B, C, Fr, T = y.shape | |
y = y.permute(0, 2, 1, 3).reshape(-1, C, T) | |
y = self.dconv(y) | |
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) | |
else: | |
y = self.dconv(y) | |
z = self.norm2(self.rewrite(y)) | |
z = F.glu(z, dim=1) | |
return z | |
class _HDecLayer(torch.nn.Module): | |
r"""Decoder layer. This used both by the time and the frequency branches. | |
Args: | |
chin (int): number of input channels. | |
chout (int): number of output channels. | |
last (bool, optional): whether current layer is final layer (Default: ``False``) | |
kernel_size (int, optional): Kernel size for encoder (Default: 8) | |
stride (int): Stride for encoder layer (Default: 4) | |
norm_groups (int, optional): number of groups for group norm. (Default: 1) | |
empty (bool, optional): used to make a layer with just the first conv. this is used | |
before merging the time and freq. branches. (Default: ``False``) | |
freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``) | |
norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``) | |
context (int, optional): context size for the 1x1 conv. (Default: 1) | |
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``) | |
pad (bool, optional): true to pad the input. Padding is done so that the output size is | |
always the input size / stride. (Default: ``True``) | |
""" | |
def __init__( | |
self, | |
chin: int, | |
chout: int, | |
last: bool = False, | |
kernel_size: int = 8, | |
stride: int = 4, | |
norm_groups: int = 1, | |
empty: bool = False, | |
freq: bool = True, | |
norm_type: str = "group_norm", | |
context: int = 1, | |
dconv_kw: Optional[Dict[str, Any]] = None, | |
pad: bool = True, | |
): | |
super().__init__() | |
if dconv_kw is None: | |
dconv_kw = {} | |
norm_fn = lambda d: nn.Identity() # noqa | |
if norm_type == "group_norm": | |
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa | |
if pad: | |
if (kernel_size - stride) % 2 != 0: | |
raise ValueError("Kernel size and stride do not align") | |
pad = (kernel_size - stride) // 2 | |
else: | |
pad = 0 | |
self.pad = pad | |
self.last = last | |
self.freq = freq | |
self.chin = chin | |
self.empty = empty | |
self.stride = stride | |
self.kernel_size = kernel_size | |
klass = nn.Conv1d | |
klass_tr = nn.ConvTranspose1d | |
if freq: | |
kernel_size = [kernel_size, 1] | |
stride = [stride, 1] | |
klass = nn.Conv2d | |
klass_tr = nn.ConvTranspose2d | |
self.conv_tr = klass_tr(chin, chout, kernel_size, stride) | |
self.norm2 = norm_fn(chout) | |
if self.empty: | |
self.rewrite = nn.Identity() | |
self.norm1 = nn.Identity() | |
else: | |
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) | |
self.norm1 = norm_fn(2 * chin) | |
def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length): | |
r"""Forward pass for decoding layer. | |
Size depends on whether frequency or time | |
Args: | |
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape | |
`(B, C, T)` for time | |
skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param | |
(default: ``None``) | |
length (int): Size of tensor for output | |
Returns: | |
(Tensor, Tensor): | |
Tensor | |
output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last | |
frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)` | |
for time domain. | |
Tensor | |
contains the output just before final transposed convolution, which is used when the | |
freq. and time branch separate. Otherwise, does not matter. Shape is | |
`(B, C, F, T)` for frequency and `(B, C, T)` for time. | |
""" | |
if self.freq and x.dim() == 3: | |
B, C, T = x.shape | |
x = x.view(B, self.chin, -1, T) | |
if not self.empty: | |
x = x + skip | |
y = F.glu(self.norm1(self.rewrite(x)), dim=1) | |
else: | |
y = x | |
if skip is not None: | |
raise ValueError("Skip must be none when empty is true.") | |
z = self.norm2(self.conv_tr(y)) | |
if self.freq: | |
if self.pad: | |
z = z[..., self.pad : -self.pad, :] | |
else: | |
z = z[..., self.pad : self.pad + length] | |
if z.shape[-1] != length: | |
raise ValueError("Last index of z must be equal to length") | |
if not self.last: | |
z = F.gelu(z) | |
return z, y | |
class HDemucs(torch.nn.Module): | |
r"""Hybrid Demucs model from | |
*Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`. | |
See Also: | |
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models. | |
Args: | |
sources (List[str]): list of source names. List can contain the following source | |
options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``]. | |
audio_channels (int, optional): input/output audio channels. (Default: 2) | |
channels (int, optional): initial number of hidden channels. (Default: 48) | |
growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2) | |
nfft (int, optional): number of fft bins. Note that changing this requires careful computation of | |
various shape parameters and will not work out of the box for hybrid models. (Default: 4096) | |
depth (int, optional): number of layers in encoder and decoder (Default: 6) | |
freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0, | |
the actual value controls the weight of the embedding. (Default: 0.2) | |
emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10) | |
emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies). | |
(Default: ``True``) | |
kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8) | |
time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2) | |
stride (int, optional): stride for encoder and decoder layers. (Default: 4) | |
context (int, optional): context for 1x1 conv in the decoder. (Default: 4) | |
context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0) | |
norm_starts (int, optional): layer at which group norm starts being used. | |
decoder layers are numbered in reverse order. (Default: 4) | |
norm_groups (int, optional): number of groups for group norm. (Default: 4) | |
dconv_depth (int, optional): depth of residual DConv branch. (Default: 2) | |
dconv_comp (int, optional): compression of DConv branch. (Default: 4) | |
dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4) | |
dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4) | |
dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4) | |
""" | |
def __init__( | |
self, | |
sources: List[str], | |
audio_channels: int = 2, | |
channels: int = 48, | |
growth: int = 2, | |
nfft: int = 4096, | |
depth: int = 6, | |
freq_emb: float = 0.2, | |
emb_scale: int = 10, | |
emb_smooth: bool = True, | |
kernel_size: int = 8, | |
time_stride: int = 2, | |
stride: int = 4, | |
context: int = 1, | |
context_enc: int = 0, | |
norm_starts: int = 4, | |
norm_groups: int = 4, | |
dconv_depth: int = 2, | |
dconv_comp: int = 4, | |
dconv_attn: int = 4, | |
dconv_lstm: int = 4, | |
dconv_init: float = 1e-4, | |
): | |
super().__init__() | |
self.depth = depth | |
self.nfft = nfft | |
self.audio_channels = audio_channels | |
self.sources = sources | |
self.kernel_size = kernel_size | |
self.context = context | |
self.stride = stride | |
self.channels = channels | |
self.hop_length = self.nfft // 4 | |
self.freq_emb = None | |
self.freq_encoder = nn.ModuleList() | |
self.freq_decoder = nn.ModuleList() | |
self.time_encoder = nn.ModuleList() | |
self.time_decoder = nn.ModuleList() | |
chin = audio_channels | |
chin_z = chin * 2 # number of channels for the freq branch | |
chout = channels | |
chout_z = channels | |
freqs = self.nfft // 2 | |
for index in range(self.depth): | |
lstm = index >= dconv_lstm | |
attn = index >= dconv_attn | |
norm_type = "group_norm" if index >= norm_starts else "none" | |
freq = freqs > 1 | |
stri = stride | |
ker = kernel_size | |
if not freq: | |
if freqs != 1: | |
raise ValueError("When freq is false, freqs must be 1.") | |
ker = time_stride * 2 | |
stri = time_stride | |
pad = True | |
last_freq = False | |
if freq and freqs <= kernel_size: | |
ker = freqs | |
pad = False | |
last_freq = True | |
kw = { | |
"kernel_size": ker, | |
"stride": stri, | |
"freq": freq, | |
"pad": pad, | |
"norm_type": norm_type, | |
"norm_groups": norm_groups, | |
"dconv_kw": { | |
"lstm": lstm, | |
"attn": attn, | |
"depth": dconv_depth, | |
"compress": dconv_comp, | |
"init": dconv_init, | |
}, | |
} | |
kwt = dict(kw) | |
kwt["freq"] = 0 | |
kwt["kernel_size"] = kernel_size | |
kwt["stride"] = stride | |
kwt["pad"] = True | |
kw_dec = dict(kw) | |
if last_freq: | |
chout_z = max(chout, chout_z) | |
chout = chout_z | |
enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw) | |
if freq: | |
if last_freq is True and nfft == 2048: | |
kwt["stride"] = 2 | |
kwt["kernel_size"] = 4 | |
tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt) | |
self.time_encoder.append(tenc) | |
self.freq_encoder.append(enc) | |
if index == 0: | |
chin = self.audio_channels * len(self.sources) | |
chin_z = chin * 2 | |
dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec) | |
if freq: | |
tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt) | |
self.time_decoder.insert(0, tdec) | |
self.freq_decoder.insert(0, dec) | |
chin = chout | |
chin_z = chout_z | |
chout = int(growth * chout) | |
chout_z = int(growth * chout_z) | |
if freq: | |
if freqs <= kernel_size: | |
freqs = 1 | |
else: | |
freqs //= stride | |
if index == 0 and freq_emb: | |
self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale) | |
self.freq_emb_scale = freq_emb | |
_rescale_module(self) | |
def _spec(self, x): | |
hl = self.hop_length | |
nfft = self.nfft | |
x0 = x # noqa | |
# We re-pad the signal in order to keep the property | |
# that the size of the output is exactly the size of the input | |
# divided by the stride (here hop_length), when divisible. | |
# This is achieved by padding by 1/4th of the kernel size (here nfft). | |
# which is not supported by torch.stft. | |
# Having all convolution operations follow this convention allow to easily | |
# align the time and frequency branches later on. | |
if hl != nfft // 4: | |
raise ValueError("Hop length must be nfft // 4") | |
le = int(math.ceil(x.shape[-1] / hl)) | |
pad = hl // 2 * 3 | |
x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect") | |
z = _spectro(x, nfft, hl)[..., :-1, :] | |
if z.shape[-1] != le + 4: | |
raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride") | |
z = z[..., 2 : 2 + le] | |
return z | |
def _ispec(self, z, length=None): | |
hl = self.hop_length | |
z = F.pad(z, [0, 0, 0, 1]) | |
z = F.pad(z, [2, 2]) | |
pad = hl // 2 * 3 | |
le = hl * int(math.ceil(length / hl)) + 2 * pad | |
x = _ispectro(z, hl, length=le) | |
x = x[..., pad : pad + length] | |
return x | |
def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0): | |
"""Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad. | |
Add extra zero padding around in order for padding to not break.""" | |
length = x.shape[-1] | |
if mode == "reflect": | |
max_pad = max(padding_left, padding_right) | |
if length <= max_pad: | |
x = F.pad(x, (0, max_pad - length + 1)) | |
return F.pad(x, (padding_left, padding_right), mode, value) | |
def _magnitude(self, z): | |
# move the complex dimension to the channel one. | |
B, C, Fr, T = z.shape | |
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) | |
m = m.reshape(B, C * 2, Fr, T) | |
return m | |
def _mask(self, m): | |
# `m` is a full spectrogram and `z` is ignored. | |
B, S, C, Fr, T = m.shape | |
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) | |
out = torch.view_as_complex(out.contiguous()) | |
return out | |
def forward(self, input: torch.Tensor): | |
r"""HDemucs forward call | |
Args: | |
input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)` | |
Returns: | |
Tensor | |
output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)` | |
""" | |
if input.ndim != 3: | |
raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}") | |
if input.shape[1] != self.audio_channels: | |
raise ValueError( | |
f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. " | |
f"Found:{input.shape[1]}." | |
) | |
x = input | |
length = x.shape[-1] | |
z = self._spec(input) | |
mag = self._magnitude(z) | |
x = mag | |
B, C, Fq, T = x.shape | |
# unlike previous Demucs, we always normalize because it is easier. | |
mean = x.mean(dim=(1, 2, 3), keepdim=True) | |
std = x.std(dim=(1, 2, 3), keepdim=True) | |
x = (x - mean) / (1e-5 + std) | |
# x will be the freq. branch input. | |
# Prepare the time branch input. | |
xt = input | |
meant = xt.mean(dim=(1, 2), keepdim=True) | |
stdt = xt.std(dim=(1, 2), keepdim=True) | |
xt = (xt - meant) / (1e-5 + stdt) | |
saved = [] # skip connections, freq. | |
saved_t = [] # skip connections, time. | |
lengths: List[int] = [] # saved lengths to properly remove padding, freq branch. | |
lengths_t: List[int] = [] # saved lengths for time branch. | |
for idx, encode in enumerate(self.freq_encoder): | |
lengths.append(x.shape[-1]) | |
inject = None | |
if idx < len(self.time_encoder): | |
# we have not yet merged branches. | |
lengths_t.append(xt.shape[-1]) | |
tenc = self.time_encoder[idx] | |
xt = tenc(xt) | |
if not tenc.empty: | |
# save for skip connection | |
saved_t.append(xt) | |
else: | |
# tenc contains just the first conv., so that now time and freq. | |
# branches have the same shape and can be merged. | |
inject = xt | |
x = encode(x, inject) | |
if idx == 0 and self.freq_emb is not None: | |
# add frequency embedding to allow for non equivariant convolutions | |
# over the frequency axis. | |
frs = torch.arange(x.shape[-2], device=x.device) | |
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) | |
x = x + self.freq_emb_scale * emb | |
saved.append(x) | |
x = torch.zeros_like(x) | |
xt = torch.zeros_like(x) | |
# initialize everything to zero (signal will go through u-net skips). | |
for idx, decode in enumerate(self.freq_decoder): | |
skip = saved.pop(-1) | |
x, pre = decode(x, skip, lengths.pop(-1)) | |
# `pre` contains the output just before final transposed convolution, | |
# which is used when the freq. and time branch separate. | |
offset = self.depth - len(self.time_decoder) | |
if idx >= offset: | |
tdec = self.time_decoder[idx - offset] | |
length_t = lengths_t.pop(-1) | |
if tdec.empty: | |
if pre.shape[2] != 1: | |
raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}") | |
pre = pre[:, :, 0] | |
xt, _ = tdec(pre, None, length_t) | |
else: | |
skip = saved_t.pop(-1) | |
xt, _ = tdec(xt, skip, length_t) | |
if len(saved) != 0: | |
raise AssertionError("saved is not empty") | |
if len(lengths_t) != 0: | |
raise AssertionError("lengths_t is not empty") | |
if len(saved_t) != 0: | |
raise AssertionError("saved_t is not empty") | |
S = len(self.sources) | |
x = x.view(B, S, -1, Fq, T) | |
x = x * std[:, None] + mean[:, None] | |
zout = self._mask(x) | |
x = self._ispec(zout, length) | |
xt = xt.view(B, S, -1, length) | |
xt = xt * stdt[:, None] + meant[:, None] | |
x = xt + x | |
return x | |
class _DConv(torch.nn.Module): | |
r""" | |
New residual branches in each encoder layer. | |
This alternates dilated convolutions, potentially with LSTMs and attention. | |
Also before entering each residual branch, dimension is projected on a smaller subspace, | |
e.g. of dim `channels // compress`. | |
Args: | |
channels (int): input/output channels for residual branch. | |
compress (float, optional): amount of channel compression inside the branch. (default: 4) | |
depth (int, optional): number of layers in the residual branch. Each layer has its own | |
projection, and potentially LSTM and attention.(default: 2) | |
init (float, optional): initial scale for LayerNorm. (default: 1e-4) | |
norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``) | |
attn (bool, optional): use LocalAttention. (Default: ``False``) | |
heads (int, optional): number of heads for the LocalAttention. (default: 4) | |
ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4) | |
lstm (bool, optional): use LSTM. (Default: ``False``) | |
kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3) | |
""" | |
def __init__( | |
self, | |
channels: int, | |
compress: float = 4, | |
depth: int = 2, | |
init: float = 1e-4, | |
norm_type: str = "group_norm", | |
attn: bool = False, | |
heads: int = 4, | |
ndecay: int = 4, | |
lstm: bool = False, | |
kernel_size: int = 3, | |
): | |
super().__init__() | |
if kernel_size % 2 == 0: | |
raise ValueError("Kernel size should not be divisible by 2") | |
self.channels = channels | |
self.compress = compress | |
self.depth = abs(depth) | |
dilate = depth > 0 | |
norm_fn: tp.Callable[[int], nn.Module] | |
norm_fn = lambda d: nn.Identity() # noqa | |
if norm_type == "group_norm": | |
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa | |
hidden = int(channels / compress) | |
act = nn.GELU | |
self.layers = nn.ModuleList([]) | |
for d in range(self.depth): | |
dilation = pow(2, d) if dilate else 1 | |
padding = dilation * (kernel_size // 2) | |
mods = [ | |
nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding), | |
norm_fn(hidden), | |
act(), | |
nn.Conv1d(hidden, 2 * channels, 1), | |
norm_fn(2 * channels), | |
nn.GLU(1), | |
_LayerScale(channels, init), | |
] | |
if attn: | |
mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay)) | |
if lstm: | |
mods.insert(3, _BLSTM(hidden, layers=2, skip=True)) | |
layer = nn.Sequential(*mods) | |
self.layers.append(layer) | |
def forward(self, x): | |
r"""DConv forward call | |
Args: | |
x (torch.Tensor): input tensor for convolution | |
Returns: | |
Tensor | |
Output after being run through layers. | |
""" | |
for layer in self.layers: | |
x = x + layer(x) | |
return x | |
class _BLSTM(torch.nn.Module): | |
r""" | |
BiLSTM with same hidden units as input dim. | |
If `max_steps` is not None, input will be splitting in overlapping | |
chunks and the LSTM applied separately on each chunk. | |
Args: | |
dim (int): dimensions at LSTM layer. | |
layers (int, optional): number of LSTM layers. (default: 1) | |
skip (bool, optional): (default: ``False``) | |
""" | |
def __init__(self, dim, layers: int = 1, skip: bool = False): | |
super().__init__() | |
self.max_steps = 200 | |
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) | |
self.linear = nn.Linear(2 * dim, dim) | |
self.skip = skip | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r"""BLSTM forward call | |
Args: | |
x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)` | |
Returns: | |
Tensor | |
Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)` | |
""" | |
B, C, T = x.shape | |
y = x | |
framed = False | |
width = 0 | |
stride = 0 | |
nframes = 0 | |
if self.max_steps is not None and T > self.max_steps: | |
width = self.max_steps | |
stride = width // 2 | |
frames = _unfold(x, width, stride) | |
nframes = frames.shape[2] | |
framed = True | |
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) | |
x = x.permute(2, 0, 1) | |
x = self.lstm(x)[0] | |
x = self.linear(x) | |
x = x.permute(1, 2, 0) | |
if framed: | |
out = [] | |
frames = x.reshape(B, -1, C, width) | |
limit = stride // 2 | |
for k in range(nframes): | |
if k == 0: | |
out.append(frames[:, k, :, :-limit]) | |
elif k == nframes - 1: | |
out.append(frames[:, k, :, limit:]) | |
else: | |
out.append(frames[:, k, :, limit:-limit]) | |
out = torch.cat(out, -1) | |
out = out[..., :T] | |
x = out | |
if self.skip: | |
x = x + y | |
return x | |
class _LocalState(nn.Module): | |
"""Local state allows to have attention based only on data (no positional embedding), | |
but while setting a constraint on the time window (e.g. decaying penalty term). | |
Also a failed experiments with trying to provide some frequency based attention. | |
""" | |
def __init__(self, channels: int, heads: int = 4, ndecay: int = 4): | |
r""" | |
Args: | |
channels (int): Size of Conv1d layers. | |
heads (int, optional): (default: 4) | |
ndecay (int, optional): (default: 4) | |
""" | |
super(_LocalState, self).__init__() | |
if channels % heads != 0: | |
raise ValueError("Channels must be divisible by heads.") | |
self.heads = heads | |
self.ndecay = ndecay | |
self.content = nn.Conv1d(channels, channels, 1) | |
self.query = nn.Conv1d(channels, channels, 1) | |
self.key = nn.Conv1d(channels, channels, 1) | |
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) | |
if ndecay: | |
# Initialize decay close to zero (there is a sigmoid), for maximum initial window. | |
self.query_decay.weight.data *= 0.01 | |
if self.query_decay.bias is None: | |
raise ValueError("bias must not be None.") | |
self.query_decay.bias.data[:] = -2 | |
self.proj = nn.Conv1d(channels + heads * 0, channels, 1) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r"""LocalState forward call | |
Args: | |
x (torch.Tensor): input tensor for LocalState | |
Returns: | |
Tensor | |
Output after being run through LocalState layer. | |
""" | |
B, C, T = x.shape | |
heads = self.heads | |
indexes = torch.arange(T, device=x.device, dtype=x.dtype) | |
# left index are keys, right index are queries | |
delta = indexes[:, None] - indexes[None, :] | |
queries = self.query(x).view(B, heads, -1, T) | |
keys = self.key(x).view(B, heads, -1, T) | |
# t are keys, s are queries | |
dots = torch.einsum("bhct,bhcs->bhts", keys, queries) | |
dots /= math.sqrt(keys.shape[2]) | |
if self.ndecay: | |
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) | |
decay_q = self.query_decay(x).view(B, heads, -1, T) | |
decay_q = torch.sigmoid(decay_q) / 2 | |
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay) | |
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) | |
# Kill self reference. | |
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) | |
weights = torch.softmax(dots, dim=2) | |
content = self.content(x).view(B, heads, -1, T) | |
result = torch.einsum("bhts,bhct->bhcs", weights, content) | |
result = result.reshape(B, -1, T) | |
return x + self.proj(result) | |
class _LayerScale(nn.Module): | |
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). | |
This rescales diagonally residual outputs close to 0 initially, then learnt. | |
""" | |
def __init__(self, channels: int, init: float = 0): | |
r""" | |
Args: | |
channels (int): Size of rescaling | |
init (float, optional): Scale to default to (default: 0) | |
""" | |
super().__init__() | |
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) | |
self.scale.data[:] = init | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r"""LayerScale forward call | |
Args: | |
x (torch.Tensor): input tensor for LayerScale | |
Returns: | |
Tensor | |
Output after rescaling tensor. | |
""" | |
return self.scale[:, None] * x | |
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: | |
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K] | |
with K the kernel size, by extracting frames with the given stride. | |
This will pad the input so that `F = ceil(T / K)`. | |
see https://github.com/pytorch/pytorch/issues/60466 | |
""" | |
shape = list(a.shape[:-1]) | |
length = int(a.shape[-1]) | |
n_frames = math.ceil(length / stride) | |
tgt_length = (n_frames - 1) * stride + kernel_size | |
a = F.pad(input=a, pad=[0, tgt_length - length]) | |
strides = [a.stride(dim) for dim in range(a.dim())] | |
if strides[-1] != 1: | |
raise ValueError("Data should be contiguous.") | |
strides = strides[:-1] + [stride, 1] | |
shape.append(n_frames) | |
shape.append(kernel_size) | |
return a.as_strided(shape, strides) | |
def _rescale_module(module): | |
r""" | |
Rescales initial weight scale for all models within the module. | |
""" | |
for sub in module.modules(): | |
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): | |
std = sub.weight.std().detach() | |
scale = (std / 0.1) ** 0.5 | |
sub.weight.data /= scale | |
if sub.bias is not None: | |
sub.bias.data /= scale | |
def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor: | |
other = list(x.shape[:-1]) | |
length = int(x.shape[-1]) | |
x = x.reshape(-1, length) | |
z = torch.stft( | |
x, | |
n_fft * (1 + pad), | |
hop_length, | |
window=torch.hann_window(n_fft).to(x), | |
win_length=n_fft, | |
normalized=True, | |
center=True, | |
return_complex=True, | |
pad_mode="reflect", | |
) | |
_, freqs, frame = z.shape | |
other.extend([freqs, frame]) | |
return z.view(other) | |
def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor: | |
other = list(z.shape[:-2]) | |
freqs = int(z.shape[-2]) | |
frames = int(z.shape[-1]) | |
n_fft = 2 * freqs - 2 | |
z = z.view(-1, freqs, frames) | |
win_length = n_fft // (1 + pad) | |
x = torch.istft( | |
z, | |
n_fft, | |
hop_length, | |
window=torch.hann_window(win_length).to(z.real), | |
win_length=win_length, | |
normalized=True, | |
length=length, | |
center=True, | |
) | |
_, length = x.shape | |
other.append(length) | |
return x.view(other) | |
def hdemucs_low(sources: List[str]) -> HDemucs: | |
"""Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz. | |
Args: | |
sources (List[str]): See :py:func:`HDemucs`. | |
Returns: | |
HDemucs: | |
HDemucs model. | |
""" | |
return HDemucs(sources=sources, nfft=1024, depth=5) | |
def hdemucs_medium(sources: List[str]) -> HDemucs: | |
r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz. | |
.. note:: | |
Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is | |
not compatible with the original implementation in https://github.com/facebookresearch/demucs | |
Args: | |
sources (List[str]): See :py:func:`HDemucs`. | |
Returns: | |
HDemucs: | |
HDemucs model. | |
""" | |
return HDemucs(sources=sources, nfft=2048, depth=6) | |
def hdemucs_high(sources: List[str]) -> HDemucs: | |
r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz. | |
Args: | |
sources (List[str]): See :py:func:`HDemucs`. | |
Returns: | |
HDemucs: | |
HDemucs model. | |
""" | |
return HDemucs(sources=sources, nfft=4096, depth=6) | |