편명장/님/(myeongjang.pyeon)
initial commit
287a683
from functools import wraps
from math import log, pi
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
from torch import einsum, nn
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cache_fn(f):
cache = dict()
@wraps(f)
def cached_fn(*args, _cache=True, key=None, **kwargs):
if not _cache:
return f(*args, **kwargs)
nonlocal cache
if key in cache:
return cache[key]
result = f(*args, **kwargs)
cache[key] = result
return result
return cached_fn
def fourier_encode(x, max_freq, num_bands=4):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
scales = torch.linspace(1.0, max_freq / 2, num_bands, device=device, dtype=dtype)
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
x = x * scales * pi
x = torch.cat([x.sin(), x.cos()], dim=-1)
x = torch.cat((x, orig_x), dim=-1)
return x
class PreNorm(nn.Module):
def __init__(self, dim, fn, context_dim=None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
def forward(self, x, **kwargs):
x = self.norm(x)
if exists(self.norm_context):
context = kwargs["context"]
normed_context = self.norm_context(context)
kwargs.update(context=normed_context)
return self.fn(x, **kwargs)
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
attn = self.dropout(attn)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
class Perceiver(nn.Module):
def __init__(
self,
*,
num_freq_bands,
depth,
max_freq,
input_channels=3,
input_axis=2,
num_latents=512,
latent_dim=512,
cross_heads=1,
latent_heads=8,
cross_dim_head=64,
latent_dim_head=64,
num_classes=1000,
attn_dropout=0.0,
ff_dropout=0.0,
weight_tie_layers=False,
fourier_encode_data=True,
self_per_cross_attn=1,
final_classifier_head=True,
pool="mean",
latent_init=None,
):
"""The shape of the final attention mechanism will be:
depth * (cross attention -> self_per_cross_attn * self attention)
Args:
num_freq_bands: Number of freq bands, with original value (2 * K + 1)
depth: Depth of net.
max_freq: Maximum frequency, hyperparameter depending on how
fine the data is.
freq_base: Base for the frequency
input_channels: Number of channels for each token of the input.
input_axis: Number of axes for input data (2 for images, 3 for video)
num_latents: Number of latents, or induced set points, or centroids.
Different papers giving it different names.
latent_dim: Latent dimension.
cross_heads: Number of heads for cross attention. Paper said 1.
latent_heads: Number of heads for latent self attention, 8.
cross_dim_head: Number of dimensions per cross attention head.
latent_dim_head: Number of dimensions per latent self attention head.
num_classes: Output number of classes.
attn_dropout: Attention dropout
ff_dropout: Feedforward dropout
weight_tie_layers: Whether to weight tie layers (optional).
fourier_encode_data: Whether to auto-fourier encode the data, using
the input_axis given. defaults to True, but can be turned off
if you are fourier encoding the data yourself.
self_per_cross_attn: Number of self attention blocks per cross attn.
final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
"""
super().__init__()
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.self_per_cross_attn = self_per_cross_attn
self.fourier_encode_data = fourier_encode_data
fourier_channels = (
(input_axis * ((num_freq_bands * 2) + 1)) * 2 if fourier_encode_data else 0
)
input_dim = fourier_channels + input_channels
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
if latent_init is not None:
latent_init_feat = torch.load(latent_init)
if type(latent_init_feat) != torch.Tensor:
latent_init_feat = torch.Tensor(latent_init_feat)
if len(latent_init_feat.shape) == 3:
latent_init_feat = latent_init_feat[0]
with torch.no_grad():
self.latents.copy_(latent_init_feat)
print(f"load latent feature: , {latent_init}")
get_cross_attn = lambda: PreNorm(
latent_dim,
Attention(
latent_dim,
input_dim,
heads=cross_heads,
dim_head=cross_dim_head,
dropout=attn_dropout,
),
context_dim=input_dim,
)
get_cross_ff = lambda: PreNorm(
latent_dim, FeedForward(latent_dim, dropout=ff_dropout)
)
get_latent_attn = lambda: PreNorm(
latent_dim,
Attention(
latent_dim,
heads=latent_heads,
dim_head=latent_dim_head,
dropout=attn_dropout,
),
)
get_latent_ff = lambda: PreNorm(
latent_dim, FeedForward(latent_dim, dropout=ff_dropout)
)
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(
cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)
)
self.layers = nn.ModuleList([])
for i in range(depth):
should_cache = i > 0 and weight_tie_layers
cache_args = {"_cache": should_cache}
self_attns = nn.ModuleList([])
for block_ind in range(self_per_cross_attn):
self_attns.append(
nn.ModuleList(
[
get_latent_attn(**cache_args, key=block_ind),
get_latent_ff(**cache_args, key=block_ind),
]
)
)
if self_per_cross_attn == 0:
self_attns.append(get_latent_ff(**cache_args, key=block_ind))
self.layers.append(
nn.ModuleList(
[
get_cross_attn(**cache_args),
get_cross_ff(**cache_args),
self_attns,
]
)
)
if final_classifier_head:
if pool == "cat":
self.to_logits = nn.Sequential(
Rearrange("b n d -> b (n d)"),
nn.LayerNorm(num_latents * latent_dim),
nn.Linear(num_latents * latent_dim, num_classes),
)
elif pool == "mlp":
self.to_logits = nn.Sequential(
Reduce("b n d -> b d", "mean"),
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes),
)
else:
self.to_logits = nn.Sequential(
Reduce("b n d -> b d", pool),
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes),
)
def forward(self, h, label=None, mask=None, pretrain=False, coords=None):
b, *axis, _, device, dtype = *h.shape, h.device, h.dtype
assert (
len(axis) == self.input_axis
), "input data must have the right number of axis"
if self.fourier_encode_data:
# calculate fourier encoded positions in the range of [-1, 1], for all axis
# axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis))
# pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
# enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
# enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
# enc_pos = repeat(enc_pos, '... -> b ...', b = b)
enc_pos = fourier_encode(coords, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, "... n d -> ... (n d)")
h = torch.cat((h, enc_pos), dim=-1)
# concat to channels of data and flatten axis
h = rearrange(h, "b ... d -> b (...) d")
x = repeat(self.latents, "n d -> b n d", b=b)
# layers
for cross_attn, cross_ff, self_attns in self.layers:
x = cross_attn(x, context=h, mask=mask) + x
x = cross_ff(x) + x
if self.self_per_cross_attn > 0:
for self_attn, self_ff in self_attns:
x = self_attn(x) + x
x = self_ff(x) + x
else:
x = self_attns[0](x) + x
# allow for fetching embeddings
if pretrain:
return x.mean(dim=1)
# to logits
logits = self.to_logits(x)
Y_hat = torch.topk(logits, 1, dim=1)[1]
Y_prob = F.softmax(logits, dim=1)
return logits, Y_prob, Y_hat