|
|
|
|
|
|
|
|
|
""" |
|
Transformer model, with streaming support, + CUDA Graphable. |
|
Optimized for inference. |
|
|
|
See `StreamingTransformer` for more information. |
|
""" |
|
|
|
from contextlib import ExitStack |
|
from dataclasses import dataclass |
|
import typing as tp |
|
|
|
from einops import rearrange |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
from ..utils.compile import no_compile |
|
from .gating import make_gating |
|
from .rope import RotaryEmbedding |
|
from .streaming import StreamingModule, StreamingContainer |
|
|
|
|
|
class LayerNormF32(nn.LayerNorm): |
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
x_f32 = input.float() |
|
out_f32 = super().forward(x_f32) |
|
return out_f32.to(input.dtype) |
|
|
|
|
|
def _rms_norm( |
|
x: torch.Tensor, |
|
alpha: torch.Tensor, |
|
dtype: tp.Optional[torch.dtype], |
|
eps: float, |
|
): |
|
assert x.dim() == 3, f"RMSNorm expects 3D inputs but got {x.shape}" |
|
x_dtype = x.dtype |
|
if dtype is not None: |
|
x = x.to(dtype) |
|
var = eps + torch.mean(x**2, dim=2, keepdim=True) |
|
y = (x * (alpha.to(var) * torch.rsqrt(var))).to(x_dtype) |
|
return y |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
eps: float = 1e-5, |
|
dtype: tp.Optional[torch.dtype] = None, |
|
device=None, |
|
): |
|
super().__init__() |
|
self.eps = eps |
|
self.dtype = dtype |
|
self.alpha = nn.Parameter( |
|
torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype) |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
return _rms_norm(x, self.alpha, self.dtype, self.eps) |
|
|
|
|
|
class LayerScale(nn.Module): |
|
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). |
|
This rescales diagonally the residual outputs close to 0, with a learnt scale. |
|
|
|
Args: |
|
channels (int): Number of channels. |
|
init (float): Initial scale. |
|
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. |
|
device (torch.device or str, optional): Device on which to initialize the module. |
|
dtype (torch.dtype, optional): dtype to use to initialize the module. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels: int, |
|
init: float = 1e-4, |
|
channel_last: bool = True, |
|
device=None, |
|
dtype=None, |
|
): |
|
super().__init__() |
|
self.channel_last = channel_last |
|
self.scale = nn.Parameter( |
|
torch.full( |
|
(channels,), init, requires_grad=True, device=device, dtype=dtype |
|
) |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.channel_last: |
|
return self.scale * x |
|
else: |
|
return self.scale[:, None] * x |
|
|
|
|
|
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: |
|
"""Create normalization module for transformer encoder layer. |
|
|
|
Args: |
|
norm_type (str): Normalization method. |
|
dim (int): Dimension of the normalized layer. |
|
**kwargs (dict): Additional parameters for normalization layer. |
|
Returns: |
|
nn.Module: Normalization module. |
|
""" |
|
if norm_type == "layer_norm": |
|
return nn.LayerNorm(dim, eps=1e-5, **kwargs) |
|
elif norm_type == "layer_norm_f32": |
|
kwargs.pop("dtype", None) |
|
return LayerNormF32(dim, eps=1e-8, **kwargs) |
|
elif norm_type in {"rms_norm"}: |
|
return RMSNorm(dim, eps=1e-5, **kwargs) |
|
elif norm_type in {"rms_norm_f32"}: |
|
kwargs.pop("dtype", None) |
|
return RMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs) |
|
else: |
|
raise ValueError(f"Unknown norm type: {norm_type}") |
|
|
|
|
|
def create_sin_embedding( |
|
positions: torch.Tensor, |
|
dim: int, |
|
max_period: float = 10000, |
|
dtype: torch.dtype = torch.float32, |
|
) -> torch.Tensor: |
|
"""Create sinusoidal positional embedding, with shape `[B, T, C]`. |
|
|
|
Args: |
|
positions (torch.Tensor): LongTensor of positions. |
|
dim (int): Dimension of the embedding. |
|
max_period (float): Maximum period of the cosine/sine functions. |
|
dtype (torch.dtype or str): dtype to use to generate the embedding. |
|
Returns: |
|
torch.Tensor: Sinusoidal positional embedding. |
|
""" |
|
|
|
assert dim % 2 == 0 |
|
half_dim = dim // 2 |
|
positions = positions.to(dtype) |
|
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) |
|
max_period_tensor = torch.full( |
|
[], max_period, device=positions.device, dtype=dtype |
|
) |
|
phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) |
|
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) |
|
|
|
|
|
def multi_linear( |
|
num_linear: int, |
|
weight: torch.Tensor, |
|
x: torch.Tensor, |
|
offset: int, |
|
): |
|
"""Utility to apply a multi linear layer to the given input. A multi linear layer |
|
applies a different set of weight for each time step. |
|
|
|
Args: |
|
num_linear (int): Number of possible time steps and so number of linears. |
|
weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`. |
|
x (torch.Tensor): Input tensor, with shape `[B, T, C]`. |
|
offset (int): offset for the current time step, in particular for decoding, with |
|
time steps provided one by one. |
|
""" |
|
B, T, C = x.shape |
|
ys = [] |
|
chout, chin = weight.shape |
|
weight = weight.view(num_linear, -1, chin) |
|
for t in range(T): |
|
y = F.linear(x[:, t], weight[t + offset]) |
|
ys.append(y) |
|
out = torch.stack(ys, 1) |
|
return out |
|
|
|
|
|
def set_attention_context(model: nn.Module, context: tp.Optional[int] = None) -> None: |
|
"""Deactivates or changes the context span (in time steps) in a model. |
|
Args: |
|
model (nn.Module): model over which to look for attentions. |
|
context (int or None): new temporary context value. |
|
|
|
..Note:: this is not a context manager but a plain function changing the context forever. |
|
Initially, it was a context manager, but that led to interesting bugs when using |
|
activation checkpointing, with the context being inconsistent between the forward |
|
and backward. |
|
""" |
|
for module in model.modules(): |
|
if isinstance(module, StreamingMultiheadAttention): |
|
module.context = context |
|
|
|
|
|
class KVCacheResult(tp.NamedTuple): |
|
keys: torch.Tensor |
|
values: torch.Tensor |
|
positions: torch.Tensor |
|
|
|
@staticmethod |
|
def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult": |
|
B, H, T, D = keys.shape |
|
assert tuple(values.shape[:-1]) == (B, H, T) |
|
positions = torch.arange(T, device=keys.device, dtype=torch.long) |
|
return KVCacheResult(keys, values, positions) |
|
|
|
|
|
class RingKVCache: |
|
"""Efficient streaming KVCache to be compatible with Cuda Graph. |
|
|
|
Args: |
|
batch_size (int): Batch size. |
|
num_heads (int): Number of heads in the attention. |
|
dim_per_head (int): Dimension per head. |
|
device (torch.device): Device on which to initialize the cache. |
|
dtype (torch.dtype): dtype to use for the cache. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
batch_size: int, |
|
num_heads: int, |
|
dim_per_head: int, |
|
capacity: int, |
|
device: torch.device = torch.device("cuda"), |
|
dtype: torch.dtype = torch.bfloat16, |
|
): |
|
self.capacity = capacity |
|
self.cache = torch.zeros( |
|
(2, batch_size, num_heads, capacity, dim_per_head), |
|
device=device, |
|
dtype=dtype, |
|
) |
|
self.end_offset = torch.zeros(1, device=device, dtype=torch.long) |
|
|
|
def reset(self): |
|
self.end_offset.zero_() |
|
|
|
def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: |
|
assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape) |
|
B, H, T, D = k.shape |
|
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset |
|
indexes = indexes % self.capacity |
|
self.cache[0].index_copy_(2, indexes, k) |
|
self.cache[1].index_copy_(2, indexes, v) |
|
self.end_offset.add_(T) |
|
|
|
keys = self.cache[0] |
|
values = self.cache[1] |
|
|
|
indexes = torch.arange( |
|
self.capacity, device=self.end_offset.device, dtype=torch.long |
|
) |
|
invalid = indexes >= self.end_offset |
|
|
|
end_index = self.end_offset % self.capacity |
|
delta = indexes - end_index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
positions = torch.where( |
|
delta <= 0, |
|
self.end_offset + delta, |
|
self.end_offset + delta - self.capacity, |
|
) |
|
positions = torch.where(invalid, torch.full_like(positions, -1), positions) |
|
|
|
return KVCacheResult(keys, values, positions) |
|
|
|
|
|
@dataclass |
|
class _MHAState: |
|
kv_cache: RingKVCache |
|
offset: torch.Tensor |
|
offset_cpu: int |
|
|
|
def reset(self): |
|
self.kv_cache.reset() |
|
self.offset.zero_() |
|
self.offset_cpu = 0 |
|
|
|
|
|
class StreamingMultiheadAttention(StreamingModule[_MHAState]): |
|
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. |
|
|
|
Args: |
|
embed_dim (int): Dimension to project to. |
|
num_heads (int): Number of heads. |
|
causal (bool): Causal mask applied automatically. |
|
context (int, optional): Number of time steps the attention can access to. |
|
When causal, can access `context` time steps into the past, and when non causal, |
|
can access `context // 2` steps in the past, and the same in the future. |
|
rope (`RotaryEmbedding`, optional): Rope embedding to use. |
|
weights_per_step (int): use different weights per time step. If non zero, should correspond to the |
|
number of possible time steps. |
|
device (torch.device, optional): Device on which to initialize. |
|
dtype (torch.dtype, optional): dtype to use. |
|
""" |
|
|
|
_fsdp_final = True |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
causal: bool = False, |
|
context: tp.Optional[int] = None, |
|
rope: tp.Optional[RotaryEmbedding] = None, |
|
weights_per_step: int = 0, |
|
device=None, |
|
dtype=None, |
|
): |
|
super().__init__() |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
|
self.embed_dim = embed_dim |
|
self.causal = causal |
|
self.context = context |
|
self.rope = rope |
|
self.num_heads = num_heads |
|
|
|
out_dim = embed_dim |
|
out_dim = 3 * embed_dim |
|
mult = 1 |
|
self.weights_per_step = weights_per_step |
|
if weights_per_step: |
|
mult = weights_per_step |
|
in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs) |
|
|
|
self.in_proj_weight = in_proj.weight |
|
self.in_proj_bias = in_proj.bias |
|
self.out_proj = nn.Linear( |
|
embed_dim, mult * embed_dim, bias=False, **factory_kwargs |
|
) |
|
|
|
def _init_streaming_state(self, batch_size: int) -> _MHAState: |
|
if self.context is None: |
|
if self.weights_per_step: |
|
capacity = self.weights_per_step |
|
else: |
|
raise RuntimeError( |
|
"Cannot create a streaming KVCache without a context to estimate capacity." |
|
) |
|
else: |
|
capacity = self.context |
|
device = self.in_proj_weight.device |
|
|
|
dtype = self.in_proj_weight.dtype |
|
dim_per_head = self.embed_dim // self.num_heads |
|
kv_cache = RingKVCache( |
|
batch_size, self.num_heads, dim_per_head, capacity, device, dtype |
|
) |
|
return _MHAState( |
|
kv_cache, |
|
offset=torch.zeros(1, device=device, dtype=torch.long), |
|
offset_cpu=0, |
|
) |
|
|
|
def _complete_kv(self, k, v) -> KVCacheResult: |
|
state = self._streaming_state |
|
if state is None: |
|
return KVCacheResult.from_kv(k, v) |
|
else: |
|
return state.kv_cache.complete(k, v) |
|
|
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): |
|
state = self._streaming_state |
|
T = query.shape[1] |
|
|
|
if state is None: |
|
offset = torch.zeros(1, device=query.device, dtype=torch.long) |
|
offset_cpu = 0 |
|
else: |
|
assert self.causal, "Streaming only available for causal" |
|
offset = state.offset |
|
offset_cpu = state.offset_cpu |
|
|
|
if self.weights_per_step: |
|
projected = multi_linear( |
|
self.weights_per_step, self.in_proj_weight, query, offset_cpu |
|
) |
|
else: |
|
projected = nn.functional.linear(query, self.in_proj_weight) |
|
q, k, v = rearrange( |
|
projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads |
|
) |
|
|
|
if self.rope: |
|
q, k = self.rope(q, k, offset, time_before_heads=False) |
|
|
|
k, v, pos_k = self._complete_kv(k, v) |
|
if self.causal: |
|
pos_k = pos_k.view(1, -1) |
|
pos_q = offset + torch.arange(T, device=q.device, dtype=torch.long).view( |
|
-1, 1 |
|
) |
|
delta = pos_q - pos_k |
|
attn_bias = (pos_k >= 0) & (delta >= 0) |
|
if self.context is not None: |
|
attn_bias = attn_bias & (delta < self.context) |
|
else: |
|
attn_bias = None |
|
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0) |
|
|
|
x = rearrange(x, "b h t d -> b t (h d)") |
|
if self.weights_per_step: |
|
x = multi_linear(self.weights_per_step, self.out_proj.weight, x, offset_cpu) |
|
else: |
|
x = self.out_proj(x) |
|
if state is not None: |
|
state.offset.add_(T) |
|
state.offset_cpu += T |
|
return x |
|
|
|
|
|
@dataclass |
|
class _LayerState: |
|
offset_cpu: int |
|
|
|
def reset(self): |
|
self.offset_cpu = 0 |
|
|
|
|
|
class StreamingTransformerLayer(StreamingModule[_LayerState]): |
|
"""TransformerLayer with Streaming / Causal support. |
|
|
|
Args: |
|
d_model (int): Dimension of the data. |
|
num_heads (int): Number of heads. |
|
dim_feedforward (int): Intermediate dimension of FF module. |
|
causal (bool): Causal mask applied automatically. |
|
context (int, optional): Receptive field for the causal mask, infinite if None. |
|
custom (bool): Use custom MHA implementation, for testing / benchmarking. |
|
rope (`RotaryEmbedding`, optional): Rope embedding to use. |
|
norm (str): Normalization to use. Currently, only 'layer_norm' is supported. |
|
layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale. |
|
gating (str): if provided, replaces FFN with special gating, like GLU, GSiGLU etc. |
|
weights_per_step (int): use different weights per time step. If non zero, should correspond to the |
|
number of possible time steps. |
|
skip_self_attn: If true, skips the self attention module and the norm |
|
device (torch.device, optional): Device on which to initialize. |
|
dtype (torch.dtype, optional): dtype to use. |
|
""" |
|
|
|
_fsdp_final = True |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
num_heads: int, |
|
dim_feedforward: int | list[int] = 2048, |
|
causal: bool = False, |
|
context: tp.Optional[int] = None, |
|
rope: tp.Optional[RotaryEmbedding] = None, |
|
norm: str = "layer_norm", |
|
layer_scale: tp.Optional[float] = None, |
|
gating: str = "none", |
|
weights_per_step: int = 0, |
|
activation=F.gelu, |
|
skip_self_attn: bool = False, |
|
device=None, |
|
dtype=None, |
|
): |
|
super().__init__() |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
|
attn_kwargs: tp.Dict[str, tp.Any] = { |
|
"embed_dim": d_model, |
|
"num_heads": num_heads, |
|
} |
|
if not skip_self_attn: |
|
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( |
|
causal=causal, |
|
context=context, |
|
rope=rope, |
|
weights_per_step=weights_per_step, |
|
**attn_kwargs, |
|
**factory_kwargs, |
|
) |
|
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) |
|
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) |
|
|
|
self.weights_per_step = weights_per_step |
|
self.gating: tp.Optional[nn.Module] = None |
|
self.linear1: tp.Optional[nn.Module] = None |
|
self.linear2: tp.Optional[nn.Module] = None |
|
self.activation = activation |
|
self.skip_self_attn = skip_self_attn |
|
|
|
if isinstance(dim_feedforward, list): |
|
assert dim_feedforward |
|
assert len(dim_feedforward) == weights_per_step, ( |
|
"Length of dim_feedforward must match weights_per_step," |
|
f" got {len(dim_feedforward)} != {weights_per_step}" |
|
) |
|
if gating == "none": |
|
assert ( |
|
not weights_per_step |
|
), "weights_per_step without gating not supported for now." |
|
assert not isinstance( |
|
dim_feedforward, list |
|
), "List dim_feedforward without gating not supported for now." |
|
self.linear1 = nn.Linear( |
|
d_model, dim_feedforward, bias=False, **factory_kwargs |
|
) |
|
self.linear2 = nn.Linear( |
|
dim_feedforward, d_model, bias=False, **factory_kwargs |
|
) |
|
else: |
|
self.linear1 = None |
|
self.linear2 = None |
|
if weights_per_step: |
|
if isinstance(dim_feedforward, int): |
|
dim_feedforward = [dim_feedforward] * weights_per_step |
|
assert isinstance(dim_feedforward, list), dim_feedforward |
|
self.gating = nn.ModuleList( |
|
[ |
|
make_gating(gating, d_model, dim, **factory_kwargs) |
|
for dim in dim_feedforward |
|
] |
|
) |
|
else: |
|
assert isinstance(dim_feedforward, int) |
|
self.gating = make_gating( |
|
gating, d_model, dim_feedforward, **factory_kwargs |
|
) |
|
|
|
self.layer_scale_1: nn.Module |
|
self.layer_scale_2: nn.Module |
|
if layer_scale is None: |
|
self.layer_scale_1 = nn.Identity() |
|
self.layer_scale_2 = nn.Identity() |
|
else: |
|
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) |
|
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) |
|
|
|
def _init_streaming_state(self, batch_size: int) -> _LayerState: |
|
return _LayerState(offset_cpu=0) |
|
|
|
|
|
def _ff_block(self, x: torch.Tensor) -> torch.Tensor: |
|
state = self._streaming_state |
|
offset = 0 |
|
if state is not None: |
|
offset = state.offset_cpu |
|
x_orig = x |
|
x = self.norm2(x) |
|
if self.gating is None: |
|
assert self.linear1 is not None |
|
assert self.linear2 is not None |
|
update = self.linear2(self.activation(self.linear1(x))) |
|
else: |
|
if self.weights_per_step: |
|
assert isinstance(self.gating, nn.ModuleList) |
|
B, T, D = x.shape |
|
ys = [] |
|
for t in range(T): |
|
y = self.gating[offset + t](x[:, t : t + 1]) |
|
ys.append(y) |
|
update = torch.cat(ys, dim=1) |
|
else: |
|
update = self.gating(x) |
|
return x_orig + self.layer_scale_2(update) |
|
|
|
def _sa_block(self, x: torch.Tensor): |
|
if self.skip_self_attn: |
|
return x |
|
x_orig = x |
|
x = self.norm1(x) |
|
update = self.self_attn(x, x, x) |
|
return x_orig + self.layer_scale_1(update) |
|
|
|
def forward(self, x: torch.Tensor): |
|
with ExitStack() as stack: |
|
if x.device.type != 'cuda': |
|
stack.enter_context(no_compile()) |
|
x = self._sa_block(x) |
|
x = self._ff_block(x) |
|
state = self._streaming_state |
|
if state: |
|
state.offset_cpu += x.shape[1] |
|
return x |
|
|
|
|
|
@dataclass |
|
class _TransformerState: |
|
offset: torch.Tensor |
|
|
|
def reset(self): |
|
self.offset.zero_() |
|
|
|
|
|
class StreamingTransformer(StreamingModule[_TransformerState]): |
|
"""Transformer with Streaming / Causal support. |
|
|
|
Args: |
|
d_model (int): Dimension of the data. |
|
num_heads (int): Number of heads. |
|
dim_feedforward (int): Intermediate dimension of FF module. |
|
causal (bool): Causal mask applied automatically. |
|
context (int, optional): Receptive field for the causal mask, infinite if None. |
|
layer_scale (float, optional): If not None, LayerScale will be used |
|
with the given value as initial scale. |
|
positional_embedding (str): Positional embedding strategy (sin, rope, sin_rope, or none). |
|
max_period (float): Maximum period of the time embedding. |
|
positional_scale (float): Scale of positional embedding, set to 0 to deactivate. |
|
layer_class: (subclass of `StreamingTransformerLayer): class to use |
|
to initialize the layers, allowing further customization outside of AudioCraft. |
|
device (torch.device, optional): Device on which to initialize. |
|
dtype (torch.dtype, optional): dtype to use. |
|
**kwargs: See `StreamingTransformerLayer`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
num_heads: int, |
|
num_layers: int, |
|
dim_feedforward: int | list[int] = 2048, |
|
causal: bool = False, |
|
context: tp.Optional[int] = None, |
|
positional_embedding: str = "sin", |
|
max_period: float = 10_000, |
|
positional_scale: float = 1.0, |
|
betas: tp.Optional[tp.Tuple[float, float]] = None, |
|
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, |
|
device=None, |
|
dtype=None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
assert d_model % num_heads == 0 |
|
|
|
self.positional_embedding = positional_embedding |
|
self.max_period = max_period |
|
self.positional_scale = positional_scale |
|
self.betas = betas |
|
|
|
assert positional_embedding in {"sin", "rope", "sin_rope", "none"} |
|
self.rope: tp.Optional[RotaryEmbedding] = None |
|
if self.positional_embedding in {"rope", "sin_rope"}: |
|
self.rope = RotaryEmbedding(max_period=max_period) |
|
|
|
self.layers = nn.ModuleList() |
|
for _ in range(num_layers): |
|
self.layers.append( |
|
layer_class( |
|
d_model=d_model, |
|
num_heads=num_heads, |
|
dim_feedforward=dim_feedforward, |
|
causal=causal, |
|
context=context, |
|
rope=self.rope, |
|
device=device, |
|
dtype=dtype, |
|
**kwargs, |
|
) |
|
) |
|
|
|
def _init_streaming_state(self, batch_size: int) -> _TransformerState: |
|
device = next(self.parameters()).device |
|
return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long)) |
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
B, T, C = x.shape |
|
|
|
state = self._streaming_state |
|
if state is None: |
|
offset = torch.zeros(1, dtype=torch.long, device=x.device) |
|
else: |
|
offset = state.offset |
|
|
|
if self.positional_embedding in {"sin", "sin_rope"}: |
|
positions = torch.arange(T, device=x.device).view(1, -1, 1) |
|
positions = positions + offset.view(-1, 1, 1) |
|
pos_emb = create_sin_embedding( |
|
positions, C, max_period=self.max_period, dtype=x.dtype |
|
) |
|
x = x + self.positional_scale * pos_emb |
|
|
|
for layer in self.layers: |
|
x = layer(x, *args, **kwargs) |
|
|
|
if state is not None: |
|
state.offset.add_(T) |
|
return x |
|
|
|
|
|
class ProjectedTransformer(StreamingContainer): |
|
"""Transformer with optional projections of the input and output to different dimensions when needed. |
|
Supports multiple outputs. |
|
|
|
Args: |
|
input_dimension (int): dimension of the input. |
|
output_dimensions (tuple[int]): dimensions of the outputs. |
|
d_model (int): inner dimension of the Transformer. |
|
conv_layout (bool): If True, expects `[B, C, T]` shaped tensors, otherwise, `[B, T, C]`. |
|
Similarly, the output will have the same layout. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dimension: int, |
|
output_dimensions: tp.Tuple[int, ...], |
|
d_model: int, |
|
*, |
|
conv_layout: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.transformer = StreamingTransformer(d_model=d_model, **kwargs) |
|
self.input_dimension = input_dimension |
|
self.output_dimensions = output_dimensions |
|
self.conv_layout = conv_layout |
|
self.input_proj = None |
|
if d_model != input_dimension: |
|
self.input_proj = nn.Linear(input_dimension, d_model, bias=False) |
|
|
|
self.output_projs = nn.ModuleList() |
|
for output_dimension in output_dimensions: |
|
if d_model == output_dimension: |
|
self.output_projs.append(nn.Identity()) |
|
else: |
|
self.output_projs.append( |
|
nn.Linear(d_model, output_dimension, bias=False) |
|
) |
|
|
|
def forward(self, x, *args, **kwargs): |
|
if self.conv_layout: |
|
x = x.transpose(1, 2) |
|
if self.input_proj is not None: |
|
x = self.input_proj(x) |
|
z = self.transformer(x, *args, **kwargs) |
|
ys = [] |
|
for output_proj in self.output_projs: |
|
y = output_proj(z) |
|
if self.conv_layout: |
|
y = y.transpose(1, 2) |
|
ys.append(y) |
|
return ys |
|
|