tezuesh's picture
Upload folder using huggingface_hub
22d5f88 verified
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import math
import typing as tp
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm
from .streaming import RawStreamingConv1d, RawStreamingConvTranspose1d, StreamingModule
CONV_NORMALIZATIONS = frozenset(["none", "weight_norm"])
class TransposedLayerNorm(nn.Module):
"""LayerNorm for [B, C, T] inputs."""
def __init__(self, **kwargs):
super().__init__()
self.layer_norm = nn.LayerNorm(**kwargs)
def forward(self, x):
x = x.transpose(1, 2)
x = self.layer_norm(x)
return x.transpose(1, 2)
def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
assert norm in CONV_NORMALIZATIONS
if norm == "weight_norm":
return weight_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "constant",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.conv = apply_parametrization_norm(
RawStreamingConv1d(*args, **kwargs), norm
)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
**kwargs,
):
super().__init__()
self.convtr = apply_parametrization_norm(
RawStreamingConvTranspose1d(*args, **kwargs), norm
)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
return x
@dataclass
class _StreamingConv1dState:
padding_to_add: int
original_padding_to_add: int
def reset(self):
self.padding_to_add = self.original_padding_to_add
class StreamingConv1d(StreamingModule[_StreamingConv1dState]):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causal: bool = False,
norm: str = "none",
norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = "reflect",
):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn(
"StreamingConv1d has been initialized with stride > 1 and dilation > 1"
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
)
self.conv = NormConv1d(
in_channels,
out_channels,
kernel_size,
stride,
dilation=dilation,
groups=groups,
bias=bias,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.pad_mode = pad_mode
@property
def _stride(self) -> int:
return self.conv.conv.stride[0]
@property
def _kernel_size(self) -> int:
return self.conv.conv.kernel_size[0]
@property
def _effective_kernel_size(self) -> int:
dilation = self.conv.conv.dilation[0]
return (
self._kernel_size - 1
) * dilation + 1 # effective kernel size with dilations
@property
def _padding_total(self) -> int:
return self._effective_kernel_size - self._stride
def _init_streaming_state(self, batch_size: int) -> _StreamingConv1dState:
assert self.causal, "streaming is only supported for causal convs"
return _StreamingConv1dState(self._padding_total, self._padding_total)
def forward(self, x):
B, C, T = x.shape
padding_total = self._padding_total
extra_padding = get_extra_padding_for_conv1d(
x, self._effective_kernel_size, self._stride, padding_total
)
state = self._streaming_state
if state is None:
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
)
else:
if state.padding_to_add > 0 and x.shape[-1] > 0:
x = pad1d(x, (state.padding_to_add, 0), mode=self.pad_mode)
state.padding_to_add = 0
return self.conv(x)
@dataclass
class _StreamingConvTr1dState:
pass
def reset(self):
pass
class StreamingConvTranspose1d(StreamingModule[_StreamingConvTr1dState]):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
groups: int = 1,
bias: bool = True,
causal: bool = False,
norm: str = "none",
trim_right_ratio: float = 1.0,
norm_kwargs: tp.Dict[str, tp.Any] = {},
):
super().__init__()
self.convtr = NormConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
groups=groups,
bias=bias,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert (
self.causal or self.trim_right_ratio == 1.0
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
def _init_streaming_state(self, batch_size: int) -> _StreamingConvTr1dState:
assert self.causal, "streaming is only supported for causal convtrs"
return _StreamingConvTr1dState()
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
if not self.is_streaming:
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y