from math import floor, log, pi import torch.nn.functional as F import torch import torch.nn as nn from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange from einops_exts import rearrange_many from torch import Tensor, einsum def default(val, d): if val is not None: #exists(val): return val return d # d() if isfunction(d) else d class AdaLayerNorm(nn.Module): def __init__(self, style_dim, channels, eps=1e-5): super().__init__() self.channels = channels self.eps = eps self.fc = nn.Linear(style_dim, channels*2) def forward(self, x, s): x = x.transpose(-1, -2) x = x.transpose(1, -1) h = self.fc(s) h = h.view(h.size(0), h.size(1), 1) gamma, beta = torch.chunk(h, chunks=2, dim=1) gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) x = F.layer_norm(x, (self.channels,), eps=self.eps) x = (1 + gamma) * x + beta return x.transpose(1, -1).transpose(-1, -2) class StyleTransformer1d(nn.Module): # artificial_stylets / models.py def __init__( self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int, use_context_time: bool = True, use_rel_pos: bool = False, context_features_multiplier: int = 1, # rel_pos_num_buckets: Optional[int] = None, # rel_pos_max_distance: Optional[int] = None, context_features=None, context_embedding_features=None, embedding_max_length=512, ): super().__init__() self.blocks = nn.ModuleList( [ StyleTransformerBlock( features=channels + context_embedding_features, head_features=head_features, num_heads=num_heads, multiplier=multiplier, style_dim=context_features, use_rel_pos=use_rel_pos, # rel_pos_num_buckets=rel_pos_num_buckets, # rel_pos_max_distance=rel_pos_max_distance, ) for i in range(num_layers) ] ) self.to_out = nn.Sequential( Rearrange("b t c -> b c t"), nn.Conv1d( in_channels=channels + context_embedding_features, out_channels=channels, kernel_size=1, ), ) use_context_features = context_features is not None self.use_context_features = use_context_features self.use_context_time = use_context_time if use_context_time or use_context_features: # print(f'{use_context_time=} {use_context_features=}ooooooooooooooooooooooooooooooooooo') # raise ValueError # True True both context context_mapping_features = channels + context_embedding_features self.to_mapping = nn.Sequential( nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(), nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(), ) if use_context_time: self.to_time = nn.Sequential( TimePositionalEmbedding( dim=channels, out_features=context_mapping_features ), nn.GELU(), ) if use_context_features: self.to_features = nn.Sequential( nn.Linear( in_features=context_features, out_features=context_mapping_features ), nn.GELU(), ) # self.fixed_embedding = FixedEmbedding( # max_length=embedding_max_length, features=context_embedding_features # ) # Non speker-aware LookUp: EMbedding looks just the time-frame-index [0,1,2...,num-asr-time-frames] def get_mapping( self, time=None, features=None): """Combines context time features and features into mapping""" items, mapping = [], None # Compute time features if self.use_context_time: items += [self.to_time(time)] # Compute features if self.use_context_features: items += [self.to_features(features)] # Compute joint mapping if self.use_context_time or self.use_context_features: # raise ValueError mapping = reduce(torch.stack(items), "n b m -> b m", "sum") mapping = self.to_mapping(mapping) return mapping def forward(self, x, time, embedding= None, features = None): # -- # called by forward() mapping = self.get_mapping(time, features) x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1) mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1) for block in self.blocks: x = x + mapping x = block(x, features) x = x.mean(axis=1).unsqueeze(1) x = self.to_out(x) x = x.transpose(-1, -2) return x class StyleTransformerBlock(nn.Module): def __init__( self, features: int, num_heads: int, head_features: int, style_dim: int, multiplier: int, use_rel_pos: bool, # rel_pos_num_buckets: Optional[int] = None, # rel_pos_max_distance: Optional[int] = None, context_features = None, ): super().__init__() self.use_cross_attention = (context_features is not None) and (context_features > 0) # print(f'{rel_pos_num_buckets=} {rel_pos_max_distance=}') # None None # raise ValueError self.attention = StyleAttention( features=features, style_dim=style_dim, num_heads=num_heads, head_features=head_features ) if self.use_cross_attention: raise ValueError self.feed_forward = FeedForward(features=features, multiplier=multiplier) def forward(self, x: Tensor, s: Tensor, *, context = None) -> Tensor: x = self.attention(x, s) + x if self.use_cross_attention: raise ValueError # x = self.cross_attention(x, s, context=context) + x x = self.feed_forward(x) + x return x class StyleAttention(nn.Module): def __init__( self, features: int, *, style_dim: int, head_features: int, num_heads: int, context_features = None, # use_rel_pos: bool, # rel_pos_num_buckets: Optional[int] = None, # rel_pos_max_distance: Optional[int] = None, ): super().__init__() self.context_features = context_features mid_features = head_features * num_heads context_features = default(context_features, features) self.norm = AdaLayerNorm(style_dim, features) self.norm_context = AdaLayerNorm(style_dim, context_features) self.to_q = nn.Linear( in_features=features, out_features=mid_features, bias=False ) self.to_kv = nn.Linear( in_features=context_features, out_features=mid_features * 2, bias=False ) self.attention = AttentionBase( features, num_heads=num_heads, head_features=head_features ) def forward(self, x, s, *, context = None): if context is not None: raise ValueError context = default(context, x) x, context = self.norm(x, s), self.norm_context(context, s) q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) return self.attention(q, k, v) def FeedForward(features, multiplier): mid_features = features * multiplier return nn.Sequential( nn.Linear(in_features=features, out_features=mid_features), nn.GELU(), nn.Linear(in_features=mid_features, out_features=features), ) class AttentionBase(nn.Module): def __init__( self, features, *, head_features, num_heads): super().__init__() self.scale = head_features ** -0.5 self.num_heads = num_heads mid_features = head_features * num_heads self.to_out = nn.Linear(in_features=mid_features, out_features=features) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Split heads q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) # Compute similarity matrix sim = einsum("... n d, ... m d -> ... n m", q, k) # _____THERE_IS_NO_rel_po # sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim # print(self.rel_pos) sim = sim * self.scale # Get attention matrix with softmax attn = sim.softmax(dim=-1) # Compute values out = einsum("... n m, ... m d -> ... n d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Attention(nn.Module): def __init__( self, features, *, head_features, num_heads, out_features=None, context_features=None, # use_rel_pos, # rel_pos_num_buckets: Optional[int] = None, # rel_pos_max_distance: Optional[int] = None, ): super().__init__() self.context_features = context_features mid_features = head_features * num_heads context_features = default(context_features, features) self.norm = nn.LayerNorm(features) self.norm_context = nn.LayerNorm(context_features) self.to_q = nn.Linear( in_features=features, out_features=mid_features, bias=False ) self.to_kv = nn.Linear( in_features=context_features, out_features=mid_features * 2, bias=False ) self.attention = AttentionBase( features, out_features=out_features, num_heads=num_heads, head_features=head_features, # use_rel_pos=use_rel_pos, # rel_pos_num_buckets=rel_pos_num_buckets, # rel_pos_max_distance=rel_pos_max_distance, ) def forward(self, x: Tensor, *, context = None) -> Tensor: # assert_message = "You must provide a context when using context_features" # assert not self.context_features or exists(context), assert_message # Use context if provided context = default(context, x) # Normalize then compute q from input and k,v from context x, context = self.norm(x), self.norm_context(context) q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) # Compute and return attention return self.attention(q, k, v) class LearnedPositionalEmbedding(nn.Module): """Used for continuous time""" def __init__(self, dim: int): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x: Tensor) -> Tensor: x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((x, fouriered), dim=-1) return fouriered def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: return nn.Sequential( LearnedPositionalEmbedding(dim), nn.Linear(in_features=dim + 1, out_features=out_features), )