File size: 5,065 Bytes
3f50570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from sonics.models.spectttra import SpecTTTra
from sonics.models.vit import ViT
from sonics.layers.feature import FeatureExtractor
from sonics.layers.augment import AugmentLayer
import torch.nn as nn
import torch.nn.functional as F
import timm


def use_global_pool(model_name):
    """
    Check if the model requires global pooling or not.
    """
    no_global_pool = ["timm"]
    return False if any(x in model_name for x in no_global_pool) else True


def get_embed_dim(model_name, encoder):
    """
    Get the embedding dimension of the encoder.
    """
    if "timm" in model_name:
        return encoder.head_hidden_size
    else:
        return encoder.embed_dim


def use_init_weights(model_name):
    """
    Check if the model requires initialization of weights or not.
    """
    has_init_weights = ["timm"]
    return False if any(x in model_name for x in has_init_weights) else True


class AudioClassifier(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.model_name = cfg.model.name
        self.input_shape = cfg.model.input_shape
        self.num_classes = cfg.num_classes
        self.ft_extractor = FeatureExtractor(cfg)
        self.augment = AugmentLayer(cfg)
        self.encoder = self.get_encoder(cfg)
        self.embed_dim = get_embed_dim(self.model_name, self.encoder)
        self.classifier = nn.Linear(self.embed_dim, self.num_classes)
        self.use_init_weights = getattr(cfg.model, "use_init_weights", True)

        # Initialize weights
        (
            self.initialize_weights()
            if self.use_init_weights and use_init_weights(self.model_name)
            else None
        )

    def get_encoder(self, cfg):
        if cfg.model.name == "SpecTTTra":
            model = SpecTTTra(
                input_spec_dim=cfg.model.input_shape[0],
                input_temp_dim=cfg.model.input_shape[1],
                embed_dim=cfg.model.embed_dim,
                t_clip=cfg.model.t_clip,
                f_clip=cfg.model.f_clip,
                num_heads=cfg.model.num_heads,
                num_layers=cfg.model.num_layers,
                pre_norm=cfg.model.pre_norm,
                pe_learnable=cfg.model.pe_learnable,
                pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
                attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
                proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
                mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
            )
        elif cfg.model.name == "ViT":
            model = ViT(
                image_size=cfg.model.input_shape,
                patch_size=cfg.model.patch_size,
                embed_dim=cfg.model.embed_dim,
                num_heads=cfg.model.num_heads,
                num_layers=cfg.model.num_layers,
                pe_learnable=cfg.model.pe_learnable,
                patch_norm=getattr(cfg.model, "patch_norm", False),
                pos_drop_rate=getattr(cfg.model, "pos_drop_rate", 0.0),
                attn_drop_rate=getattr(cfg.model, "attn_drop_rate", 0.0),
                proj_drop_rate=getattr(cfg.model, "proj_drop_rate", 0.0),
                mlp_ratio=getattr(cfg.model, "mlp_ratio", 4.0),
            )
        elif "timm" in cfg.model.name:
            model_name = cfg.model.name.replace("timm-", "")
            model = timm.create_model(
                model_name,
                pretrained=cfg.model.pretrained,
                in_chans=1,
                num_classes=0,
            )
        else:
            raise ValueError(f"Model {cfg.model.name} not supported in V1.")
        return model

    def forward(self, audio, y=None):
        spec = self.ft_extractor(audio)  # shape: (batch_size, n_mels, n_frames)
        if self.training:
            spec, y = self.augment(spec, y)
        spec = spec.unsqueeze(1)  # shape: (batch_size, 1, n_mels, n_frames)
        spec = F.interpolate(spec, size=tuple(self.input_shape), mode="bilinear")
        features = self.encoder(spec)
        embeds = features.mean(dim=1) if use_global_pool(self.model_name) else features
        preds = self.classifier(embeds)
        return preds if y is None else (preds, y)

    def initialize_weights(self):
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                if name.startswith("classifier"):
                    nn.init.zeros_(module.weight)
                    nn.init.constant_(module.bias, 0.0)
                else:
                    nn.init.xavier_uniform_(module.weight)
                    if module.bias is not None:
                        nn.init.normal_(module.bias, std=1e-6)
            elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
                nn.init.kaiming_normal_(
                    module.weight, mode="fan_out", nonlinearity="relu"
                )
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif hasattr(module, "init_weights"):
                module.init_weights()