awsaf49's picture
Initial Commit
3f50570
raw
history blame
5.07 kB
from sonics.models.spectttra import SpecTTTra
from sonics.models.vit import ViT
from sonics.layers.feature import FeatureExtractor
from sonics.layers.augment import AugmentLayer
import torch.nn as nn
import torch.nn.functional as F
import timm
def use_global_pool(model_name):
"""
Check if the model requires global pooling or not.
"""
no_global_pool = ["timm"]
return False if any(x in model_name for x in no_global_pool) else True
def get_embed_dim(model_name, encoder):
"""
Get the embedding dimension of the encoder.
"""
if "timm" in model_name:
return encoder.head_hidden_size
else:
return encoder.embed_dim
def use_init_weights(model_name):
"""
Check if the model requires initialization of weights or not.
"""
has_init_weights = ["timm"]
return False if any(x in model_name for x in has_init_weights) else True
class AudioClassifier(nn.Module):
def __init__(self, cfg):
super().__init__()
self.model_name = cfg.model.name
self.input_shape = cfg.model.input_shape
self.num_classes = cfg.num_classes
self.ft_extractor = FeatureExtractor(cfg)
self.augment = AugmentLayer(cfg)
self.encoder = self.get_encoder(cfg)
self.embed_dim = get_embed_dim(self.model_name, self.encoder)
self.classifier = nn.Linear(self.embed_dim, self.num_classes)
self.use_init_weights = getattr(cfg.model, "use_init_weights", True)
# Initialize weights
(
self.initialize_weights()
if self.use_init_weights and use_init_weights(self.model_name)
else None
)
def get_encoder(self, cfg):
if cfg.model.name == "SpecTTTra":
model = SpecTTTra(
input_spec_dim=cfg.model.input_shape[0],
input_temp_dim=cfg.model.input_shape[1],
embed_dim=cfg.model.embed_dim,
t_clip=cfg.model.t_clip,
f_clip=cfg.model.f_clip,
num_heads=cfg.model.num_heads,
num_layers=cfg.model.num_layers,
pre_norm=cfg.model.pre_norm,
pe_learnable=cfg.model.pe_learnable,
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
)
elif cfg.model.name == "ViT":
model = ViT(
image_size=cfg.model.input_shape,
patch_size=cfg.model.patch_size,
embed_dim=cfg.model.embed_dim,
num_heads=cfg.model.num_heads,
num_layers=cfg.model.num_layers,
pe_learnable=cfg.model.pe_learnable,
patch_norm=getattr(cfg.model, "patch_norm", False),
pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
)
elif "timm" in cfg.model.name:
model_name = cfg.model.name.replace("timm-", "")
model = timm.create_model(
model_name,
pretrained=cfg.model.pretrained,
in_chans=1,
num_classes=0,
)
else:
raise ValueError(f"Model {cfg.model.name} not supported in V1.")
return model
def forward(self, audio, y=None):
spec = self.ft_extractor(audio) # shape: (batch_size, n_mels, n_frames)
if self.training:
spec, y = self.augment(spec, y)
spec = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames)
spec = F.interpolate(spec, size=tuple(self.input_shape), mode="bilinear")
features = self.encoder(spec)
embeds = features.mean(dim=1) if use_global_pool(self.model_name) else features
preds = self.classifier(embeds)
return preds if y is None else (preds, y)
def initialize_weights(self):
for name, module in self.named_modules():
if isinstance(module, nn.Linear):
if name.startswith("classifier"):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, 0.0)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(module.bias, std=1e-6)
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="relu"
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()