Spaces:
Running
on
L4
Running
on
L4
| import math | |
| import typing as tp | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dac.nn.quantize import ResidualVectorQuantize | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from torch.nn.utils.parametrize import remove_parametrizations | |
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
| """Remove padding from x, handling properly zero padding. Only for 1d!""" | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| assert (padding_left + padding_right) <= x.shape[-1] | |
| end = x.shape[-1] - padding_right | |
| return x[..., padding_left:end] | |
| def get_extra_padding_for_conv1d( | |
| x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 | |
| ) -> int: | |
| """See `pad_for_conv1d`.""" | |
| length = x.shape[-1] | |
| n_frames = (length - kernel_size + padding_total) / stride + 1 | |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
| return ideal_length - length | |
| def pad1d( | |
| x: torch.Tensor, | |
| paddings: tp.Tuple[int, int], | |
| mode: str = "zeros", | |
| value: float = 0.0, | |
| ): | |
| """Tiny wrapper around F.pad, just to allow for reflect padding on small input. | |
| If this is the case, we insert extra 0 padding to the right | |
| before the reflection happen. | |
| """ | |
| length = x.shape[-1] | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| if mode == "reflect": | |
| max_pad = max(padding_left, padding_right) | |
| extra_pad = 0 | |
| if length <= max_pad: | |
| extra_pad = max_pad - length + 1 | |
| x = F.pad(x, (0, extra_pad)) | |
| padded = F.pad(x, paddings, mode, value) | |
| end = padded.shape[-1] - extra_pad | |
| return padded[..., :end] | |
| else: | |
| return F.pad(x, paddings, mode, value) | |
| class CausalConvNet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| dilation=1, | |
| stride=1, | |
| groups=1, | |
| padding=None, | |
| ): | |
| super(CausalConvNet, self).__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| groups=groups, | |
| ) | |
| self.stride = stride | |
| self.kernel_size = (kernel_size - 1) * dilation + 1 | |
| self.dilation = dilation | |
| self.padding = self.kernel_size - self.stride | |
| def forward(self, x): | |
| pad = self.padding | |
| extra_padding = get_extra_padding_for_conv1d( | |
| x, self.kernel_size, self.stride, pad | |
| ) | |
| x = pad1d(x, (pad, extra_padding), mode="constant", value=0) | |
| return self.conv(x).contiguous() | |
| def weight_norm(self, name="weight", dim=0): | |
| self.conv = weight_norm(self.conv, name=name, dim=dim) | |
| return self | |
| def remove_weight_norm(self): | |
| self.conv = remove_parametrizations(self.conv) | |
| return self | |
| class CausalTransConvNet(nn.Module): | |
| def __init__( | |
| self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None | |
| ): | |
| super(CausalTransConvNet, self).__init__() | |
| self.conv = nn.ConvTranspose1d( | |
| in_channels, out_channels, kernel_size, stride=stride, dilation=dilation | |
| ) | |
| self.stride = stride | |
| self.kernel_size = kernel_size | |
| def forward(self, x): | |
| x = self.conv(x) | |
| pad = self.kernel_size - self.stride | |
| padding_right = math.ceil(pad) | |
| padding_left = pad - padding_right | |
| x = unpad1d(x, (padding_left, padding_right)) | |
| return x.contiguous() | |
| def weight_norm(self, name="weight", dim=0): | |
| self.conv = weight_norm(self.conv, name=name, dim=dim) | |
| return self | |
| def remove_weight_norm(self): | |
| self.conv = remove_parametrizations(self.conv) | |
| return self | |
| # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py | |
| class ConvNeXtBlock(nn.Module): | |
| r"""ConvNeXt Block. There are two equivalent implementations: | |
| (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |
| (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |
| We use (2) as we find it slightly faster in PyTorch | |
| Args: | |
| dim (int): Number of input channels. | |
| drop_path (float): Stochastic depth rate. Default: 0.0 | |
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. | |
| kernel_size (int): Kernel size for depthwise conv. Default: 7. | |
| dilation (int): Dilation for depthwise conv. Default: 1. | |
| """ # noqa: E501 | |
| def __init__( | |
| self, | |
| dim: int, | |
| layer_scale_init_value: float = 1e-6, | |
| mlp_ratio: float = 4.0, | |
| kernel_size: int = 7, | |
| dilation: int = 1, | |
| ): | |
| super().__init__() | |
| convnet_type = CausalConvNet | |
| self.dwconv = convnet_type( | |
| dim, | |
| dim, | |
| kernel_size=kernel_size, | |
| # padding=int(dilation * (kernel_size - 1) / 2), | |
| groups=dim, | |
| dilation=dilation, | |
| ) # depthwise conv | |
| self.norm = nn.LayerNorm(dim, eps=1e-6) | |
| self.pwconv1 = nn.Linear( | |
| dim, int(mlp_ratio * dim) | |
| ) # pointwise/1x1 convs, implemented with linear layers | |
| self.act = nn.GELU() | |
| self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) | |
| self.gamma = ( | |
| nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
| if layer_scale_init_value > 0 | |
| else None | |
| ) | |
| def forward(self, x, apply_residual: bool = True): | |
| input = x | |
| x = self.dwconv(x) | |
| x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) | |
| x = self.norm(x) | |
| x = self.pwconv1(x) | |
| x = self.act(x) | |
| x = self.pwconv2(x) | |
| if self.gamma is not None: | |
| x = self.gamma * x | |
| x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) | |
| if apply_residual: | |
| x = input + x | |
| return x | |
| class VQResult: | |
| z: torch.Tensor | |
| codes: torch.Tensor | |
| latents: torch.Tensor | |
| codebook_loss: torch.Tensor | |
| commitment_loss: torch.Tensor | |
| semantic_distill_z: torch.Tensor | None = None | |
| class DownsampleResidualVectorQuantize(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int = 1024, | |
| n_codebooks: int = 9, | |
| codebook_dim: int = 8, | |
| quantizer_dropout: float = 0.5, | |
| codebook_size: int = 1024, | |
| semantic_codebook_size: int = 4096, | |
| downsample_factor: tuple[int] = (2, 2), | |
| downsample_dims: tuple[int] | None = None, | |
| pre_module: nn.Module | None = None, | |
| post_module: nn.Module | None = None, | |
| semantic_predictor_module: nn.Module | None = None, | |
| ): | |
| super().__init__() | |
| if downsample_dims is None: | |
| downsample_dims = [input_dim for _ in range(len(downsample_factor))] | |
| all_dims = (input_dim,) + tuple(downsample_dims) | |
| self.semantic_quantizer = ResidualVectorQuantize( | |
| input_dim=input_dim, | |
| n_codebooks=1, | |
| codebook_size=semantic_codebook_size, | |
| codebook_dim=codebook_dim, | |
| quantizer_dropout=0.0, | |
| ) | |
| self.quantizer = ResidualVectorQuantize( | |
| input_dim=input_dim, | |
| n_codebooks=n_codebooks, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| quantizer_dropout=quantizer_dropout, | |
| ) | |
| self.downsample_factor = downsample_factor | |
| self.downsample_dims = downsample_dims | |
| convnet_type = CausalConvNet | |
| transconvnet_type = CausalTransConvNet | |
| self.downsample = nn.Sequential( | |
| *[ | |
| nn.Sequential( | |
| convnet_type( | |
| all_dims[idx], | |
| all_dims[idx + 1], | |
| kernel_size=factor, | |
| stride=factor, | |
| ), | |
| ConvNeXtBlock(dim=all_dims[idx + 1]), | |
| ) | |
| for idx, factor in enumerate(downsample_factor) | |
| ] | |
| ) | |
| self.upsample = nn.Sequential( | |
| *[ | |
| nn.Sequential( | |
| transconvnet_type( | |
| all_dims[idx + 1], | |
| all_dims[idx], | |
| kernel_size=factor, | |
| stride=factor, | |
| ), | |
| ConvNeXtBlock(dim=all_dims[idx]), | |
| ) | |
| for idx, factor in reversed(list(enumerate(downsample_factor))) | |
| ] | |
| ) | |
| self.apply(self._init_weights) | |
| self.pre_module = ( | |
| pre_module if pre_module is not None else nn.Identity() | |
| ) # leave for transformer, LSTM or Mamba or something else | |
| self.post_module = post_module if post_module is not None else nn.Identity() | |
| self.semantic_predictor_module = ( | |
| semantic_predictor_module | |
| if semantic_predictor_module is not None | |
| else nn.Identity() | |
| ) | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def forward( | |
| self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs | |
| ): | |
| # z: (B, D, T) | |
| original_shape = z.shape | |
| if semantic_len is None: | |
| semantic_len = torch.LongTensor([z.shape[-1]]) | |
| z = self.downsample(z) | |
| z = self.pre_module(z) # B, T, D | |
| ( | |
| semantic_z, | |
| semantic_codes, | |
| semantic_latents, | |
| semantic_commitment_loss, | |
| semantic_codebook_loss, | |
| ) = self.semantic_quantizer(z) | |
| residual_z = z - semantic_z | |
| residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer( | |
| residual_z, n_quantizers=n_quantizers | |
| ) | |
| z = semantic_z + residual_z | |
| commitment_loss = commitment_loss + semantic_commitment_loss | |
| codebook_loss = codebook_loss + semantic_codebook_loss | |
| codes = torch.cat([semantic_codes, codes], dim=1) | |
| latents = torch.cat([semantic_latents, latents], dim=1) | |
| z = self.post_module(z) | |
| z = self.upsample(z) | |
| # z: (B, D, T) | |
| # semantic distillation (disabled here since only used in training) | |
| # semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D | |
| # Pad or crop z to match original shape | |
| diff = original_shape[-1] - z.shape[-1] | |
| right = 0 | |
| left = abs(diff) - right | |
| if diff > 0: | |
| z = F.pad(z, (left, right)) | |
| elif diff < 0: | |
| z = z[..., left:] | |
| results = VQResult( | |
| z=z, | |
| codes=codes, | |
| latents=latents, | |
| commitment_loss=commitment_loss, | |
| codebook_loss=codebook_loss, | |
| ) | |
| return results | |
| # def encode(self, z): | |
| # z = self.downsample(z) | |
| # z = self.pre_module(z) | |
| # _, indices, _, _, _ = self.quantizer(z.mT) | |
| # indices = rearrange(indices, "g b l r -> b (g r) l") | |
| # return indices | |
| # | |
| def decode(self, indices: torch.Tensor): | |
| # indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) | |
| # print(f"indices: {indices.shape}, semantic_quantizer.codebook_size: {self.semantic_quantizer.codebook_size}, quantizer.codebook_size: {self.quantizer.codebook_size}, semantic min: {indices[:, 0].min()}, max: {indices[:, 0].max()}, quantizer min: {indices[:, 1:].min()}, max: {indices[:, 1:].max()}") | |
| new_indices = torch.zeros_like(indices) | |
| new_indices[:, 0] = torch.clamp( | |
| indices[:, 0], max=self.semantic_quantizer.codebook_size - 1 | |
| ) | |
| new_indices[:, 1:] = torch.clamp( | |
| indices[:, 1:], max=self.quantizer.codebook_size - 1 | |
| ) | |
| z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0] | |
| z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0] | |
| z_q = z_q_semantic + z_q_residual | |
| z_q = self.post_module(z_q) | |
| z_q = self.upsample(z_q) | |
| return z_q | |
| # def from_latents(self, latents: torch.Tensor): | |
| # z_q, z_p, codes = super().from_latents(latents) | |
| # z_q = self.upsample(z_q) | |
| # return z_q, z_p, codes | |
| if __name__ == "__main__": | |
| rvq = DownsampleResidualVectorQuantize( | |
| input_dim=512, | |
| n_codebooks=8, | |
| codebook_dim=8, | |
| codebook_size=1024, | |
| quantizer_dropout=0.5, | |
| downsample_factor=[2, 2], | |
| ) | |
| rvq.eval() | |
| x = torch.randn(2, 512, 442) | |
| result = rvq(x) | |
| print(rvq) | |
| print(result.latents.shape, result.codes.shape, result.z.shape) | |
| # y = rvq.from_codes(result.codes) | |
| # print(y[0].shape) | |
| # y = rvq.from_latents( | |
| result1 = rvq(x[:, :, :40]) | |
| print(result1.latents.shape, result1.codes.shape, result1.z.shape) | |
| assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8) | |
| print("Success") | |