Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,289 Bytes
9172422 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
import torch
from einops import rearrange
from torch import nn
from .blocks import AdaRMSNorm
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
class ContinuousLocalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_in = None,
dim_out = None,
causal = False,
local_attn_window_size = 64,
heads = 8,
ff_mult = 2,
cond_dim = 0,
cross_attn_cond_dim = 0,
**kwargs
):
super().__init__()
dim_head = dim//heads
self.layers = nn.ModuleList([])
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
self.local_attn_window_size = local_attn_window_size
self.cond_dim = cond_dim
self.cross_attn_cond_dim = cross_attn_cond_dim
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
for _ in range(depth):
self.layers.append(nn.ModuleList([
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
Attention(
dim=dim,
dim_heads=dim_head,
causal=causal,
zero_init_output=True,
natten_kernel_size=local_attn_window_size,
),
Attention(
dim=dim,
dim_heads=dim_head,
dim_context = cross_attn_cond_dim,
zero_init_output=True
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
]))
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
x = checkpoint(self.project_in, x)
if prepend_cond is not None:
x = torch.cat([prepend_cond, x], dim=1)
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
residual = x
if cond is not None:
x = checkpoint(attn_norm, x, cond)
else:
x = checkpoint(attn_norm, x)
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
if cross_attn_cond is not None:
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
residual = x
if cond is not None:
x = checkpoint(ff_norm, x, cond)
else:
x = checkpoint(ff_norm, x)
x = checkpoint(ff, x) + residual
return checkpoint(self.project_out, x)
class TransformerDownsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim = 768,
depth = 3,
heads = 12,
downsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.downsample_ratio = downsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size=local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
def forward(self, x):
x = checkpoint(self.project_in, x)
# Compute
x = self.transformer(x)
# Trade sequence length for channels
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
# Project back to embed dim
x = checkpoint(self.project_down, x)
return x
class TransformerUpsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim,
depth = 3,
heads = 12,
upsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.upsample_ratio = upsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size = local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
def forward(self, x):
# Project to embed dim
x = checkpoint(self.project_in, x)
# Project to increase channel dim
x = checkpoint(self.project_up, x)
# Trade channels for sequence length
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
# Compute
x = self.transformer(x)
return x
class TransformerEncoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [96, 192, 384, 768],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerDownsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
downsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
class TransformerDecoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [768, 384, 192, 96],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerUpsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
upsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x |