Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from sonics.layers import Transformer | |
from sonics.layers.tokenizer import STTokenizer | |
class SpecTTTra(nn.Module): | |
def __init__( | |
self, | |
input_spec_dim, | |
input_temp_dim, | |
embed_dim, | |
t_clip, | |
f_clip, | |
num_heads, | |
num_layers, | |
pre_norm=False, | |
pe_learnable=False, | |
pos_drop_rate=0.0, | |
attn_drop_rate=0.0, | |
proj_drop_rate=0.0, | |
mlp_ratio=4.0, | |
): | |
super(SpecTTTra, self).__init__() | |
self.input_spec_dim = input_spec_dim | |
self.input_temp_dim = input_temp_dim | |
self.embed_dim = embed_dim | |
self.t_clip = t_clip | |
self.f_clip = f_clip | |
self.num_heads = num_heads | |
self.num_layers = num_layers | |
self.pre_norm = ( | |
pre_norm # applied after tokenization before transformer (used in CLIP) | |
) | |
self.pe_learnable = pe_learnable # learned positional encoding | |
self.pos_drop_rate = pos_drop_rate | |
self.attn_drop_rate = attn_drop_rate | |
self.proj_drop_rate = proj_drop_rate | |
self.mlp_ratio = mlp_ratio | |
self.st_tokenizer = STTokenizer( | |
input_spec_dim, | |
input_temp_dim, | |
t_clip, | |
f_clip, | |
embed_dim, | |
pre_norm=pre_norm, | |
pe_learnable=pe_learnable, | |
) | |
self.pos_drop = nn.Dropout(p=pos_drop_rate) | |
self.transformer = Transformer( | |
embed_dim, | |
num_heads, | |
num_layers, | |
attn_drop=self.attn_drop_rate, | |
proj_drop=self.proj_drop_rate, | |
mlp_ratio=self.mlp_ratio, | |
) | |
def forward(self, x): | |
# Squeeze the channel dimension if it exists | |
if x.dim() == 4: | |
x = x.squeeze(1) | |
# Spectro-temporal tokenization | |
spectro_temporal_tokens = self.st_tokenizer(x) | |
# Positional dropout | |
spectro_temporal_tokens = self.pos_drop(spectro_temporal_tokens) | |
# Transformer | |
output = self.transformer(spectro_temporal_tokens) # shape: (B, T/t + F/f, dim) | |
return output | |
# Example usage: | |
input_spec_dim = 384 | |
input_temp_dim = 128 | |
embed_dim = 512 | |
t_clip = 20 # This means t | |
f_clip = 10 # This means f | |
num_heads = 8 | |
num_layers = 6 | |
dim_feedforward = 512 | |
num_classes = 10 | |