Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from inspect import isfunction | |
| from einops import rearrange, repeat | |
| import xformers.ops as xops | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| class CrossAttention(nn.Module): | |
| def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| context_dim = default(context_dim, query_dim) | |
| self.heads = heads | |
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
| self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
| self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
| self.to_out = nn.Sequential( | |
| nn.Linear(inner_dim, query_dim, bias=False), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x, context=None, mask=None): | |
| h = self.heads | |
| q = self.to_q(x) | |
| context = default(context, x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
| out = xops.memory_efficient_attention(q, k, v) | |
| out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | |
| return self.to_out(out) | |
| class BasicTransformerBlock(nn.Module): | |
| def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True): | |
| super().__init__() | |
| self.self_attn = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) | |
| self.ff = nn.Sequential( | |
| nn.Linear(dim, dim*4, bias=False), | |
| nn.GELU(), | |
| nn.Linear(dim*4, dim, bias=False), | |
| ) | |
| self.norm1 = nn.LayerNorm(dim, bias=False) | |
| self.norm2 = nn.LayerNorm(dim, bias=False) | |
| def forward(self, x, context=None): | |
| before_sa = self.norm1(x) | |
| x = x + self.self_attn(before_sa) | |
| x = self.ff(self.norm2(x)) + x | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| image_size=512, | |
| patch_size=8, | |
| input_dim=3, | |
| inner_dim=1024, | |
| output_dim=14, | |
| n_heads=16, | |
| depth=24, | |
| dropout=0., | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.input_dim = input_dim | |
| self.inner_dim = inner_dim | |
| self.output_dim = output_dim | |
| self.patchify = nn.Conv2d(input_dim, inner_dim, kernel_size=patch_size, stride=patch_size, padding=0, bias=False) | |
| num_patches = (image_size // patch_size) ** 2 | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, inner_dim)) | |
| self.ref_embed = nn.Parameter(torch.zeros(1, 1, inner_dim)) | |
| self.src_embed = nn.Parameter(torch.zeros(1, 1, inner_dim)) | |
| self.blocks = nn.ModuleList( | |
| [BasicTransformerBlock(inner_dim, n_heads, inner_dim//n_heads, dropout=dropout) | |
| for _ in range(depth)] | |
| ) | |
| self.norm = nn.LayerNorm(inner_dim, bias=False) | |
| self.unpatchify = nn.Linear(inner_dim, patch_size ** 2 * output_dim, bias=True) | |
| nn.init.trunc_normal_(self.pos_embed, std=.02) | |
| nn.init.trunc_normal_(self.ref_embed, std=.02) | |
| nn.init.trunc_normal_(self.src_embed, std=.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.trunc_normal_(m.weight, std=.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1.0) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def interpolate_pos_encoding(self, x, w, h): | |
| npatch = x.shape[-2] | |
| N = self.pos_embed.shape[-2] | |
| if npatch == N and w == h: | |
| return self.pos_embed | |
| patch_pos_embed = self.pos_embed | |
| dim = x.shape[-1] | |
| w0 = w // self.patch_size | |
| h0 = h // self.patch_size | |
| # we add a small number to avoid floating point error in the interpolation | |
| # see discussion at https://github.com/facebookresearch/dino/issues/8 | |
| w0, h0 = w0 + 0.1, h0 + 0.1 | |
| patch_pos_embed = F.interpolate( | |
| patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2).contiguous(), | |
| scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), | |
| mode='bicubic', | |
| ) | |
| assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim).contiguous() | |
| return patch_pos_embed | |
| def forward(self, images): | |
| """ | |
| images: (B, N, C, H, W) | |
| """ | |
| B, N, _, H, W = images.shape | |
| # patchify | |
| images = rearrange(images, 'b n c h w -> (b n) c h w') | |
| tokens = self.patchify(images) | |
| tokens = rearrange(tokens, 'bn c h w -> bn (h w) c') | |
| # add pos encodings | |
| tokens = rearrange(tokens, '(b n) hw c -> b n hw c', b=B) | |
| tokens = tokens + self.interpolate_pos_encoding(tokens, W, H).unsqueeze(1) | |
| view_embeds = torch.cat([self.ref_embed, self.src_embed.repeat(1, N-1, 1)], dim=1) | |
| tokens = tokens + view_embeds.unsqueeze(2) | |
| # tokens = rearrange(tokens, '(b n) hw c -> b n hw c', b=B) | |
| # tokens = tokens + self.interpolate_pos_encoding(tokens, W, H).unsqueeze(1) | |
| # view_embeds = self.src_embed.repeat(1, N, 1) | |
| # view_embeds[:, 0:1] = torch.zeros_like(self.ref_embed) | |
| # tokens = tokens + view_embeds.unsqueeze(2) | |
| # transformer | |
| tokens = rearrange(tokens, 'b n hw c -> b (n hw) c') | |
| x = tokens | |
| for layer in self.blocks: | |
| x = layer(x) | |
| # unpatchify | |
| x = self.norm(x) | |
| x = self.unpatchify(x) | |
| x = rearrange(x, 'b (n h w) c -> b n h w c', n=N, h=H//self.patch_size, w=W//self.patch_size) | |
| x = rearrange(x, 'b n h w (p q c) -> b n (h p) (w q) c', p=self.patch_size, q=self.patch_size) | |
| out = x | |
| return out | |