Dionyssos's picture
del voice-blind embedding
52c4e0a
raw
history blame
12.1 kB
from math import floor, log, pi
import torch.nn.functional as F
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many
from torch import Tensor, einsum
def default(val, d):
if val is not None: #exists(val):
return val
return d # d() if isfunction(d) else d
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels*2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
class StyleTransformer1d(nn.Module):
# artificial_stylets / models.py
def __init__(
self,
num_layers: int,
channels: int,
num_heads: int,
head_features: int,
multiplier: int,
use_context_time: bool = True,
use_rel_pos: bool = False,
context_features_multiplier: int = 1,
# rel_pos_num_buckets: Optional[int] = None,
# rel_pos_max_distance: Optional[int] = None,
context_features=None,
context_embedding_features=None,
embedding_max_length=512,
):
super().__init__()
self.blocks = nn.ModuleList(
[
StyleTransformerBlock(
features=channels + context_embedding_features,
head_features=head_features,
num_heads=num_heads,
multiplier=multiplier,
style_dim=context_features,
use_rel_pos=use_rel_pos,
# rel_pos_num_buckets=rel_pos_num_buckets,
# rel_pos_max_distance=rel_pos_max_distance,
)
for i in range(num_layers)
]
)
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
nn.Conv1d(
in_channels=channels + context_embedding_features,
out_channels=channels,
kernel_size=1,
),
)
use_context_features = context_features is not None
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
# print(f'{use_context_time=} {use_context_features=}ooooooooooooooooooooooooooooooooooo')
# raise ValueError
# True True both context
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
)
if use_context_time:
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
),
nn.GELU(),
)
if use_context_features:
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
),
nn.GELU(),
)
# self.fixed_embedding = FixedEmbedding(
# max_length=embedding_max_length, features=context_embedding_features
# ) # Non speker-aware LookUp: EMbedding looks just the time-frame-index [0,1,2...,num-asr-time-frames]
def get_mapping(
self,
time=None,
features=None):
"""Combines context time features and features into mapping"""
items, mapping = [], None
# Compute time features
if self.use_context_time:
items += [self.to_time(time)]
# Compute features
if self.use_context_features:
items += [self.to_features(features)]
# Compute joint mapping
if self.use_context_time or self.use_context_features:
# raise ValueError
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def forward(self,
x,
time,
embedding= None,
features = None):
# --
# called by forward()
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x, features)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
class StyleTransformerBlock(nn.Module):
def __init__(
self,
features: int,
num_heads: int,
head_features: int,
style_dim: int,
multiplier: int,
use_rel_pos: bool,
# rel_pos_num_buckets: Optional[int] = None,
# rel_pos_max_distance: Optional[int] = None,
context_features = None,
):
super().__init__()
self.use_cross_attention = (context_features is not None) and (context_features > 0)
# print(f'{rel_pos_num_buckets=} {rel_pos_max_distance=}') # None None
# raise ValueError
self.attention = StyleAttention(
features=features,
style_dim=style_dim,
num_heads=num_heads,
head_features=head_features
)
if self.use_cross_attention:
raise ValueError
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, s: Tensor, *, context = None) -> Tensor:
x = self.attention(x, s) + x
if self.use_cross_attention:
raise ValueError
# x = self.cross_attention(x, s, context=context) + x
x = self.feed_forward(x) + x
return x
class StyleAttention(nn.Module):
def __init__(
self,
features: int,
*,
style_dim: int,
head_features: int,
num_heads: int,
context_features = None,
# use_rel_pos: bool,
# rel_pos_num_buckets: Optional[int] = None,
# rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = AdaLayerNorm(style_dim, features)
self.norm_context = AdaLayerNorm(style_dim, context_features)
self.to_q = nn.Linear(
in_features=features, out_features=mid_features, bias=False
)
self.to_kv = nn.Linear(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features,
num_heads=num_heads,
head_features=head_features
)
def forward(self, x, s, *, context = None):
if context is not None:
raise ValueError
context = default(context, x)
x, context = self.norm(x, s), self.norm_context(context, s)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
return self.attention(q, k, v)
def FeedForward(features,
multiplier):
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
class AttentionBase(nn.Module):
def __init__(
self,
features,
*,
head_features,
num_heads):
super().__init__()
self.scale = head_features ** -0.5
self.num_heads = num_heads
mid_features = head_features * num_heads
self.to_out = nn.Linear(in_features=mid_features,
out_features=features)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Split heads
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
# Compute similarity matrix
sim = einsum("... n d, ... m d -> ... n m", q, k)
# _____THERE_IS_NO_rel_po
# sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
# print(self.rel_pos)
sim = sim * self.scale
# Get attention matrix with softmax
attn = sim.softmax(dim=-1)
# Compute values
out = einsum("... n m, ... m d -> ... n d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Attention(nn.Module):
def __init__(
self,
features,
*,
head_features,
num_heads,
out_features=None,
context_features=None,
# use_rel_pos,
# rel_pos_num_buckets: Optional[int] = None,
# rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = nn.LayerNorm(features)
self.norm_context = nn.LayerNorm(context_features)
self.to_q = nn.Linear(
in_features=features, out_features=mid_features, bias=False
)
self.to_kv = nn.Linear(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features,
out_features=out_features,
num_heads=num_heads,
head_features=head_features,
# use_rel_pos=use_rel_pos,
# rel_pos_num_buckets=rel_pos_num_buckets,
# rel_pos_max_distance=rel_pos_max_distance,
)
def forward(self, x: Tensor, *, context = None) -> Tensor:
# assert_message = "You must provide a context when using context_features"
# assert not self.context_features or exists(context), assert_message
# Use context if provided
context = default(context, x)
# Normalize then compute q from input and k,v from context
x, context = self.norm(x), self.norm_context(context)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
# Compute and return attention
return self.attention(q, k, v)
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
nn.Linear(in_features=dim + 1, out_features=out_features),
)