Spaces:
Starting
on
T4
Starting
on
T4
from functools import partial | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from celle.reversible import SequentialSequence | |
from celle.attention import Attention | |
from rotary_embedding_torch import RotaryEmbedding, broadcat | |
from celle.utils import exists, default, cast_tuple | |
# https://arxiv.org/abs/2103.17239 | |
class LayerScale(nn.Module): | |
def __init__(self, dim, depth, fn): | |
super().__init__() | |
if depth <= 18: | |
init_eps = 0.1 | |
elif depth > 18 and depth <= 24: | |
init_eps = 1e-5 | |
else: | |
init_eps = 1e-6 | |
scale = torch.zeros(1, 1, dim).fill_(init_eps) | |
self.scale = nn.Parameter(scale) | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
return self.fn(x, **kwargs) * self.scale | |
# layer norm | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.norm_out = nn.Identity() | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
x = self.norm(x) | |
x = self.fn(x, **kwargs) | |
return self.norm_out(x) | |
# feed forward | |
class GEGLU(nn.Module): | |
def forward(self, x): | |
x, gates = x.chunk(2, dim=-1) | |
return x * F.gelu(gates) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, dropout=0.0, mult=4.0): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, dim * mult * 2), | |
GEGLU(), | |
nn.Dropout(dropout), | |
nn.Linear(dim * mult, dim), | |
) | |
def forward(self, x): | |
return self.net(x) | |
# main transformer class | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
depth, | |
seq_len, | |
causal=True, | |
heads=8, | |
dim_head=64, | |
ff_mult=4, | |
attn_dropout=0.0, | |
ff_dropout=0.0, | |
image_fmap_size=None, | |
num_images=None, | |
stable=False, | |
rotary_emb=True, | |
): | |
super().__init__() | |
layers = nn.ModuleList([]) | |
self.seq_len = seq_len | |
self.image_fmap_size = image_fmap_size | |
for ind in range(depth): | |
attn_class = partial(Attention, stable=stable) | |
attn = attn_class( | |
dim, | |
causal=causal, | |
seq_len=seq_len, | |
heads=heads, | |
dim_head=dim_head, | |
dropout=attn_dropout, | |
) | |
ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) | |
layers.append( | |
nn.ModuleList( | |
[ | |
LayerScale( | |
dim, ind + 1, PreNorm(dim, attn) | |
), | |
LayerScale( | |
dim, ind + 1, PreNorm(dim, ff) | |
), | |
] | |
) | |
) | |
# pairs arguments with attention layer | |
route_attn = ((True, False),) * depth | |
attn_route_map = { | |
"mask": route_attn, | |
"rotary_pos_emb": route_attn, | |
} | |
self.layers = SequentialSequence(layers, args_route=attn_route_map) | |
# generate positional embeddings for rotary | |
pos_emb = None | |
if rotary_emb: | |
rot_dim = dim_head // 3 | |
img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images | |
text_len = seq_len - img_seq_len + 1 | |
text_pos_emb = RotaryEmbedding(dim=rot_dim) | |
img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel") | |
text_freqs = text_pos_emb(torch.arange(text_len)) | |
img_to_text_freqs = text_pos_emb( | |
torch.full((img_seq_len,), 8192) | |
) # image is given a position far away from text | |
text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0) | |
img_freqs_axial = img_axial_pos_emb( | |
torch.linspace(-1, 1, steps=image_fmap_size) | |
) | |
if num_images > 1: | |
split_img_freqs_axial = torch.split( | |
img_freqs_axial, image_fmap_size // num_images, dim=0 | |
) | |
split_img_freqs = [ | |
broadcat( | |
( | |
rearrange(img_freqs_axial_per_image, "i d -> i () d"), | |
rearrange(img_freqs_axial_per_image, "j d -> () j d"), | |
), | |
dim=-1, | |
) | |
for img_freqs_axial_per_image in split_img_freqs_axial | |
] | |
split_img_freqs = [ | |
rearrange(img_freqs_per_image, "h w d -> (h w) d") | |
for img_freqs_per_image in split_img_freqs | |
] | |
# concat per image-image_freqs | |
img_freqs = torch.cat(split_img_freqs, dim=0) | |
elif num_images == 1: | |
img_freqs = broadcat( | |
( | |
rearrange(img_freqs_axial, "i d -> i () d"), | |
rearrange(img_freqs_axial, "j d -> () j d"), | |
), | |
dim=-1, | |
) | |
img_freqs = rearrange(img_freqs, "h w d -> (h w) d") | |
else: | |
assert False, "num_images must be int greater than 0" | |
self.img_axial_pos_emb = img_axial_pos_emb | |
self.text_pos_emb = text_pos_emb | |
text_axial_freqs = img_axial_pos_emb( | |
torch.full((text_len,), -10.0) | |
) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] | |
text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1) | |
img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0) | |
pos_emb = torch.cat((text_freqs, img_freqs), dim=-1) | |
pos_emb = rearrange(pos_emb, "n d -> () n d") | |
self.register_buffer("pos_emb", pos_emb) | |
def forward(self, x, **kwargs): | |
return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs) |