|
from math import floor, log, pi |
|
import torch.nn.functional as F |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, reduce, repeat |
|
from einops.layers.torch import Rearrange |
|
from einops_exts import rearrange_many |
|
from torch import Tensor, einsum |
|
|
|
|
|
def default(val, d): |
|
if val is not None: |
|
return val |
|
return d |
|
|
|
class AdaLayerNorm(nn.Module): |
|
def __init__(self, style_dim, channels, eps=1e-5): |
|
super().__init__() |
|
self.channels = channels |
|
self.eps = eps |
|
|
|
self.fc = nn.Linear(style_dim, channels*2) |
|
|
|
def forward(self, x, s): |
|
x = x.transpose(-1, -2) |
|
x = x.transpose(1, -1) |
|
|
|
h = self.fc(s) |
|
h = h.view(h.size(0), h.size(1), 1) |
|
gamma, beta = torch.chunk(h, chunks=2, dim=1) |
|
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) |
|
|
|
|
|
x = F.layer_norm(x, (self.channels,), eps=self.eps) |
|
x = (1 + gamma) * x + beta |
|
return x.transpose(1, -1).transpose(-1, -2) |
|
|
|
class StyleTransformer1d(nn.Module): |
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
num_layers: int, |
|
channels: int, |
|
num_heads: int, |
|
head_features: int, |
|
multiplier: int, |
|
use_context_time: bool = True, |
|
use_rel_pos: bool = False, |
|
context_features_multiplier: int = 1, |
|
|
|
|
|
context_features=None, |
|
context_embedding_features=None, |
|
embedding_max_length=512, |
|
): |
|
super().__init__() |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
StyleTransformerBlock( |
|
features=channels + context_embedding_features, |
|
head_features=head_features, |
|
num_heads=num_heads, |
|
multiplier=multiplier, |
|
style_dim=context_features, |
|
use_rel_pos=use_rel_pos, |
|
|
|
|
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.to_out = nn.Sequential( |
|
Rearrange("b t c -> b c t"), |
|
nn.Conv1d( |
|
in_channels=channels + context_embedding_features, |
|
out_channels=channels, |
|
kernel_size=1, |
|
), |
|
) |
|
|
|
use_context_features = context_features is not None |
|
self.use_context_features = use_context_features |
|
self.use_context_time = use_context_time |
|
|
|
if use_context_time or use_context_features: |
|
|
|
|
|
|
|
context_mapping_features = channels + context_embedding_features |
|
|
|
self.to_mapping = nn.Sequential( |
|
nn.Linear(context_mapping_features, context_mapping_features), |
|
nn.GELU(), |
|
nn.Linear(context_mapping_features, context_mapping_features), |
|
nn.GELU(), |
|
) |
|
|
|
if use_context_time: |
|
|
|
self.to_time = nn.Sequential( |
|
TimePositionalEmbedding( |
|
dim=channels, out_features=context_mapping_features |
|
), |
|
nn.GELU(), |
|
) |
|
|
|
if use_context_features: |
|
|
|
self.to_features = nn.Sequential( |
|
nn.Linear( |
|
in_features=context_features, out_features=context_mapping_features |
|
), |
|
nn.GELU(), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_mapping( |
|
self, |
|
time=None, |
|
features=None): |
|
"""Combines context time features and features into mapping""" |
|
items, mapping = [], None |
|
|
|
if self.use_context_time: |
|
|
|
items += [self.to_time(time)] |
|
|
|
if self.use_context_features: |
|
|
|
items += [self.to_features(features)] |
|
|
|
|
|
if self.use_context_time or self.use_context_features: |
|
|
|
mapping = reduce(torch.stack(items), "n b m -> b m", "sum") |
|
mapping = self.to_mapping(mapping) |
|
|
|
return mapping |
|
|
|
def forward(self, |
|
x, |
|
time, |
|
embedding= None, |
|
features = None): |
|
|
|
|
|
|
|
|
|
mapping = self.get_mapping(time, features) |
|
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1) |
|
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1) |
|
for block in self.blocks: |
|
x = x + mapping |
|
x = block(x, features) |
|
x = x.mean(axis=1).unsqueeze(1) |
|
x = self.to_out(x) |
|
x = x.transpose(-1, -2) |
|
return x |
|
|
|
|
|
class StyleTransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
features: int, |
|
num_heads: int, |
|
head_features: int, |
|
style_dim: int, |
|
multiplier: int, |
|
use_rel_pos: bool, |
|
|
|
|
|
context_features = None, |
|
): |
|
super().__init__() |
|
|
|
self.use_cross_attention = (context_features is not None) and (context_features > 0) |
|
|
|
|
|
self.attention = StyleAttention( |
|
features=features, |
|
style_dim=style_dim, |
|
num_heads=num_heads, |
|
head_features=head_features |
|
) |
|
|
|
if self.use_cross_attention: |
|
raise ValueError |
|
|
|
self.feed_forward = FeedForward(features=features, multiplier=multiplier) |
|
|
|
def forward(self, x: Tensor, s: Tensor, *, context = None) -> Tensor: |
|
x = self.attention(x, s) + x |
|
if self.use_cross_attention: |
|
raise ValueError |
|
|
|
x = self.feed_forward(x) + x |
|
return x |
|
|
|
class StyleAttention(nn.Module): |
|
def __init__( |
|
self, |
|
features: int, |
|
*, |
|
style_dim: int, |
|
head_features: int, |
|
num_heads: int, |
|
context_features = None, |
|
|
|
|
|
|
|
): |
|
super().__init__() |
|
self.context_features = context_features |
|
mid_features = head_features * num_heads |
|
context_features = default(context_features, features) |
|
|
|
self.norm = AdaLayerNorm(style_dim, features) |
|
self.norm_context = AdaLayerNorm(style_dim, context_features) |
|
self.to_q = nn.Linear( |
|
in_features=features, out_features=mid_features, bias=False |
|
) |
|
self.to_kv = nn.Linear( |
|
in_features=context_features, out_features=mid_features * 2, bias=False |
|
) |
|
self.attention = AttentionBase( |
|
features, |
|
num_heads=num_heads, |
|
head_features=head_features |
|
) |
|
|
|
def forward(self, x, s, *, context = None): |
|
|
|
if context is not None: |
|
raise ValueError |
|
context = default(context, x) |
|
|
|
|
|
x, context = self.norm(x, s), self.norm_context(context, s) |
|
|
|
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) |
|
|
|
return self.attention(q, k, v) |
|
|
|
|
|
def FeedForward(features, |
|
multiplier): |
|
mid_features = features * multiplier |
|
return nn.Sequential( |
|
nn.Linear(in_features=features, out_features=mid_features), |
|
nn.GELU(), |
|
nn.Linear(in_features=mid_features, out_features=features), |
|
) |
|
|
|
|
|
class AttentionBase(nn.Module): |
|
def __init__( |
|
self, |
|
features, |
|
*, |
|
head_features, |
|
num_heads): |
|
super().__init__() |
|
self.scale = head_features ** -0.5 |
|
self.num_heads = num_heads |
|
mid_features = head_features * num_heads |
|
self.to_out = nn.Linear(in_features=mid_features, |
|
out_features=features) |
|
|
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: |
|
|
|
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) |
|
|
|
sim = einsum("... n d, ... m d -> ... n m", q, k) |
|
|
|
|
|
|
|
|
|
|
|
sim = sim * self.scale |
|
|
|
attn = sim.softmax(dim=-1) |
|
|
|
out = einsum("... n m, ... m d -> ... n d", attn, v) |
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
features, |
|
*, |
|
head_features, |
|
num_heads, |
|
out_features=None, |
|
context_features=None, |
|
|
|
|
|
|
|
): |
|
super().__init__() |
|
self.context_features = context_features |
|
mid_features = head_features * num_heads |
|
context_features = default(context_features, features) |
|
|
|
self.norm = nn.LayerNorm(features) |
|
self.norm_context = nn.LayerNorm(context_features) |
|
self.to_q = nn.Linear( |
|
in_features=features, out_features=mid_features, bias=False |
|
) |
|
self.to_kv = nn.Linear( |
|
in_features=context_features, out_features=mid_features * 2, bias=False |
|
) |
|
|
|
self.attention = AttentionBase( |
|
features, |
|
out_features=out_features, |
|
num_heads=num_heads, |
|
head_features=head_features, |
|
|
|
|
|
|
|
) |
|
|
|
def forward(self, x: Tensor, *, context = None) -> Tensor: |
|
|
|
|
|
|
|
context = default(context, x) |
|
|
|
x, context = self.norm(x), self.norm_context(context) |
|
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) |
|
|
|
return self.attention(q, k, v) |
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Module): |
|
"""Used for continuous time""" |
|
|
|
def __init__(self, dim: int): |
|
super().__init__() |
|
assert (dim % 2) == 0 |
|
half_dim = dim // 2 |
|
self.weights = nn.Parameter(torch.randn(half_dim)) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
x = rearrange(x, "b -> b 1") |
|
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi |
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
|
fouriered = torch.cat((x, fouriered), dim=-1) |
|
return fouriered |
|
|
|
|
|
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: |
|
return nn.Sequential( |
|
LearnedPositionalEmbedding(dim), |
|
nn.Linear(in_features=dim + 1, out_features=out_features), |
|
) |
|
|
|
|