Spaces:
Sleeping
Sleeping
from sonics.models.spectttra import SpecTTTra | |
from sonics.models.vit import ViT | |
from sonics.layers.feature import FeatureExtractor | |
from sonics.layers.augment import AugmentLayer | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import timm | |
def use_global_pool(model_name): | |
""" | |
Check if the model requires global pooling or not. | |
""" | |
no_global_pool = ["timm"] | |
return False if any(x in model_name for x in no_global_pool) else True | |
def get_embed_dim(model_name, encoder): | |
""" | |
Get the embedding dimension of the encoder. | |
""" | |
if "timm" in model_name: | |
return encoder.head_hidden_size | |
else: | |
return encoder.embed_dim | |
def use_init_weights(model_name): | |
""" | |
Check if the model requires initialization of weights or not. | |
""" | |
has_init_weights = ["timm"] | |
return False if any(x in model_name for x in has_init_weights) else True | |
class AudioClassifier(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.model_name = cfg.model.name | |
self.input_shape = cfg.model.input_shape | |
self.num_classes = cfg.num_classes | |
self.ft_extractor = FeatureExtractor(cfg) | |
self.augment = AugmentLayer(cfg) | |
self.encoder = self.get_encoder(cfg) | |
self.embed_dim = get_embed_dim(self.model_name, self.encoder) | |
self.classifier = nn.Linear(self.embed_dim, self.num_classes) | |
self.use_init_weights = getattr(cfg.model, "use_init_weights", True) | |
# Initialize weights | |
( | |
self.initialize_weights() | |
if self.use_init_weights and use_init_weights(self.model_name) | |
else None | |
) | |
def get_encoder(self, cfg): | |
if cfg.model.name == "SpecTTTra": | |
model = SpecTTTra( | |
input_spec_dim=cfg.model.input_shape[0], | |
input_temp_dim=cfg.model.input_shape[1], | |
embed_dim=cfg.model.embed_dim, | |
t_clip=cfg.model.t_clip, | |
f_clip=cfg.model.f_clip, | |
num_heads=cfg.model.num_heads, | |
num_layers=cfg.model.num_layers, | |
pre_norm=cfg.model.pre_norm, | |
pe_learnable=cfg.model.pe_learnable, | |
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0), | |
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0), | |
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0), | |
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0), | |
) | |
elif cfg.model.name == "ViT": | |
model = ViT( | |
image_size=cfg.model.input_shape, | |
patch_size=cfg.model.patch_size, | |
embed_dim=cfg.model.embed_dim, | |
num_heads=cfg.model.num_heads, | |
num_layers=cfg.model.num_layers, | |
pe_learnable=cfg.model.pe_learnable, | |
patch_norm=getattr(cfg.model, "patch_norm", False), | |
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0), | |
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0), | |
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0), | |
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0), | |
) | |
elif "timm" in cfg.model.name: | |
model_name = cfg.model.name.replace("timm-", "") | |
model = timm.create_model( | |
model_name, | |
pretrained=cfg.model.pretrained, | |
in_chans=1, | |
num_classes=0, | |
) | |
else: | |
raise ValueError(f"Model {cfg.model.name} not supported in V1.") | |
return model | |
def forward(self, audio, y=None): | |
spec = self.ft_extractor(audio) # shape: (batch_size, n_mels, n_frames) | |
if self.training: | |
spec, y = self.augment(spec, y) | |
spec = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames) | |
spec = F.interpolate(spec, size=tuple(self.input_shape), mode="bilinear") | |
features = self.encoder(spec) | |
embeds = features.mean(dim=1) if use_global_pool(self.model_name) else features | |
preds = self.classifier(embeds) | |
return preds if y is None else (preds, y) | |
def initialize_weights(self): | |
for name, module in self.named_modules(): | |
if isinstance(module, nn.Linear): | |
if name.startswith("classifier"): | |
nn.init.zeros_(module.weight) | |
nn.init.constant_(module.bias, 0.0) | |
else: | |
nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
nn.init.normal_(module.bias, std=1e-6) | |
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d): | |
nn.init.kaiming_normal_( | |
module.weight, mode="fan_out", nonlinearity="relu" | |
) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
elif hasattr(module, "init_weights"): | |
module.init_weights() | |