Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,812 Bytes
be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c be88838 0a1370c fd38570 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fc1 = nn.Linear(dim, int(dim * mult))
self.act = nn.GELU()
self.fc2 = nn.Linear(int(dim * mult), dim)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
def forward(self, x):
return x + self.fc2(self.act(self.fc1(self.norm(x))))
def reshape_tensor(x, heads):
bs, length, _ = x.shape
return x.view(bs, length, heads, -1).transpose(1, 2)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim)
self.to_kv = nn.Linear(dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, dim)
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_kv.weight)
nn.init.xavier_uniform_(self.to_out.weight)
def forward(self, x, latents):
x = self.norm1(x)
latents = self.norm2(latents)
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = map(lambda t: reshape_tensor(t, self.heads), (q, k, v))
attn_score = (q @ k.transpose(-2, -1)) * self.scale
attn_weight = F.softmax(attn_score, dim=-1)
out = (attn_weight @ v).transpose(1, 2).reshape(latents.shape)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(self, dim=1024, depth=8, dim_head=64, heads=16, num_queries=8, embedding_dim=768, output_dim=1024, ff_mult=4):
super().__init__()
self.latents = nn.Parameter(torch.empty(1, num_queries, dim))
nn.init.normal_(self.latents, mean=0, std=dim**-0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]) for _ in range(depth)
])
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
return self.norm_out(self.proj_out(latents)) |