|
from functools import wraps |
|
from math import log, pi |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange, Reduce |
|
from torch import einsum, nn |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def cache_fn(f): |
|
cache = dict() |
|
|
|
@wraps(f) |
|
def cached_fn(*args, _cache=True, key=None, **kwargs): |
|
if not _cache: |
|
return f(*args, **kwargs) |
|
nonlocal cache |
|
if key in cache: |
|
return cache[key] |
|
result = f(*args, **kwargs) |
|
cache[key] = result |
|
return result |
|
|
|
return cached_fn |
|
|
|
|
|
def fourier_encode(x, max_freq, num_bands=4): |
|
x = x.unsqueeze(-1) |
|
device, dtype, orig_x = x.device, x.dtype, x |
|
|
|
scales = torch.linspace(1.0, max_freq / 2, num_bands, device=device, dtype=dtype) |
|
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] |
|
|
|
x = x * scales * pi |
|
x = torch.cat([x.sin(), x.cos()], dim=-1) |
|
x = torch.cat((x, orig_x), dim=-1) |
|
return x |
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn, context_dim=None): |
|
super().__init__() |
|
self.fn = fn |
|
self.norm = nn.LayerNorm(dim) |
|
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None |
|
|
|
def forward(self, x, **kwargs): |
|
x = self.norm(x) |
|
|
|
if exists(self.norm_context): |
|
context = kwargs["context"] |
|
normed_context = self.norm_context(context) |
|
kwargs.update(context=normed_context) |
|
|
|
return self.fn(x, **kwargs) |
|
|
|
|
|
class GEGLU(nn.Module): |
|
def forward(self, x): |
|
x, gates = x.chunk(2, dim=-1) |
|
return x * F.gelu(gates) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, mult=4, dropout=0.0): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, dim * mult * 2), |
|
GEGLU(), |
|
nn.Linear(dim * mult, dim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
context_dim = default(context_dim, query_dim) |
|
|
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.to_out = nn.Linear(inner_dim, query_dim) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
h = self.heads |
|
|
|
q = self.to_q(x) |
|
context = default(context, x) |
|
k, v = self.to_kv(context).chunk(2, dim=-1) |
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) |
|
|
|
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
|
|
|
if exists(mask): |
|
mask = rearrange(mask, "b ... -> b (...)") |
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
mask = repeat(mask, "b j -> (b h) () j", h=h) |
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
|
|
attn = sim.softmax(dim=-1) |
|
attn = self.dropout(attn) |
|
|
|
out = einsum("b i j, b j d -> b i d", attn, v) |
|
out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
|
return self.to_out(out) |
|
|
|
|
|
class Perceiver(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
num_freq_bands, |
|
depth, |
|
max_freq, |
|
input_channels=3, |
|
input_axis=2, |
|
num_latents=512, |
|
latent_dim=512, |
|
cross_heads=1, |
|
latent_heads=8, |
|
cross_dim_head=64, |
|
latent_dim_head=64, |
|
num_classes=1000, |
|
attn_dropout=0.0, |
|
ff_dropout=0.0, |
|
weight_tie_layers=False, |
|
fourier_encode_data=True, |
|
self_per_cross_attn=1, |
|
final_classifier_head=True, |
|
pool="mean", |
|
latent_init=None, |
|
): |
|
"""The shape of the final attention mechanism will be: |
|
depth * (cross attention -> self_per_cross_attn * self attention) |
|
|
|
Args: |
|
num_freq_bands: Number of freq bands, with original value (2 * K + 1) |
|
depth: Depth of net. |
|
max_freq: Maximum frequency, hyperparameter depending on how |
|
fine the data is. |
|
freq_base: Base for the frequency |
|
input_channels: Number of channels for each token of the input. |
|
input_axis: Number of axes for input data (2 for images, 3 for video) |
|
num_latents: Number of latents, or induced set points, or centroids. |
|
Different papers giving it different names. |
|
latent_dim: Latent dimension. |
|
cross_heads: Number of heads for cross attention. Paper said 1. |
|
latent_heads: Number of heads for latent self attention, 8. |
|
cross_dim_head: Number of dimensions per cross attention head. |
|
latent_dim_head: Number of dimensions per latent self attention head. |
|
num_classes: Output number of classes. |
|
attn_dropout: Attention dropout |
|
ff_dropout: Feedforward dropout |
|
weight_tie_layers: Whether to weight tie layers (optional). |
|
fourier_encode_data: Whether to auto-fourier encode the data, using |
|
the input_axis given. defaults to True, but can be turned off |
|
if you are fourier encoding the data yourself. |
|
self_per_cross_attn: Number of self attention blocks per cross attn. |
|
final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end |
|
""" |
|
super().__init__() |
|
self.input_axis = input_axis |
|
self.max_freq = max_freq |
|
self.num_freq_bands = num_freq_bands |
|
self.self_per_cross_attn = self_per_cross_attn |
|
|
|
self.fourier_encode_data = fourier_encode_data |
|
fourier_channels = ( |
|
(input_axis * ((num_freq_bands * 2) + 1)) * 2 if fourier_encode_data else 0 |
|
) |
|
input_dim = fourier_channels + input_channels |
|
|
|
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) |
|
if latent_init is not None: |
|
latent_init_feat = torch.load(latent_init) |
|
if type(latent_init_feat) != torch.Tensor: |
|
latent_init_feat = torch.Tensor(latent_init_feat) |
|
if len(latent_init_feat.shape) == 3: |
|
latent_init_feat = latent_init_feat[0] |
|
with torch.no_grad(): |
|
self.latents.copy_(latent_init_feat) |
|
print(f"load latent feature: , {latent_init}") |
|
|
|
get_cross_attn = lambda: PreNorm( |
|
latent_dim, |
|
Attention( |
|
latent_dim, |
|
input_dim, |
|
heads=cross_heads, |
|
dim_head=cross_dim_head, |
|
dropout=attn_dropout, |
|
), |
|
context_dim=input_dim, |
|
) |
|
get_cross_ff = lambda: PreNorm( |
|
latent_dim, FeedForward(latent_dim, dropout=ff_dropout) |
|
) |
|
get_latent_attn = lambda: PreNorm( |
|
latent_dim, |
|
Attention( |
|
latent_dim, |
|
heads=latent_heads, |
|
dim_head=latent_dim_head, |
|
dropout=attn_dropout, |
|
), |
|
) |
|
get_latent_ff = lambda: PreNorm( |
|
latent_dim, FeedForward(latent_dim, dropout=ff_dropout) |
|
) |
|
|
|
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map( |
|
cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff) |
|
) |
|
|
|
self.layers = nn.ModuleList([]) |
|
for i in range(depth): |
|
should_cache = i > 0 and weight_tie_layers |
|
cache_args = {"_cache": should_cache} |
|
|
|
self_attns = nn.ModuleList([]) |
|
|
|
for block_ind in range(self_per_cross_attn): |
|
self_attns.append( |
|
nn.ModuleList( |
|
[ |
|
get_latent_attn(**cache_args, key=block_ind), |
|
get_latent_ff(**cache_args, key=block_ind), |
|
] |
|
) |
|
) |
|
if self_per_cross_attn == 0: |
|
self_attns.append(get_latent_ff(**cache_args, key=block_ind)) |
|
|
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
get_cross_attn(**cache_args), |
|
get_cross_ff(**cache_args), |
|
self_attns, |
|
] |
|
) |
|
) |
|
|
|
if final_classifier_head: |
|
if pool == "cat": |
|
self.to_logits = nn.Sequential( |
|
Rearrange("b n d -> b (n d)"), |
|
nn.LayerNorm(num_latents * latent_dim), |
|
nn.Linear(num_latents * latent_dim, num_classes), |
|
) |
|
elif pool == "mlp": |
|
self.to_logits = nn.Sequential( |
|
Reduce("b n d -> b d", "mean"), |
|
nn.LayerNorm(latent_dim), |
|
nn.Linear(latent_dim, latent_dim), |
|
nn.ReLU(), |
|
nn.LayerNorm(latent_dim), |
|
nn.Linear(latent_dim, num_classes), |
|
) |
|
else: |
|
self.to_logits = nn.Sequential( |
|
Reduce("b n d -> b d", pool), |
|
nn.LayerNorm(latent_dim), |
|
nn.Linear(latent_dim, num_classes), |
|
) |
|
|
|
def forward(self, h, label=None, mask=None, pretrain=False, coords=None): |
|
b, *axis, _, device, dtype = *h.shape, h.device, h.dtype |
|
assert ( |
|
len(axis) == self.input_axis |
|
), "input data must have the right number of axis" |
|
|
|
if self.fourier_encode_data: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enc_pos = fourier_encode(coords, self.max_freq, self.num_freq_bands) |
|
enc_pos = rearrange(enc_pos, "... n d -> ... (n d)") |
|
|
|
h = torch.cat((h, enc_pos), dim=-1) |
|
|
|
|
|
|
|
h = rearrange(h, "b ... d -> b (...) d") |
|
|
|
x = repeat(self.latents, "n d -> b n d", b=b) |
|
|
|
|
|
|
|
for cross_attn, cross_ff, self_attns in self.layers: |
|
x = cross_attn(x, context=h, mask=mask) + x |
|
x = cross_ff(x) + x |
|
|
|
if self.self_per_cross_attn > 0: |
|
for self_attn, self_ff in self_attns: |
|
x = self_attn(x) + x |
|
x = self_ff(x) + x |
|
else: |
|
x = self_attns[0](x) + x |
|
|
|
|
|
if pretrain: |
|
return x.mean(dim=1) |
|
|
|
|
|
logits = self.to_logits(x) |
|
Y_hat = torch.topk(logits, 1, dim=1)[1] |
|
Y_prob = F.softmax(logits, dim=1) |
|
return logits, Y_prob, Y_hat |
|
|