Spaces:
Running
on
L4
Running
on
L4
| import math | |
| import typing as tp | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Union | |
| import hydra | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from audiotools import AudioSignal | |
| from audiotools.ml import BaseModel | |
| from dac.model.base import CodecMixin | |
| from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d | |
| from omegaconf import OmegaConf | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from torch.nn.utils.parametrize import remove_parametrizations | |
| class VQResult: | |
| z: torch.Tensor | |
| codes: torch.Tensor | |
| latents: torch.Tensor | |
| codebook_loss: torch.Tensor | |
| commitment_loss: torch.Tensor | |
| semantic_distill_z: torch.Tensor | None = None | |
| def find_multiple(n: int, k: int) -> int: | |
| if n % k == 0: | |
| return n | |
| return n + k - (n % k) | |
| class ModelArgs: | |
| block_size: int = 2048 | |
| n_layer: int = 8 | |
| n_head: int = 8 | |
| dim: int = 512 | |
| intermediate_size: int = 1536 | |
| n_local_heads: int = -1 | |
| head_dim: int = 64 | |
| rope_base: float = 10000 | |
| norm_eps: float = 1e-5 | |
| dropout_rate: float = 0.1 | |
| attn_dropout_rate: float = 0.1 | |
| channels_first: bool = True # to be compatible with conv1d input/output | |
| pos_embed_type: str = "rope" # can be "rope" or "conformer" | |
| max_relative_position: int = 128 # for conformer-style relative position embedding | |
| def __post_init__(self): | |
| if self.n_local_heads == -1: | |
| self.n_local_heads = self.n_head | |
| if self.intermediate_size is None: | |
| hidden_dim = 4 * self.dim | |
| n_hidden = int(2 * hidden_dim / 3) | |
| self.intermediate_size = find_multiple(n_hidden, 256) | |
| assert self.pos_embed_type in [ | |
| "rope", | |
| "conformer", | |
| ], "pos_embed_type must be either 'rope' or 'conformer'" | |
| class KVCache(nn.Module): | |
| def __init__( | |
| self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 | |
| ): | |
| super().__init__() | |
| cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) | |
| self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| def update(self, input_pos, k_val, v_val): | |
| # input_pos: [S], k_val: [B, H, S, D] | |
| assert input_pos.shape[0] == k_val.shape[2] | |
| k_out = self.k_cache | |
| v_out = self.v_cache | |
| k_out[:, :, input_pos] = k_val | |
| v_out[:, :, input_pos] = v_val | |
| return ( | |
| k_out[:, :, : input_pos.max() + 1, :], | |
| v_out[:, :, : input_pos.max() + 1, :], | |
| ) | |
| def clear_cache(self, prompt_len): | |
| self.k_cache[:, :, prompt_len:, :].fill_(0) | |
| self.v_cache[:, :, prompt_len:, :].fill_(0) | |
| class Transformer(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList( | |
| TransformerBlock(config) for _ in range(config.n_layer) | |
| ) | |
| self.norm = RMSNorm(config.dim, eps=config.norm_eps) | |
| # Only compute RoPE frequencies if using RoPE | |
| if config.pos_embed_type == "rope": | |
| freqs_cis = precompute_freqs_cis( | |
| self.config.block_size, self.config.head_dim, self.config.rope_base | |
| ) | |
| self.register_buffer("freqs_cis", freqs_cis) | |
| else: | |
| self.register_buffer("freqs_cis", None) | |
| causal_mask = torch.tril( | |
| torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool) | |
| ) | |
| self.register_buffer("causal_mask", causal_mask) | |
| self.max_batch_size = -1 | |
| self.max_seq_length = -1 | |
| self.use_kv_cache = False | |
| def setup_caches(self, max_batch_size, max_seq_length): | |
| """ | |
| This method will only be called during inference when using KV cache. | |
| """ | |
| head_dim = self.config.dim // self.config.n_head | |
| max_seq_length = find_multiple(max_seq_length, 8) | |
| self.max_seq_length = max_seq_length | |
| self.max_batch_size = max_batch_size | |
| dtype = self.norm.weight.dtype | |
| device = self.norm.weight.device | |
| for b in self.layers: | |
| b.attention.kv_cache = KVCache( | |
| max_batch_size, | |
| max_seq_length, | |
| self.config.n_local_heads, | |
| head_dim, | |
| dtype, | |
| ).to(device) | |
| self.use_kv_cache = True | |
| def forward( | |
| self, | |
| x: Tensor, | |
| input_pos: Optional[Tensor] = None, | |
| mask: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| if self.config.pos_embed_type == "rope": | |
| assert ( | |
| self.freqs_cis is not None | |
| ), "RoPE frequencies must be initialized for RoPE positional embedding" | |
| freqs_cis = self.freqs_cis[input_pos] | |
| else: | |
| freqs_cis = None | |
| if mask is None: # in case of non-causal model | |
| if not self.training and self.use_kv_cache: | |
| mask = self.causal_mask[None, None, input_pos] | |
| mask = mask[..., : input_pos.max() + 1] | |
| else: | |
| mask = self.causal_mask[None, None, input_pos] | |
| mask = mask[..., input_pos] | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x, input_pos, freqs_cis, mask) | |
| x = self.norm(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.attention = Attention(config) | |
| self.feed_forward = FeedForward(config) | |
| self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) | |
| self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) | |
| self.attention_layer_scale = LayerScale(config.dim, inplace=True) | |
| self.ffn_layer_scale = LayerScale(config.dim, inplace=True) | |
| def forward( | |
| self, | |
| x: Tensor, | |
| input_pos: Tensor, | |
| freqs_cis: Tensor, | |
| mask: Tensor, | |
| ) -> Tensor: | |
| h = x + self.attention_layer_scale( | |
| self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) | |
| ) | |
| out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h))) | |
| return out | |
| class Attention(nn.Module): | |
| def __init__(self, config: ModelArgs): | |
| super().__init__() | |
| assert config.dim % config.n_head == 0 | |
| total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim | |
| # key, query, value projections for all heads, but in a batch | |
| self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) | |
| self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) | |
| self.kv_cache = None | |
| self.n_head = config.n_head | |
| self.head_dim = config.head_dim | |
| self.n_local_heads = config.n_local_heads | |
| self.dim = config.dim | |
| self.attn_dropout_rate = config.attn_dropout_rate | |
| self.pos_embed_type = config.pos_embed_type | |
| # Add relative position embedding for conformer-style | |
| if self.pos_embed_type == "conformer": | |
| self.max_relative_position = config.max_relative_position | |
| num_pos_embeddings = 2 * config.max_relative_position + 1 | |
| self.rel_pos_embeddings = nn.Parameter( | |
| torch.zeros(num_pos_embeddings, self.head_dim) | |
| ) | |
| nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02) | |
| def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor: | |
| # q: [B, H, S, D] | |
| # Returns: [B, H, S, S] | |
| positions = torch.arange(seqlen, device=q.device) | |
| relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S] | |
| relative_positions = torch.clamp( | |
| relative_positions + self.max_relative_position, | |
| 0, | |
| 2 * self.max_relative_position, | |
| ) | |
| rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D] | |
| # Compute attention scores with relative position embeddings | |
| q = q.transpose(1, 2) # [B, S, H, D] | |
| rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S] | |
| rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S] | |
| return rel_logits | |
| def forward( | |
| self, | |
| x: Tensor, | |
| freqs_cis: Tensor, | |
| mask: Tensor, | |
| input_pos: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| bsz, seqlen, _ = x.shape | |
| kv_size = self.n_local_heads * self.head_dim | |
| q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) | |
| context_seqlen = seqlen | |
| q = q.view(bsz, seqlen, self.n_head, self.head_dim) | |
| k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) | |
| v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) | |
| if self.pos_embed_type == "rope": | |
| q = apply_rotary_emb(q, freqs_cis) | |
| k = apply_rotary_emb(k, freqs_cis) | |
| q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) | |
| if self.kv_cache is not None: | |
| k, v = self.kv_cache.update(input_pos, k, v) | |
| k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
| v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
| if self.pos_embed_type == "conformer": | |
| # Compute attention scores | |
| scale = 1.0 / math.sqrt(self.head_dim) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) * scale | |
| # Add relative position embeddings for conformer-style | |
| rel_scores = self._compute_conformer_pos_scores(q, seqlen) | |
| scores = scores + rel_scores | |
| # Apply attention | |
| if mask is not None: | |
| scores = scores.masked_fill(~mask, float("-inf")) | |
| attn = F.softmax(scores, dim=-1) | |
| if self.attn_dropout_rate > 0 and self.training: | |
| attn = F.dropout(attn, p=self.attn_dropout_rate) | |
| y = torch.matmul(attn, v) | |
| else: | |
| y = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=self.attn_dropout_rate if self.training else 0.0, | |
| attn_mask=mask, | |
| ) | |
| # is_causal=True) | |
| y = ( | |
| y.transpose(1, 2) | |
| .contiguous() | |
| .view(bsz, seqlen, self.head_dim * self.n_head) | |
| ) | |
| y = self.wo(y) | |
| return y | |
| class FeedForward(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
| self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
| self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) | |
| def forward(self, x: Tensor) -> Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class LayerScale(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| init_values: Union[float, Tensor] = 1e-2, | |
| inplace: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.inplace = inplace | |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma | |
| class WindowLimitedTransformer(Transformer): | |
| """ | |
| Transformer with window limited attention, causal. | |
| """ | |
| def __init__( | |
| self, | |
| config: ModelArgs, | |
| input_dim: int = 512, | |
| window_size: Optional[int] = None, | |
| causal: bool = True, | |
| look_ahead_conv: nn.Module = None, | |
| ): | |
| super().__init__(config) | |
| self.window_size = window_size | |
| self.causal = causal | |
| self.channels_first = config.channels_first | |
| self.look_ahead_conv = ( | |
| look_ahead_conv if look_ahead_conv is not None else nn.Identity() | |
| ) | |
| self.input_proj = ( | |
| nn.Linear(input_dim, config.dim) | |
| if input_dim != config.dim | |
| else nn.Identity() | |
| ) | |
| self.output_proj = ( | |
| nn.Linear(config.dim, input_dim) | |
| if input_dim != config.dim | |
| else nn.Identity() | |
| ) | |
| def make_window_limited_mask( | |
| self, | |
| max_length: int, | |
| x_lens: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| """ | |
| Make mask to form window limited attention. | |
| """ | |
| if self.causal: | |
| mask = torch.tril(torch.ones(max_length, max_length)) | |
| row_indices = torch.arange(max_length).view(-1, 1) | |
| window_size = self.window_size or max_length | |
| valid_range = (row_indices - window_size + 1).clamp(min=0) | |
| column_indices = torch.arange(max_length) | |
| mask = (column_indices >= valid_range) & mask.bool() | |
| else: | |
| raise NotImplementedError | |
| mask = mask.bool()[None, None] | |
| return mask | |
| def make_mask( | |
| self, | |
| max_length: int, | |
| x_lens: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| """ | |
| Make ordinary mask if window size is not specified. | |
| """ | |
| if self.causal: | |
| mask = torch.tril(torch.ones(max_length, max_length)) | |
| else: | |
| mask = torch.ones(max_length, max_length) | |
| mask = mask.bool()[None, None] | |
| for i, x_len in enumerate(x_lens): | |
| mask[:x_len, i] = 0 | |
| mask = mask.bool()[None, None] | |
| return mask | |
| def forward( | |
| self, | |
| x: Tensor, | |
| x_lens: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| if self.channels_first: | |
| x = x.transpose(1, 2) | |
| x = self.input_proj(x) # (B, T, D) | |
| x = self.look_ahead_conv(x) | |
| input_pos = torch.arange(x.shape[1], device=x.device) | |
| # construct mask to form window limited attention | |
| max_length = x.shape[1] | |
| if self.window_size is not None: | |
| mask = self.make_window_limited_mask(max_length, x_lens) | |
| else: | |
| mask = self.make_mask(max_length, x_lens) | |
| mask = mask.to(x.device) | |
| x = super().forward(x, input_pos, mask) | |
| x = self.output_proj(x) # (B, T, D) | |
| if self.channels_first: | |
| x = x.transpose(1, 2) | |
| return x | |
| def precompute_freqs_cis( | |
| seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16 | |
| ) -> Tensor: | |
| freqs = 1.0 / ( | |
| base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) | |
| ) | |
| t = torch.arange(seq_len, device=freqs.device) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) | |
| return cache.to(dtype=dtype) | |
| def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |
| xshaped = x.float().reshape(*x.shape[:-1], -1, 2) | |
| freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) | |
| x_out2 = torch.stack( | |
| [ | |
| xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], | |
| xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], | |
| ], | |
| -1, | |
| ) | |
| x_out2 = x_out2.flatten(3) | |
| return x_out2.type_as(x) | |
| def init_weights(m): | |
| if isinstance(m, nn.Conv1d): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
| """Remove padding from x, handling properly zero padding. Only for 1d!""" | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| assert (padding_left + padding_right) <= x.shape[-1] | |
| end = x.shape[-1] - padding_right | |
| return x[..., padding_left:end] | |
| def get_extra_padding_for_conv1d( | |
| x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 | |
| ) -> int: | |
| """See `pad_for_conv1d`.""" | |
| length = x.shape[-1] | |
| n_frames = (length - kernel_size + padding_total) / stride + 1 | |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
| return ideal_length - length | |
| def pad1d( | |
| x: torch.Tensor, | |
| paddings: tp.Tuple[int, int], | |
| mode: str = "zeros", | |
| value: float = 0.0, | |
| ): | |
| """Tiny wrapper around F.pad, just to allow for reflect padding on small input. | |
| If this is the case, we insert extra 0 padding to the right | |
| before the reflection happen. | |
| """ | |
| length = x.shape[-1] | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| if mode == "reflect": | |
| max_pad = max(padding_left, padding_right) | |
| extra_pad = 0 | |
| if length <= max_pad: | |
| extra_pad = max_pad - length + 1 | |
| x = F.pad(x, (0, extra_pad)) | |
| padded = F.pad(x, paddings, mode, value) | |
| end = padded.shape[-1] - extra_pad | |
| return padded[..., :end] | |
| else: | |
| return F.pad(x, paddings, mode, value) | |
| class CausalConvNet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| dilation=1, | |
| stride=1, | |
| groups=1, | |
| padding=None, | |
| ): | |
| super(CausalConvNet, self).__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| groups=groups, | |
| ) | |
| self.stride = stride | |
| self.kernel_size = (kernel_size - 1) * dilation + 1 | |
| self.dilation = dilation | |
| self.padding = self.kernel_size - self.stride | |
| def forward(self, x): | |
| pad = self.padding | |
| extra_padding = get_extra_padding_for_conv1d( | |
| x, self.kernel_size, self.stride, pad | |
| ) | |
| x = pad1d(x, (pad, extra_padding), mode="constant", value=0) | |
| return self.conv(x).contiguous() | |
| def weight_norm(self, name="weight", dim=0): | |
| self.conv = weight_norm(self.conv, name=name, dim=dim) | |
| return self | |
| def remove_weight_norm(self): | |
| self.conv = remove_parametrizations(self.conv) | |
| return self | |
| class CausalTransConvNet(nn.Module): | |
| def __init__( | |
| self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None | |
| ): | |
| super(CausalTransConvNet, self).__init__() | |
| self.conv = nn.ConvTranspose1d( | |
| in_channels, out_channels, kernel_size, stride=stride, dilation=dilation | |
| ) | |
| self.stride = stride | |
| self.kernel_size = kernel_size | |
| def forward(self, x): | |
| x = self.conv(x) | |
| pad = self.kernel_size - self.stride | |
| padding_right = math.ceil(pad) | |
| padding_left = pad - padding_right | |
| x = unpad1d(x, (padding_left, padding_right)) | |
| return x.contiguous() | |
| def weight_norm(self, name="weight", dim=0): | |
| self.conv = weight_norm(self.conv, name=name, dim=dim) | |
| return self | |
| def remove_weight_norm(self): | |
| self.conv = remove_parametrizations(self.conv) | |
| return self | |
| def CausalWNConv1d(*args, **kwargs): | |
| return CausalConvNet(*args, **kwargs).weight_norm() | |
| def CausalWNConvTranspose1d(*args, **kwargs): | |
| return CausalTransConvNet(*args, **kwargs).weight_norm() | |
| class ResidualUnit(nn.Module): | |
| def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| pad = ((7 - 1) * dilation) // 2 | |
| self.block = nn.Sequential( | |
| Snake1d(dim), | |
| conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad), | |
| Snake1d(dim), | |
| conv_class(dim, dim, kernel_size=1), | |
| ) | |
| self.causal = causal | |
| def forward(self, x): | |
| y = self.block(x) | |
| pad = x.shape[-1] - y.shape[-1] | |
| if pad > 0: | |
| if self.causal: | |
| x = x[..., :-pad] | |
| else: | |
| x = x[..., pad // 2 : -pad // 2] | |
| return x + y | |
| class EncoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int = 16, | |
| stride: int = 1, | |
| causal: bool = False, | |
| n_t_layer: int = 0, | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| transformer_module = ( | |
| nn.Identity() | |
| if n_t_layer == 0 | |
| else ( | |
| WindowLimitedTransformer( | |
| causal=causal, | |
| input_dim=dim, | |
| window_size=512, | |
| config=transformer_general_config( | |
| n_layer=n_t_layer, | |
| n_head=dim // 64, | |
| dim=dim, | |
| intermediate_size=dim * 3, | |
| ), | |
| ) | |
| ) | |
| ) | |
| self.block = nn.Sequential( | |
| ResidualUnit(dim // 2, dilation=1, causal=causal), | |
| ResidualUnit(dim // 2, dilation=3, causal=causal), | |
| ResidualUnit(dim // 2, dilation=9, causal=causal), | |
| Snake1d(dim // 2), | |
| conv_class( | |
| dim // 2, | |
| dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=math.ceil(stride / 2), | |
| ), | |
| transformer_module, | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int = 64, | |
| strides: list = [2, 4, 8, 8], | |
| d_latent: int = 64, | |
| n_transformer_layers: list = [0, 0, 4, 4], | |
| transformer_general_config: ModelArgs = None, | |
| causal: bool = False, | |
| ): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| # Create first convolution | |
| self.block = [conv_class(1, d_model, kernel_size=7, padding=3)] | |
| # Create EncoderBlocks that double channels as they downsample by `stride` | |
| for stride, n_t_layer in zip(strides, n_transformer_layers): | |
| d_model *= 2 | |
| self.block += [ | |
| EncoderBlock( | |
| d_model, | |
| stride=stride, | |
| causal=causal, | |
| n_t_layer=n_t_layer, | |
| transformer_general_config=transformer_general_config, | |
| ) | |
| ] | |
| # Create last convolution | |
| self.block += [ | |
| Snake1d(d_model), | |
| conv_class(d_model, d_latent, kernel_size=3, padding=1), | |
| ] | |
| # Wrap black into nn.Sequential | |
| self.block = nn.Sequential(*self.block) | |
| self.enc_dim = d_model | |
| def forward(self, x): | |
| return self.block(x) | |
| class DecoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int = 16, | |
| output_dim: int = 8, | |
| stride: int = 1, | |
| causal: bool = False, | |
| n_t_layer: int = 0, | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d | |
| transformer_module = ( | |
| nn.Identity() | |
| if n_t_layer == 0 | |
| else ( | |
| WindowLimitedTransformer( | |
| causal=causal, | |
| input_dim=input_dim, | |
| window_size=None, | |
| config=transformer_general_config( | |
| n_layer=n_t_layer, | |
| n_head=input_dim // 64, | |
| dim=input_dim, | |
| intermediate_size=input_dim * 3, | |
| ), | |
| ) | |
| ) | |
| ) | |
| self.block = nn.Sequential( | |
| # transformer_module, | |
| Snake1d(input_dim), | |
| conv_trans_class( | |
| input_dim, | |
| output_dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=math.ceil(stride / 2), | |
| ), | |
| ResidualUnit(output_dim, dilation=1, causal=causal), | |
| ResidualUnit(output_dim, dilation=3, causal=causal), | |
| ResidualUnit(output_dim, dilation=9, causal=causal), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_channel, | |
| channels, | |
| rates, | |
| d_out: int = 1, | |
| causal: bool = False, | |
| n_transformer_layers: list = [0, 0, 0, 0], | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| conv_class = CausalWNConv1d if causal else WNConv1d | |
| # Add first conv layer | |
| layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)] | |
| # Add upsampling + MRF blocks | |
| for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)): | |
| input_dim = channels // 2**i | |
| output_dim = channels // 2 ** (i + 1) | |
| layers += [ | |
| DecoderBlock( | |
| input_dim, | |
| output_dim, | |
| stride, | |
| causal=causal, | |
| n_t_layer=n_t_layer, | |
| transformer_general_config=transformer_general_config, | |
| ) | |
| ] | |
| # Add final conv layer | |
| layers += [ | |
| Snake1d(output_dim), | |
| conv_class(output_dim, d_out, kernel_size=7, padding=3), | |
| nn.Tanh(), | |
| ] | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |
| class DAC(BaseModel, CodecMixin): | |
| def __init__( | |
| self, | |
| encoder_dim: int = 64, | |
| encoder_rates: List[int] = [2, 4, 8, 8], | |
| latent_dim: int = None, | |
| decoder_dim: int = 1536, | |
| decoder_rates: List[int] = [8, 8, 4, 2], | |
| quantizer: torch.nn.Module = None, | |
| sample_rate: int = 44100, | |
| causal: bool = True, | |
| encoder_transformer_layers: List[int] = [0, 0, 0, 0], | |
| decoder_transformer_layers: List[int] = [0, 0, 0, 0], | |
| transformer_general_config=None, | |
| ): | |
| super().__init__() | |
| self.encoder_dim = encoder_dim | |
| self.encoder_rates = encoder_rates | |
| self.decoder_dim = decoder_dim | |
| self.decoder_rates = decoder_rates | |
| self.sample_rate = sample_rate | |
| if latent_dim is None: | |
| latent_dim = encoder_dim * (2 ** len(encoder_rates)) | |
| self.latent_dim = latent_dim | |
| self.hop_length = np.prod(encoder_rates) | |
| self.encoder = Encoder( | |
| encoder_dim, | |
| encoder_rates, | |
| latent_dim, | |
| causal=causal, | |
| n_transformer_layers=encoder_transformer_layers, | |
| transformer_general_config=transformer_general_config, | |
| ) | |
| self.quantizer = quantizer | |
| self.decoder = Decoder( | |
| latent_dim, | |
| decoder_dim, | |
| decoder_rates, | |
| causal=causal, | |
| n_transformer_layers=decoder_transformer_layers, | |
| transformer_general_config=transformer_general_config, | |
| ) | |
| self.sample_rate = sample_rate | |
| self.apply(init_weights) | |
| self.delay = self.get_delay() | |
| self.frame_length = self.hop_length * 4 | |
| def preprocess(self, audio_data, sample_rate): | |
| if sample_rate is None: | |
| sample_rate = self.sample_rate | |
| assert sample_rate == self.sample_rate | |
| length = audio_data.shape[-1] | |
| right_pad = math.ceil(length / self.hop_length) * self.hop_length - length | |
| audio_data = nn.functional.pad(audio_data, (0, right_pad)) | |
| return audio_data | |
| def encode( | |
| self, | |
| audio_data: torch.Tensor, | |
| audio_lengths: torch.Tensor = None, | |
| n_quantizers: int = None, | |
| **kwargs, | |
| ): | |
| """Encode given audio data and return quantized latent codes | |
| Parameters | |
| ---------- | |
| audio_data : Tensor[B x T] | |
| Audio data to encode | |
| n_quantizers : int, optional | |
| Number of quantizers to use, by default None | |
| If None, all quantizers are used. | |
| Returns | |
| ------- | |
| dict | |
| A dictionary with the following keys: | |
| "z" : Tensor[B x D x T] | |
| Quantized continuous representation of input | |
| "codes" : Tensor[B x N x T] | |
| Codebook indices for each codebook | |
| (quantized discrete representation of input) | |
| "latents" : Tensor[B x N*D x T] | |
| Projected latents (continuous representation of input before quantization) | |
| "vq/commitment_loss" : Tensor[1] | |
| Commitment loss to train encoder to predict vectors closer to codebook | |
| entries | |
| "vq/codebook_loss" : Tensor[1] | |
| Codebook loss to update the codebook | |
| "length" : int | |
| Number of samples in input audio | |
| """ | |
| # pad to multiple of self.frame_length | |
| if audio_data.ndim == 2: | |
| audio_data = audio_data.unsqueeze(1) | |
| # print(audio_data.shape) | |
| length = audio_data.shape[-1] | |
| right_pad = math.ceil(length / self.frame_length) * self.frame_length - length | |
| audio_data = nn.functional.pad(audio_data, (0, right_pad)) | |
| if audio_lengths is None: | |
| audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device) | |
| z = self.encoder(audio_data) | |
| vq_results = self.quantizer(z, n_quantizers, **kwargs) | |
| indices = vq_results.codes | |
| indices_lens = torch.ceil(audio_lengths / self.frame_length).long() | |
| return indices, indices_lens | |
| def decode(self, indices: torch.Tensor, feature_lengths): | |
| if indices.ndim == 2: | |
| indices = indices[None] | |
| z = self.quantizer.decode(indices) | |
| audio_lengths = feature_lengths * self.frame_length | |
| return self.decoder(z), audio_lengths | |
| def forward( | |
| self, | |
| audio_data: torch.Tensor, | |
| template: torch.Tensor = None, | |
| mask: torch.Tensor = None, | |
| sample_rate: int = None, | |
| n_quantizers: int = None, | |
| **kwargs, | |
| ): | |
| """Model forward pass | |
| Parameters | |
| ---------- | |
| audio_data : Tensor[B x 1 x T] | |
| Audio data to encode | |
| sample_rate : int, optional | |
| Sample rate of audio data in Hz, by default None | |
| If None, defaults to `self.sample_rate` | |
| n_quantizers : int, optional | |
| Number of quantizers to use, by default None. | |
| If None, all quantizers are used. | |
| Returns | |
| ------- | |
| dict | |
| A dictionary with the following keys: | |
| "z" : Tensor[B x D x T] | |
| Quantized continuous representation of input | |
| "codes" : Tensor[B x N x T] | |
| Codebook indices for each codebook | |
| (quantized discrete representation of input) | |
| "latents" : Tensor[B x N*D x T] | |
| Projected latents (continuous representation of input before quantization) | |
| "vq/commitment_loss" : Tensor[1] | |
| Commitment loss to train encoder to predict vectors closer to codebook | |
| entries | |
| "vq/codebook_loss" : Tensor[1] | |
| Codebook loss to update the codebook | |
| "length" : int | |
| Number of samples in input audio | |
| "audio" : Tensor[B x 1 x length] | |
| Decoded audio data. | |
| """ | |
| length = audio_data.shape[-1] | |
| audio_data = self.preprocess(audio_data, sample_rate) | |
| vq_results = self.encode(audio_data, n_quantizers, **kwargs) | |
| z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z | |
| x = self.decode(z) | |
| return x[..., :length], vq_results | |