from functools import wraps from math import log, pi import torch import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange, Reduce from torch import einsum, nn def exists(val): return val is not None def default(val, d): return val if exists(val) else d def cache_fn(f): cache = dict() @wraps(f) def cached_fn(*args, _cache=True, key=None, **kwargs): if not _cache: return f(*args, **kwargs) nonlocal cache if key in cache: return cache[key] result = f(*args, **kwargs) cache[key] = result return result return cached_fn def fourier_encode(x, max_freq, num_bands=4): x = x.unsqueeze(-1) device, dtype, orig_x = x.device, x.dtype, x scales = torch.linspace(1.0, max_freq / 2, num_bands, device=device, dtype=dtype) scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] x = x * scales * pi x = torch.cat([x.sin(), x.cos()], dim=-1) x = torch.cat((x, orig_x), dim=-1) return x class PreNorm(nn.Module): def __init__(self, dim, fn, context_dim=None): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None def forward(self, x, **kwargs): x = self.norm(x) if exists(self.norm_context): context = kwargs["context"] normed_context = self.norm_context(context) kwargs.update(context=normed_context) return self.fn(x, **kwargs) class GEGLU(nn.Module): def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class FeedForward(nn.Module): def __init__(self, dim, mult=4, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.dropout = nn.Dropout(dropout) self.to_out = nn.Linear(inner_dim, query_dim) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) attn = self.dropout(attn) out = einsum("b i j, b j d -> b i d", attn, v) out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) class Perceiver(nn.Module): def __init__( self, *, num_freq_bands, depth, max_freq, input_channels=3, input_axis=2, num_latents=512, latent_dim=512, cross_heads=1, latent_heads=8, cross_dim_head=64, latent_dim_head=64, num_classes=1000, attn_dropout=0.0, ff_dropout=0.0, weight_tie_layers=False, fourier_encode_data=True, self_per_cross_attn=1, final_classifier_head=True, pool="mean", latent_init=None, ): """The shape of the final attention mechanism will be: depth * (cross attention -> self_per_cross_attn * self attention) Args: num_freq_bands: Number of freq bands, with original value (2 * K + 1) depth: Depth of net. max_freq: Maximum frequency, hyperparameter depending on how fine the data is. freq_base: Base for the frequency input_channels: Number of channels for each token of the input. input_axis: Number of axes for input data (2 for images, 3 for video) num_latents: Number of latents, or induced set points, or centroids. Different papers giving it different names. latent_dim: Latent dimension. cross_heads: Number of heads for cross attention. Paper said 1. latent_heads: Number of heads for latent self attention, 8. cross_dim_head: Number of dimensions per cross attention head. latent_dim_head: Number of dimensions per latent self attention head. num_classes: Output number of classes. attn_dropout: Attention dropout ff_dropout: Feedforward dropout weight_tie_layers: Whether to weight tie layers (optional). fourier_encode_data: Whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself. self_per_cross_attn: Number of self attention blocks per cross attn. final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end """ super().__init__() self.input_axis = input_axis self.max_freq = max_freq self.num_freq_bands = num_freq_bands self.self_per_cross_attn = self_per_cross_attn self.fourier_encode_data = fourier_encode_data fourier_channels = ( (input_axis * ((num_freq_bands * 2) + 1)) * 2 if fourier_encode_data else 0 ) input_dim = fourier_channels + input_channels self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) if latent_init is not None: latent_init_feat = torch.load(latent_init) if type(latent_init_feat) != torch.Tensor: latent_init_feat = torch.Tensor(latent_init_feat) if len(latent_init_feat.shape) == 3: latent_init_feat = latent_init_feat[0] with torch.no_grad(): self.latents.copy_(latent_init_feat) print(f"load latent feature: , {latent_init}") get_cross_attn = lambda: PreNorm( latent_dim, Attention( latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, dropout=attn_dropout, ), context_dim=input_dim, ) get_cross_ff = lambda: PreNorm( latent_dim, FeedForward(latent_dim, dropout=ff_dropout) ) get_latent_attn = lambda: PreNorm( latent_dim, Attention( latent_dim, heads=latent_heads, dim_head=latent_dim_head, dropout=attn_dropout, ), ) get_latent_ff = lambda: PreNorm( latent_dim, FeedForward(latent_dim, dropout=ff_dropout) ) get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map( cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff) ) self.layers = nn.ModuleList([]) for i in range(depth): should_cache = i > 0 and weight_tie_layers cache_args = {"_cache": should_cache} self_attns = nn.ModuleList([]) for block_ind in range(self_per_cross_attn): self_attns.append( nn.ModuleList( [ get_latent_attn(**cache_args, key=block_ind), get_latent_ff(**cache_args, key=block_ind), ] ) ) if self_per_cross_attn == 0: self_attns.append(get_latent_ff(**cache_args, key=block_ind)) self.layers.append( nn.ModuleList( [ get_cross_attn(**cache_args), get_cross_ff(**cache_args), self_attns, ] ) ) if final_classifier_head: if pool == "cat": self.to_logits = nn.Sequential( Rearrange("b n d -> b (n d)"), nn.LayerNorm(num_latents * latent_dim), nn.Linear(num_latents * latent_dim, num_classes), ) elif pool == "mlp": self.to_logits = nn.Sequential( Reduce("b n d -> b d", "mean"), nn.LayerNorm(latent_dim), nn.Linear(latent_dim, latent_dim), nn.ReLU(), nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes), ) else: self.to_logits = nn.Sequential( Reduce("b n d -> b d", pool), nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes), ) def forward(self, h, label=None, mask=None, pretrain=False, coords=None): b, *axis, _, device, dtype = *h.shape, h.device, h.dtype assert ( len(axis) == self.input_axis ), "input data must have the right number of axis" if self.fourier_encode_data: # calculate fourier encoded positions in the range of [-1, 1], for all axis # axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis)) # pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1) # enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) # enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') # enc_pos = repeat(enc_pos, '... -> b ...', b = b) enc_pos = fourier_encode(coords, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, "... n d -> ... (n d)") h = torch.cat((h, enc_pos), dim=-1) # concat to channels of data and flatten axis h = rearrange(h, "b ... d -> b (...) d") x = repeat(self.latents, "n d -> b n d", b=b) # layers for cross_attn, cross_ff, self_attns in self.layers: x = cross_attn(x, context=h, mask=mask) + x x = cross_ff(x) + x if self.self_per_cross_attn > 0: for self_attn, self_ff in self_attns: x = self_attn(x) + x x = self_ff(x) + x else: x = self_attns[0](x) + x # allow for fetching embeddings if pretrain: return x.mean(dim=1) # to logits logits = self.to_logits(x) Y_hat = torch.topk(logits, 1, dim=1)[1] Y_prob = F.softmax(logits, dim=1) return logits, Y_prob, Y_hat