sonalkum's picture
stable
9172422
raw
history blame
8.29 kB
import torch
from einops import rearrange
from torch import nn
from .blocks import AdaRMSNorm
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
class ContinuousLocalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_in = None,
dim_out = None,
causal = False,
local_attn_window_size = 64,
heads = 8,
ff_mult = 2,
cond_dim = 0,
cross_attn_cond_dim = 0,
**kwargs
):
super().__init__()
dim_head = dim//heads
self.layers = nn.ModuleList([])
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
self.local_attn_window_size = local_attn_window_size
self.cond_dim = cond_dim
self.cross_attn_cond_dim = cross_attn_cond_dim
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
for _ in range(depth):
self.layers.append(nn.ModuleList([
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
Attention(
dim=dim,
dim_heads=dim_head,
causal=causal,
zero_init_output=True,
natten_kernel_size=local_attn_window_size,
),
Attention(
dim=dim,
dim_heads=dim_head,
dim_context = cross_attn_cond_dim,
zero_init_output=True
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
]))
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
x = checkpoint(self.project_in, x)
if prepend_cond is not None:
x = torch.cat([prepend_cond, x], dim=1)
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
residual = x
if cond is not None:
x = checkpoint(attn_norm, x, cond)
else:
x = checkpoint(attn_norm, x)
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
if cross_attn_cond is not None:
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
residual = x
if cond is not None:
x = checkpoint(ff_norm, x, cond)
else:
x = checkpoint(ff_norm, x)
x = checkpoint(ff, x) + residual
return checkpoint(self.project_out, x)
class TransformerDownsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim = 768,
depth = 3,
heads = 12,
downsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.downsample_ratio = downsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size=local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
def forward(self, x):
x = checkpoint(self.project_in, x)
# Compute
x = self.transformer(x)
# Trade sequence length for channels
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
# Project back to embed dim
x = checkpoint(self.project_down, x)
return x
class TransformerUpsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim,
depth = 3,
heads = 12,
upsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.upsample_ratio = upsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size = local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
def forward(self, x):
# Project to embed dim
x = checkpoint(self.project_in, x)
# Project to increase channel dim
x = checkpoint(self.project_up, x)
# Trade channels for sequence length
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
# Compute
x = self.transformer(x)
return x
class TransformerEncoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [96, 192, 384, 768],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerDownsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
downsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
class TransformerDecoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [768, 384, 192, 96],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerUpsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
upsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x