| 
							 | 
						import typing as tp | 
					
					
						
						| 
							 | 
						from einops import rearrange | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						from torch.nn import functional as F | 
					
					
						
						| 
							 | 
						from torch.utils.checkpoint import checkpoint as torch_checkpoint | 
					
					
						
						| 
							 | 
						from xformers import ops | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						_efficient_attention_backend: str = 'torch' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _get_attention_time_dimension(memory_efficient: bool) -> int: | 
					
					
						
						| 
							 | 
						    if _efficient_attention_backend == 'torch' and memory_efficient: | 
					
					
						
						| 
							 | 
						        return 2 | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, | 
					
					
						
						| 
							 | 
						                         dtype: torch.dtype = torch.float32) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """Create sinusoidal positional embedding, with shape `[B, T, C]`. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        positions (torch.Tensor): LongTensor of positions. | 
					
					
						
						| 
							 | 
						        dim (int): Dimension of the embedding. | 
					
					
						
						| 
							 | 
						        max_period (float): Maximum period of the cosine/sine functions. | 
					
					
						
						| 
							 | 
						        dtype (torch.dtype or str): dtype to use to generate the embedding. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: Sinusoidal positional embedding. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    assert dim % 2 == 0 | 
					
					
						
						| 
							 | 
						    half_dim = dim // 2 | 
					
					
						
						| 
							 | 
						    positions = positions.to(dtype) | 
					
					
						
						| 
							 | 
						    adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) | 
					
					
						
						| 
							 | 
						    max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)   | 
					
					
						
						| 
							 | 
						    phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) | 
					
					
						
						| 
							 | 
						    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" | 
					
					
						
						| 
							 | 
						    if n_rep == 1: | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						    if _efficient_attention_backend == 'torch' and memory_efficient: | 
					
					
						
						| 
							 | 
						        bs, n_kv_heads, slen, head_dim = x.shape | 
					
					
						
						| 
							 | 
						        return ( | 
					
					
						
						| 
							 | 
						            x[:, :, None, :, :] | 
					
					
						
						| 
							 | 
						            .expand(bs, n_kv_heads, n_rep, slen, head_dim) | 
					
					
						
						| 
							 | 
						            .reshape(bs, n_kv_heads * n_rep, slen, head_dim) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        bs, slen, n_kv_heads, head_dim = x.shape | 
					
					
						
						| 
							 | 
						        return ( | 
					
					
						
						| 
							 | 
						            x[:, :, :, None, :] | 
					
					
						
						| 
							 | 
						            .expand(bs, slen, n_kv_heads, n_rep, head_dim) | 
					
					
						
						| 
							 | 
						            .reshape(bs, slen, n_kv_heads * n_rep, head_dim) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class StreamingMultiheadAttention(nn.Module): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self,  | 
					
					
						
						| 
							 | 
						                 embed_dim,  | 
					
					
						
						| 
							 | 
						                 num_heads, dropout: float = 0.0, bias: bool = True, | 
					
					
						
						| 
							 | 
						                 causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, | 
					
					
						
						| 
							 | 
						                 memory_efficient: bool = False, attention_as_float32: bool = False, | 
					
					
						
						| 
							 | 
						                 cross_attention: bool = False, | 
					
					
						
						| 
							 | 
						                 kv_repeat: int = 1, | 
					
					
						
						| 
							 | 
						                 device=None, dtype=None): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        factory_kwargs = {'device': device, 'dtype': dtype} | 
					
					
						
						| 
							 | 
						        if past_context is not None: | 
					
					
						
						| 
							 | 
						            assert causal | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.embed_dim = embed_dim | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.k_history = None   | 
					
					
						
						| 
							 | 
						        self.v_history = None   | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.memory_efficient = memory_efficient | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.cross_attention = cross_attention | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.num_heads = num_heads | 
					
					
						
						| 
							 | 
						        self.dropout = dropout | 
					
					
						
						| 
							 | 
						        self.kv_repeat = kv_repeat | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.custom = True  | 
					
					
						
						| 
							 | 
						        if not self.custom: | 
					
					
						
						| 
							 | 
						            print(f'{self.custom}') | 
					
					
						
						| 
							 | 
						        if self.custom: | 
					
					
						
						| 
							 | 
						            out_dim = embed_dim | 
					
					
						
						| 
							 | 
						            assert num_heads % kv_repeat == 0 | 
					
					
						
						| 
							 | 
						            assert not cross_attention or kv_repeat == 1 | 
					
					
						
						| 
							 | 
						            num_kv = num_heads // kv_repeat | 
					
					
						
						| 
							 | 
						            kv_dim = (embed_dim // num_heads) * num_kv | 
					
					
						
						| 
							 | 
						            out_dim += 2 * kv_dim | 
					
					
						
						| 
							 | 
						            in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.in_proj_weight = in_proj.weight | 
					
					
						
						| 
							 | 
						            self.in_proj_bias = in_proj.bias | 
					
					
						
						| 
							 | 
						            if bias: | 
					
					
						
						| 
							 | 
						                self.in_proj_bias.data.zero_()   | 
					
					
						
						| 
							 | 
						            self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) | 
					
					
						
						| 
							 | 
						            if bias: | 
					
					
						
						| 
							 | 
						                self.out_proj.bias.data.zero_() | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            assert kv_repeat == 1 | 
					
					
						
						| 
							 | 
						            self.mha = nn.MultiheadAttention( | 
					
					
						
						| 
							 | 
						                embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, | 
					
					
						
						| 
							 | 
						                **factory_kwargs) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | 
					
					
						
						| 
							 | 
						        if not self.custom: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            keys = [n for n, _ in self.mha.named_parameters()] | 
					
					
						
						| 
							 | 
						            for key in keys: | 
					
					
						
						| 
							 | 
						                if prefix + key in state_dict: | 
					
					
						
						| 
							 | 
						                    state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) | 
					
					
						
						| 
							 | 
						        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self,  | 
					
					
						
						| 
							 | 
						                query,  | 
					
					
						
						| 
							 | 
						                key=None,    | 
					
					
						
						| 
							 | 
						                value=None): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        layout = "b h t d" | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.custom: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if self.cross_attention: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                dim = self.in_proj_weight.shape[0] // 3 | 
					
					
						
						| 
							 | 
						                if self.in_proj_bias is None: | 
					
					
						
						| 
							 | 
						                    bias_q, bias_k, bias_v = None, None, None | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    bias_q = self.in_proj_bias[:dim] | 
					
					
						
						| 
							 | 
						                    bias_k = self.in_proj_bias[dim: 2 * dim] | 
					
					
						
						| 
							 | 
						                    bias_v = self.in_proj_bias[2 * dim:] | 
					
					
						
						| 
							 | 
						                q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) | 
					
					
						
						| 
							 | 
						                v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						                projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) | 
					
					
						
						| 
							 | 
						                if self.kv_repeat == 1: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    bound_layout = "b h p t d" | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) | 
					
					
						
						| 
							 | 
						                    q, k, v = ops.unbind(packed, dim=2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                if self.k_history is not None: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    self.k_history = torch.cat([self.k_history, k], 2) | 
					
					
						
						| 
							 | 
						                    self.v_history = torch.cat([self.v_history, v], 2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    print(f'else skip') | 
					
					
						
						| 
							 | 
						                    self.k_history = k | 
					
					
						
						| 
							 | 
						                    self.v_history = v     | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                k = self.k_history | 
					
					
						
						| 
							 | 
						                v = self.v_history | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if self.memory_efficient: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                p = self.dropout if self.training else 0 | 
					
					
						
						| 
							 | 
						                if _efficient_attention_backend == 'torch': | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    x = torch.nn.functional.scaled_dot_product_attention( | 
					
					
						
						| 
							 | 
						                        q, k, v, is_causal=False, dropout_p=p | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            x = x.to(q.dtype) | 
					
					
						
						| 
							 | 
						            x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) | 
					
					
						
						| 
							 | 
						            x = self.out_proj(x) | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class StreamingTransformerLayer(nn.Module):  | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self,  | 
					
					
						
						| 
							 | 
						                 d_model: int,  | 
					
					
						
						| 
							 | 
						                 num_heads: int,  | 
					
					
						
						| 
							 | 
						                 dim_feedforward: int = 2048,  | 
					
					
						
						| 
							 | 
						                 dropout: float = 0.1, | 
					
					
						
						| 
							 | 
						                 bias_ff: bool = True,  | 
					
					
						
						| 
							 | 
						                 bias_attn: bool = True,  | 
					
					
						
						| 
							 | 
						                 custom: bool = False, | 
					
					
						
						| 
							 | 
						                 memory_efficient: bool = False,  | 
					
					
						
						| 
							 | 
						                 attention_as_float32: bool = False, | 
					
					
						
						| 
							 | 
						                 cross_attention: bool = False,  | 
					
					
						
						| 
							 | 
						                 attention_dropout: tp.Optional[float] = None, | 
					
					
						
						| 
							 | 
						                 kv_repeat: int = 1, | 
					
					
						
						| 
							 | 
						                 norm: str = 'layer_norm',  | 
					
					
						
						| 
							 | 
						                 device=None, | 
					
					
						
						| 
							 | 
						                 dtype=None,  | 
					
					
						
						| 
							 | 
						                 **kwargs): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        super().__init__()  | 
					
					
						
						| 
							 | 
						                          | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        factory_kwargs = {'device': device, 'dtype': dtype} | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        attn_kwargs: tp.Dict[str, tp.Any] = { | 
					
					
						
						| 
							 | 
						            'embed_dim': d_model, | 
					
					
						
						| 
							 | 
						            'num_heads': num_heads, | 
					
					
						
						| 
							 | 
						            'dropout': dropout if attention_dropout is None else attention_dropout, | 
					
					
						
						| 
							 | 
						            'bias': bias_attn, | 
					
					
						
						| 
							 | 
						            'custom': custom, | 
					
					
						
						| 
							 | 
						            'memory_efficient': memory_efficient, | 
					
					
						
						| 
							 | 
						            'attention_as_float32': attention_as_float32, | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						        self.self_attn = StreamingMultiheadAttention( | 
					
					
						
						| 
							 | 
						            kv_repeat=kv_repeat,  | 
					
					
						
						| 
							 | 
						            **attn_kwargs,  | 
					
					
						
						| 
							 | 
						            **factory_kwargs)   | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) | 
					
					
						
						| 
							 | 
						        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.cross_attention= None | 
					
					
						
						| 
							 | 
						        if cross_attention: | 
					
					
						
						| 
							 | 
						            self.cross_attention = StreamingMultiheadAttention( | 
					
					
						
						| 
							 | 
						                cross_attention=True, | 
					
					
						
						| 
							 | 
						                **attn_kwargs,  | 
					
					
						
						| 
							 | 
						                **factory_kwargs) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.dropout_cross = nn.Dropout(dropout) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)         | 
					
					
						
						| 
							 | 
						        self.norm1 = nn.LayerNorm(d_model, eps=1e-5) | 
					
					
						
						| 
							 | 
						        self.norm2 = nn.LayerNorm(d_model, eps=1e-5) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, | 
					
					
						
						| 
							 | 
						                src, | 
					
					
						
						| 
							 | 
						                cross_attention_src=None):   | 
					
					
						
						| 
							 | 
						        '''T is saved float16 weights - should we cast src to float16''' | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        x = src | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        x = x + self.self_attn(self.norm1(x)) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if cross_attention_src is not None: | 
					
					
						
						| 
							 | 
						            x = x + self.cross_attention( | 
					
					
						
						| 
							 | 
						                                    query = self.norm_cross(x),  | 
					
					
						
						| 
							 | 
						                                    key   = cross_attention_src,  | 
					
					
						
						| 
							 | 
						                                    value = cross_attention_src)   | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        x = x + self.linear2(F.gelu(self.linear1(   self.norm2(x)    ))) | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class StreamingTransformer(nn.Module): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, d_model: int,  | 
					
					
						
						| 
							 | 
						                 num_heads: int,  | 
					
					
						
						| 
							 | 
						                 num_layers: int,  | 
					
					
						
						| 
							 | 
						                 dim_feedforward: int = 2048, | 
					
					
						
						| 
							 | 
						                 dropout: float = 0.1,  | 
					
					
						
						| 
							 | 
						                 bias_ff: bool = True,  | 
					
					
						
						| 
							 | 
						                 bias_attn: bool = True, | 
					
					
						
						| 
							 | 
						                 custom: bool = False,  | 
					
					
						
						| 
							 | 
						                 memory_efficient: bool = False,  | 
					
					
						
						| 
							 | 
						                 attention_as_float32: bool = False, | 
					
					
						
						| 
							 | 
						                 cross_attention: bool = False, | 
					
					
						
						| 
							 | 
						                 positional_embedding: str = 'sin',  | 
					
					
						
						| 
							 | 
						                 max_period: float = 10_000, | 
					
					
						
						| 
							 | 
						                 layer_class=StreamingTransformerLayer, | 
					
					
						
						| 
							 | 
						                 checkpointing: str = 'none',  | 
					
					
						
						| 
							 | 
						                 device=None,  | 
					
					
						
						| 
							 | 
						                 dtype=None,  | 
					
					
						
						| 
							 | 
						                 **kwargs): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        assert d_model % num_heads == 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.positional_embedding = positional_embedding | 
					
					
						
						| 
							 | 
						        self.max_period = max_period | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.checkpointing = checkpointing | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.layers = nn.ModuleList() | 
					
					
						
						| 
							 | 
						        for idx in range(num_layers): | 
					
					
						
						| 
							 | 
						            self.layers.append( | 
					
					
						
						| 
							 | 
						                layer_class( | 
					
					
						
						| 
							 | 
						                    d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, | 
					
					
						
						| 
							 | 
						                    dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, | 
					
					
						
						| 
							 | 
						                    custom=custom, | 
					
					
						
						| 
							 | 
						                    memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, | 
					
					
						
						| 
							 | 
						                    cross_attention=cross_attention, | 
					
					
						
						| 
							 | 
						                    device=device, dtype=dtype, **kwargs)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.checkpointing != 'none': | 
					
					
						
						| 
							 | 
						            for layer in self.layers: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                layer._magma_checkpointed = True   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor, *args, **kwargs): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        B, T, C = x.shape | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.positional_embedding in ['sin', 'sin_rope']: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            positions = torch.arange(T, device=x.device).view(1, -1, 1) | 
					
					
						
						| 
							 | 
						            positions = positions + kwargs['token_count']   | 
					
					
						
						| 
							 | 
						            pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) | 
					
					
						
						| 
							 | 
						            x = x + pos_emb | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for j, lay in enumerate(self.layers): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            x = lay(x, cross_attention_src=kwargs["cross_attention_src"])   | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 |