Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from einops import rearrange | |
| def identity(t, *args, **kwargs): | |
| return t | |
| def append_dims(x, num_dims): | |
| if num_dims <= 0: | |
| return x | |
| return x.view(*x.shape, *((1,) * num_dims)) | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def padding_to_multiple_of(n, mult): | |
| remainder = n % mult | |
| if remainder == 0: | |
| return 0 | |
| return mult - remainder | |
| class Transpose(nn.Module): | |
| """Wrapper class of torch.transpose() for Sequential module.""" | |
| def __init__(self, shape: tuple): | |
| super(Transpose, self).__init__() | |
| self.shape = shape | |
| def forward(self, x): | |
| return x.transpose(*self.shape) | |
| class DepthwiseConv1d(nn.Module): | |
| """ | |
| When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, | |
| this operation is termed in literature as depthwise convolution. | |
| Args: | |
| in_channels (int): Number of channels in the input | |
| out_channels (int): Number of channels produced by the convolution | |
| kernel_size (int or tuple): Size of the convolving kernel | |
| stride (int, optional): Stride of the convolution. Default: 1 | |
| padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
| bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
| Inputs: inputs | |
| - **inputs** (batch, in_channels, time): Tensor containing input vector | |
| Returns: outputs | |
| - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| bias: bool = False, | |
| ) -> None: | |
| super(DepthwiseConv1d, self).__init__() | |
| assert ( | |
| out_channels % in_channels == 0 | |
| ), "out_channels should be constant multiple of in_channels" | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| groups=in_channels, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| def forward(self, inputs): | |
| return self.conv(inputs) | |
| class ConvModule(nn.Module): | |
| """ | |
| Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). | |
| This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution | |
| to aid training deep models. | |
| Args: | |
| in_channels (int): Number of channels in the input | |
| kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 | |
| dropout_p (float, optional): probability of dropout | |
| Inputs: inputs | |
| inputs (batch, time, dim): Tensor contains input sequences | |
| Outputs: outputs | |
| outputs (batch, time, dim): Tensor produces by conformer convolution module. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| kernel_size: int = 17, | |
| expansion_factor: int = 2, | |
| dropout_p: float = 0.1, | |
| ) -> None: | |
| super(ConvModule, self).__init__() | |
| assert ( | |
| kernel_size - 1 | |
| ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
| assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" | |
| self.sequential = nn.Sequential( | |
| Transpose(shape=(1, 2)), | |
| DepthwiseConv1d( | |
| in_channels, | |
| in_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=(kernel_size - 1) // 2, | |
| ), | |
| ) | |
| def forward(self, inputs): | |
| return inputs + self.sequential(inputs).transpose(1, 2) | |
| class OffsetScale(nn.Module): | |
| def __init__(self, dim, heads=1): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(heads, dim)) | |
| self.beta = nn.Parameter(torch.zeros(heads, dim)) | |
| nn.init.normal_(self.gamma, std=0.02) | |
| def forward(self, x): | |
| out = einsum("... d, h d -> ... h d", x, self.gamma) + self.beta | |
| return out.unbind(dim=-2) | |
| class FFConvM(nn.Module): | |
| def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1): | |
| super().__init__() | |
| self.mdl = nn.Sequential( | |
| norm_klass(dim_in), | |
| nn.Linear(dim_in, dim_out), | |
| nn.SiLU(), | |
| ConvModule(dim_out), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward( | |
| self, | |
| x, | |
| ): | |
| output = self.mdl(x) | |
| return output | |
| class FLASH_ShareA_FFConvM(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| group_size=256, | |
| query_key_dim=128, | |
| expansion_factor=1.0, | |
| causal=False, | |
| dropout=0.1, | |
| rotary_pos_emb=None, | |
| norm_klass=nn.LayerNorm, | |
| shift_tokens=True | |
| ): | |
| super().__init__() | |
| hidden_dim = int(dim * expansion_factor) | |
| self.group_size = group_size | |
| self.causal = causal | |
| self.shift_tokens = shift_tokens | |
| # positional embeddings | |
| self.rotary_pos_emb = rotary_pos_emb | |
| # norm | |
| self.dropout = nn.Dropout(dropout) | |
| # projections | |
| self.to_hidden = FFConvM( | |
| dim_in=dim, | |
| dim_out=hidden_dim, | |
| norm_klass=norm_klass, | |
| dropout=dropout, | |
| ) | |
| self.to_qk = FFConvM( | |
| dim_in=dim, | |
| dim_out=query_key_dim, | |
| norm_klass=norm_klass, | |
| dropout=dropout, | |
| ) | |
| self.qk_offset_scale = OffsetScale(query_key_dim, heads=4) | |
| self.to_out = FFConvM( | |
| dim_in=dim * 2, | |
| dim_out=dim, | |
| norm_klass=norm_klass, | |
| dropout=dropout, | |
| ) | |
| self.gateActivate = nn.Sigmoid() | |
| def forward(self, x, *, mask=None): | |
| """ | |
| b - batch | |
| n - sequence length (within groups) | |
| g - group dimension | |
| d - feature dimension (keys) | |
| e - feature dimension (values) | |
| i - sequence dimension (source) | |
| j - sequence dimension (target) | |
| """ | |
| normed_x = x | |
| # do token shift - a great, costless trick from an independent AI researcher in Shenzhen | |
| residual = x | |
| if self.shift_tokens: | |
| x_shift, x_pass = normed_x.chunk(2, dim=-1) | |
| x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.0) | |
| normed_x = torch.cat((x_shift, x_pass), dim=-1) | |
| # initial projections | |
| v, u = self.to_hidden(normed_x).chunk(2, dim=-1) | |
| qk = self.to_qk(normed_x) | |
| # offset and scale | |
| quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) | |
| att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u) | |
| out = (att_u * v) * self.gateActivate(att_v * u) | |
| x = x + self.to_out(out) | |
| return x | |
| def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None): | |
| b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size | |
| if exists(mask): | |
| lin_mask = rearrange(mask, "... -> ... 1") | |
| lin_k = lin_k.masked_fill(~lin_mask, 0.0) | |
| # rotate queries and keys | |
| if exists(self.rotary_pos_emb): | |
| quad_q, lin_q, quad_k, lin_k = map( | |
| self.rotary_pos_emb.rotate_queries_or_keys, | |
| (quad_q, lin_q, quad_k, lin_k), | |
| ) | |
| # padding for groups | |
| padding = padding_to_multiple_of(n, g) | |
| if padding > 0: | |
| quad_q, quad_k, lin_q, lin_k, v, u = map( | |
| lambda t: F.pad(t, (0, 0, 0, padding), value=0.0), | |
| (quad_q, quad_k, lin_q, lin_k, v, u), | |
| ) | |
| mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool)) | |
| mask = F.pad(mask, (0, padding), value=False) | |
| # group along sequence | |
| quad_q, quad_k, lin_q, lin_k, v, u = map( | |
| lambda t: rearrange(t, "b (g n) d -> b g n d", n=self.group_size), | |
| (quad_q, quad_k, lin_q, lin_k, v, u), | |
| ) | |
| if exists(mask): | |
| mask = rearrange(mask, "b (g j) -> b g 1 j", j=g) | |
| # calculate quadratic attention output | |
| sim = einsum("... i d, ... j d -> ... i j", quad_q, quad_k) / g | |
| attn = F.relu(sim) ** 2 | |
| attn = self.dropout(attn) | |
| if exists(mask): | |
| attn = attn.masked_fill(~mask, 0.0) | |
| if self.causal: | |
| causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1) | |
| attn = attn.masked_fill(causal_mask, 0.0) | |
| quad_out_v = einsum("... i j, ... j d -> ... i d", attn, v) | |
| quad_out_u = einsum("... i j, ... j d -> ... i d", attn, u) | |
| # calculate linear attention output | |
| if self.causal: | |
| lin_kv = einsum("b g n d, b g n e -> b g d e", lin_k, v) / g | |
| # exclusive cumulative sum along group dimension | |
| lin_kv = lin_kv.cumsum(dim=1) | |
| lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.0) | |
| lin_out_v = einsum("b g d e, b g n d -> b g n e", lin_kv, lin_q) | |
| lin_ku = einsum("b g n d, b g n e -> b g d e", lin_k, u) / g | |
| # exclusive cumulative sum along group dimension | |
| lin_ku = lin_ku.cumsum(dim=1) | |
| lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.0) | |
| lin_out_u = einsum("b g d e, b g n d -> b g n e", lin_ku, lin_q) | |
| else: | |
| lin_kv = einsum("b g n d, b g n e -> b d e", lin_k, v) / n | |
| lin_out_v = einsum("b g n d, b d e -> b g n e", lin_q, lin_kv) | |
| lin_ku = einsum("b g n d, b g n e -> b d e", lin_k, u) / n | |
| lin_out_u = einsum("b g n d, b d e -> b g n e", lin_q, lin_ku) | |
| # fold back groups into full sequence, and excise out padding | |
| return map( | |
| lambda t: rearrange(t, "b g n d -> b (g n) d")[:, :n], | |
| (quad_out_v + lin_out_v, quad_out_u + lin_out_u), | |
| ) | |