Add code and weights
Browse files- example.ipynb +116 -0
- models/baseline.py +260 -0
- models/helpers/utils.py +11 -0
- models/mel.py +97 -0
- split10/config.json +32 -0
- split10/model.safetensors +3 -0
- split100/config.json +32 -0
- split100/model.safetensors +3 -0
- split25/config.json +32 -0
- split25/model.safetensors +3 -0
- split5/config.json +32 -0
- split5/model.safetensors +3 -0
- split50/config.json +32 -0
- split50/model.safetensors +3 -0
example.ipynb
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "09225787-6a4b-4484-b00b-d0f731915a81",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from models.baseline import Network\n",
|
| 11 |
+
"from models.mel import AugmentMelSTFT\n",
|
| 12 |
+
"import soundfile as sf\n",
|
| 13 |
+
"import torch\n",
|
| 14 |
+
"import numpy as np"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"id": "c377b699-2c2e-468e-88b0-6767338988c8",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"audio_path = \"/path/to/audio.wav\""
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"id": "fa950347-df0d-4135-801a-d54525c57e58",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"from IPython.display import display, Audio\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"display(Audio(audio_path))"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": null,
|
| 42 |
+
"id": "79faad26-0f20-439d-b152-10f4666db41d",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"mel = AugmentMelSTFT().eval()\n",
|
| 47 |
+
"model = Network.from_pretrained(\"split5\").eval()\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"audio, sr = sf.read(audio_path, dtype=np.float32)\n",
|
| 50 |
+
"assert sr == 32_000\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"audio = torch.as_tensor(audio)\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"# audio.shape: (1,samples)\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"audio = audio.unsqueeze(0)\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"# audio.shape: (1,1,samples)\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"with torch.no_grad():\n",
|
| 61 |
+
" mel_spec = mel(audio)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# mel_spec.shape: (1, mel_bins, frames)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"mel_spec = mel_spec.unsqueeze(0)\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"with torch.no_grad():\n",
|
| 68 |
+
" logits = model(mel_spec)\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"# logits.shape: (1,classes)\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"logits = logits.squeeze(0)\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"tau2022_classes = [\n",
|
| 75 |
+
" \"airport\",\n",
|
| 76 |
+
" \"bus\",\n",
|
| 77 |
+
" \"metro\",\n",
|
| 78 |
+
" \"metro_station\",\n",
|
| 79 |
+
" \"park\",\n",
|
| 80 |
+
" \"public_square\",\n",
|
| 81 |
+
" \"shopping_mall\",\n",
|
| 82 |
+
" \"street_pedestrian\",\n",
|
| 83 |
+
" \"street_traffic\",\n",
|
| 84 |
+
" \"tram\"\n",
|
| 85 |
+
"]\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"best_prediction_idx = torch.argmax(logits)\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"scores = torch.softmax(logits, dim=0)\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"print(f\"Prediction: {tau2022_classes[best_prediction_idx]} (score: {scores[best_prediction_idx]:0.2f})\")"
|
| 92 |
+
]
|
| 93 |
+
}
|
| 94 |
+
],
|
| 95 |
+
"metadata": {
|
| 96 |
+
"kernelspec": {
|
| 97 |
+
"display_name": "Python 3 (ipykernel)",
|
| 98 |
+
"language": "python",
|
| 99 |
+
"name": "python3"
|
| 100 |
+
},
|
| 101 |
+
"language_info": {
|
| 102 |
+
"codemirror_mode": {
|
| 103 |
+
"name": "ipython",
|
| 104 |
+
"version": 3
|
| 105 |
+
},
|
| 106 |
+
"file_extension": ".py",
|
| 107 |
+
"mimetype": "text/x-python",
|
| 108 |
+
"name": "python",
|
| 109 |
+
"nbconvert_exporter": "python",
|
| 110 |
+
"pygments_lexer": "ipython3",
|
| 111 |
+
"version": "3.8.8"
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
"nbformat": 4,
|
| 115 |
+
"nbformat_minor": 5
|
| 116 |
+
}
|
models/baseline.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision.ops.misc import Conv2dNormActivation
|
| 4 |
+
|
| 5 |
+
from .helpers.utils import make_divisible
|
| 6 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 7 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def initialize_weights(m):
|
| 11 |
+
if isinstance(m, nn.Conv2d):
|
| 12 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
| 13 |
+
if m.bias is not None:
|
| 14 |
+
nn.init.zeros_(m.bias)
|
| 15 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
|
| 16 |
+
nn.init.ones_(m.weight)
|
| 17 |
+
nn.init.zeros_(m.bias)
|
| 18 |
+
elif isinstance(m, nn.Linear):
|
| 19 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 20 |
+
if m.bias is not None:
|
| 21 |
+
nn.init.zeros_(m.bias)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Block(nn.Module):
|
| 25 |
+
def __init__(self, in_channels, out_channels, expansion_rate, stride):
|
| 26 |
+
super().__init__()
|
| 27 |
+
exp_channels = make_divisible(in_channels * expansion_rate, 8)
|
| 28 |
+
|
| 29 |
+
# create the three factorized convs that make up the inverted bottleneck block
|
| 30 |
+
exp_conv = Conv2dNormActivation(
|
| 31 |
+
in_channels,
|
| 32 |
+
exp_channels,
|
| 33 |
+
kernel_size=1,
|
| 34 |
+
stride=1,
|
| 35 |
+
norm_layer=nn.BatchNorm2d,
|
| 36 |
+
activation_layer=nn.ReLU,
|
| 37 |
+
inplace=False,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# depthwise convolution with possible stride
|
| 41 |
+
depth_conv = Conv2dNormActivation(
|
| 42 |
+
exp_channels,
|
| 43 |
+
exp_channels,
|
| 44 |
+
kernel_size=3,
|
| 45 |
+
stride=stride,
|
| 46 |
+
padding=1,
|
| 47 |
+
groups=exp_channels,
|
| 48 |
+
norm_layer=nn.BatchNorm2d,
|
| 49 |
+
activation_layer=nn.ReLU,
|
| 50 |
+
inplace=False,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
proj_conv = Conv2dNormActivation(
|
| 54 |
+
exp_channels,
|
| 55 |
+
out_channels,
|
| 56 |
+
kernel_size=1,
|
| 57 |
+
stride=1,
|
| 58 |
+
norm_layer=nn.BatchNorm2d,
|
| 59 |
+
activation_layer=None,
|
| 60 |
+
inplace=False,
|
| 61 |
+
)
|
| 62 |
+
self.after_block_activation = nn.ReLU()
|
| 63 |
+
|
| 64 |
+
if in_channels == out_channels:
|
| 65 |
+
self.use_shortcut = True
|
| 66 |
+
if stride == 1 or stride == (1, 1):
|
| 67 |
+
self.shortcut = nn.Sequential()
|
| 68 |
+
else:
|
| 69 |
+
# average pooling required for shortcut
|
| 70 |
+
self.shortcut = nn.Sequential(
|
| 71 |
+
nn.AvgPool2d(kernel_size=3, stride=stride, padding=1),
|
| 72 |
+
nn.Sequential(),
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
self.use_shortcut = False
|
| 76 |
+
|
| 77 |
+
self.block = nn.Sequential(exp_conv, depth_conv, proj_conv)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
if self.use_shortcut:
|
| 81 |
+
x = self.block(x) + self.shortcut(x)
|
| 82 |
+
else:
|
| 83 |
+
x = self.block(x)
|
| 84 |
+
x = self.after_block_activation(x)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class NetworkConfig(PretrainedConfig):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
n_classes=10,
|
| 92 |
+
in_channels=1,
|
| 93 |
+
base_channels=32,
|
| 94 |
+
channels_multiplier=2.3,
|
| 95 |
+
expansion_rate=3.0,
|
| 96 |
+
n_blocks=(3, 2, 1),
|
| 97 |
+
strides=dict(b2=(1, 1), b3=(1, 2), b4=(2, 1)),
|
| 98 |
+
add_feats=False,
|
| 99 |
+
*args,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
super().__init__(*args, **kwargs)
|
| 103 |
+
self.n_classes = n_classes
|
| 104 |
+
self.in_channels = in_channels
|
| 105 |
+
self.base_channels = base_channels
|
| 106 |
+
self.channels_multiplier = channels_multiplier
|
| 107 |
+
self.expansion_rate = expansion_rate
|
| 108 |
+
self.n_blocks = n_blocks
|
| 109 |
+
self.strides = strides
|
| 110 |
+
self.add_feats = add_feats
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Network(PreTrainedModel):
|
| 114 |
+
config_class = NetworkConfig
|
| 115 |
+
|
| 116 |
+
def __init__(self, config):
|
| 117 |
+
super().__init__(config)
|
| 118 |
+
n_classes = config.n_classes
|
| 119 |
+
in_channels = config.in_channels
|
| 120 |
+
base_channels = config.base_channels
|
| 121 |
+
channels_multiplier = config.channels_multiplier
|
| 122 |
+
expansion_rate = config.expansion_rate
|
| 123 |
+
n_blocks = config.n_blocks
|
| 124 |
+
strides = config.strides
|
| 125 |
+
n_stages = len(n_blocks)
|
| 126 |
+
|
| 127 |
+
self.add_feats = config.add_feats
|
| 128 |
+
|
| 129 |
+
base_channels = make_divisible(base_channels, 8)
|
| 130 |
+
channels_per_stage = [base_channels] + [
|
| 131 |
+
make_divisible(base_channels * channels_multiplier**stage_id, 8)
|
| 132 |
+
for stage_id in range(n_stages)
|
| 133 |
+
]
|
| 134 |
+
self.total_block_count = 0
|
| 135 |
+
|
| 136 |
+
self.in_c = nn.Sequential(
|
| 137 |
+
Conv2dNormActivation(
|
| 138 |
+
in_channels,
|
| 139 |
+
channels_per_stage[0] // 4,
|
| 140 |
+
activation_layer=torch.nn.ReLU,
|
| 141 |
+
kernel_size=3,
|
| 142 |
+
stride=2,
|
| 143 |
+
inplace=False,
|
| 144 |
+
),
|
| 145 |
+
Conv2dNormActivation(
|
| 146 |
+
channels_per_stage[0] // 4,
|
| 147 |
+
channels_per_stage[0],
|
| 148 |
+
activation_layer=torch.nn.ReLU,
|
| 149 |
+
kernel_size=3,
|
| 150 |
+
stride=2,
|
| 151 |
+
inplace=False,
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.stages = nn.Sequential()
|
| 156 |
+
for stage_id in range(n_stages):
|
| 157 |
+
stage = self._make_stage(
|
| 158 |
+
channels_per_stage[stage_id],
|
| 159 |
+
channels_per_stage[stage_id + 1],
|
| 160 |
+
n_blocks[stage_id],
|
| 161 |
+
strides=strides,
|
| 162 |
+
expansion_rate=expansion_rate,
|
| 163 |
+
)
|
| 164 |
+
self.stages.add_module(f"s{stage_id + 1}", stage)
|
| 165 |
+
|
| 166 |
+
ff_list = []
|
| 167 |
+
ff_list += [
|
| 168 |
+
nn.Conv2d(
|
| 169 |
+
channels_per_stage[-1],
|
| 170 |
+
n_classes,
|
| 171 |
+
kernel_size=(1, 1),
|
| 172 |
+
stride=(1, 1),
|
| 173 |
+
padding=0,
|
| 174 |
+
bias=False,
|
| 175 |
+
),
|
| 176 |
+
nn.BatchNorm2d(n_classes),
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
ff_list.append(nn.AdaptiveAvgPool2d((1, 1)))
|
| 180 |
+
|
| 181 |
+
self.feed_forward = nn.Sequential(*ff_list)
|
| 182 |
+
|
| 183 |
+
self.apply(initialize_weights)
|
| 184 |
+
|
| 185 |
+
def _make_stage(self, in_channels, out_channels, n_blocks, strides, expansion_rate):
|
| 186 |
+
stage = nn.Sequential()
|
| 187 |
+
for index in range(n_blocks):
|
| 188 |
+
block_id = self.total_block_count + 1
|
| 189 |
+
bname = f"b{block_id}"
|
| 190 |
+
self.total_block_count = self.total_block_count + 1
|
| 191 |
+
if bname in strides:
|
| 192 |
+
stride = strides[bname]
|
| 193 |
+
else:
|
| 194 |
+
stride = (1, 1)
|
| 195 |
+
|
| 196 |
+
block = self._make_block(
|
| 197 |
+
in_channels, out_channels, stride=stride, expansion_rate=expansion_rate
|
| 198 |
+
)
|
| 199 |
+
stage.add_module(bname, block)
|
| 200 |
+
|
| 201 |
+
in_channels = out_channels
|
| 202 |
+
return stage
|
| 203 |
+
|
| 204 |
+
def _make_block(self, in_channels, out_channels, stride, expansion_rate):
|
| 205 |
+
|
| 206 |
+
block = Block(in_channels, out_channels, expansion_rate, stride)
|
| 207 |
+
return block
|
| 208 |
+
|
| 209 |
+
def _forward_conv(self, x):
|
| 210 |
+
x = self.in_c(x)
|
| 211 |
+
x = self.stages(x)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
y = self._forward_conv(x)
|
| 216 |
+
x = self.feed_forward(y)
|
| 217 |
+
logits = x.squeeze(2).squeeze(2)
|
| 218 |
+
if self.add_feats:
|
| 219 |
+
return logits, y
|
| 220 |
+
else:
|
| 221 |
+
return logits
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_model(
|
| 225 |
+
n_classes=10,
|
| 226 |
+
in_channels=1,
|
| 227 |
+
base_channels=32,
|
| 228 |
+
channels_multiplier=2.3,
|
| 229 |
+
expansion_rate=3.0,
|
| 230 |
+
n_blocks=(3, 2, 1),
|
| 231 |
+
strides=None,
|
| 232 |
+
add_feats=False,
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
@param n_classes: number of the classes to predict
|
| 236 |
+
@param in_channels: input channels to the network, for audio it is by default 1
|
| 237 |
+
@param base_channels: number of channels after in_conv
|
| 238 |
+
@param channels_multiplier: controls the increase in the width of the network after each stage
|
| 239 |
+
@param expansion_rate: determines the expansion rate in inverted bottleneck blocks
|
| 240 |
+
@param n_blocks: number of blocks that should exist in each stage
|
| 241 |
+
@param strides: default value set below
|
| 242 |
+
@return: full neural network model based on the specified configs
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
if strides is None:
|
| 246 |
+
strides = dict(b2=(1, 1), b3=(1, 2), b4=(2, 1))
|
| 247 |
+
|
| 248 |
+
model_config = {
|
| 249 |
+
"n_classes": n_classes,
|
| 250 |
+
"in_channels": in_channels,
|
| 251 |
+
"base_channels": base_channels,
|
| 252 |
+
"channels_multiplier": channels_multiplier,
|
| 253 |
+
"expansion_rate": expansion_rate,
|
| 254 |
+
"n_blocks": n_blocks,
|
| 255 |
+
"strides": strides,
|
| 256 |
+
"add_feats": add_feats,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
m = Network(NetworkConfig(**model_config))
|
| 260 |
+
return m
|
models/helpers/utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
|
| 5 |
+
if min_value is None:
|
| 6 |
+
min_value = divisor
|
| 7 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 8 |
+
# Make sure that round down does not go down by more than 10%.
|
| 9 |
+
if new_v < 0.9 * v:
|
| 10 |
+
new_v += divisor
|
| 11 |
+
return new_v
|
models/mel.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchaudio
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AugmentMelSTFT(nn.Module):
|
| 7 |
+
def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
|
| 8 |
+
fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000, norm_mel=False):
|
| 9 |
+
"""
|
| 10 |
+
:param n_mels: number of mel bins
|
| 11 |
+
:param sr: sampling rate used (same as passed as argument to dataset)
|
| 12 |
+
:param win_length: fft window length in samples
|
| 13 |
+
:param hopsize: fft hop size in samples
|
| 14 |
+
:param n_fft: length of fft
|
| 15 |
+
:param freqm: maximum possible length of mask along frequency dimension
|
| 16 |
+
:param timem: maximum possible length of mask along time dimension
|
| 17 |
+
:param fmin: minimum frequency used
|
| 18 |
+
:param fmax: maximum frequency used
|
| 19 |
+
:param fmin_aug_range: randomly changes min frequency
|
| 20 |
+
:param fmax_aug_range: randomly changes max frequency
|
| 21 |
+
"""
|
| 22 |
+
torch.nn.Module.__init__(self)
|
| 23 |
+
# adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
|
| 24 |
+
|
| 25 |
+
self.win_length = win_length
|
| 26 |
+
self.n_mels = n_mels
|
| 27 |
+
self.n_fft = n_fft
|
| 28 |
+
self.sr = sr
|
| 29 |
+
self.fmin = fmin
|
| 30 |
+
if fmax is None:
|
| 31 |
+
# fmax is by default set to sampling_rate/2 -> Nyquist!
|
| 32 |
+
fmax = sr // 2 - fmax_aug_range // 2
|
| 33 |
+
print(f"Warning: FMAX is None setting to {fmax} ")
|
| 34 |
+
self.fmax = fmax
|
| 35 |
+
self.hopsize = hopsize
|
| 36 |
+
# buffers are not trained by the optimizer, persistent=False also avoids adding it to the model's state dict
|
| 37 |
+
self.register_buffer('window',
|
| 38 |
+
torch.hann_window(win_length, periodic=False),
|
| 39 |
+
persistent=False)
|
| 40 |
+
assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
|
| 41 |
+
assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
|
| 42 |
+
self.fmin_aug_range = fmin_aug_range
|
| 43 |
+
self.fmax_aug_range = fmax_aug_range
|
| 44 |
+
|
| 45 |
+
self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)
|
| 46 |
+
if freqm == 0:
|
| 47 |
+
self.freqm = torch.nn.Identity()
|
| 48 |
+
else:
|
| 49 |
+
self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
|
| 50 |
+
if timem == 0:
|
| 51 |
+
self.timem = torch.nn.Identity()
|
| 52 |
+
else:
|
| 53 |
+
self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
|
| 54 |
+
self.norm_mel = norm_mel
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
# shape: batch size x samples
|
| 58 |
+
# majority of energy located in lower end of the spectrum, pre-emphasis compensates for the average spectral
|
| 59 |
+
# shape
|
| 60 |
+
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
|
| 61 |
+
|
| 62 |
+
# Short-Time Fourier Transform using Hanning window
|
| 63 |
+
x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
|
| 64 |
+
center=True, normalized=False, window=self.window, return_complex=False)
|
| 65 |
+
# shape: batch size x freqs (n_fft/2 + 1) x timeframes (samples/hop_length) x 2 (real and imaginary components)
|
| 66 |
+
|
| 67 |
+
# calculate power spectrum
|
| 68 |
+
x = (x ** 2).sum(dim=-1)
|
| 69 |
+
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
|
| 70 |
+
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
|
| 71 |
+
|
| 72 |
+
if not self.training:
|
| 73 |
+
# don't augment eval data
|
| 74 |
+
fmin = self.fmin
|
| 75 |
+
fmax = self.fmax
|
| 76 |
+
|
| 77 |
+
# create mel filterbank
|
| 78 |
+
mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr,
|
| 79 |
+
fmin, fmax, vtln_low=100.0, vtln_high=-500.,
|
| 80 |
+
vtln_warp_factor=1.0)
|
| 81 |
+
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
|
| 82 |
+
device=x.device)
|
| 83 |
+
if self.norm_mel:
|
| 84 |
+
mel_basis = mel_basis / mel_basis.sum(1)[:, None]
|
| 85 |
+
# apply mel filterbank to power spectrogram
|
| 86 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 87 |
+
melspec = torch.matmul(mel_basis, x)
|
| 88 |
+
# calculate log mel spectrogram
|
| 89 |
+
melspec = (melspec + 0.00001).log()
|
| 90 |
+
|
| 91 |
+
if self.training:
|
| 92 |
+
# don't augment eval data
|
| 93 |
+
melspec = self.freqm(melspec)
|
| 94 |
+
melspec = self.timem(melspec)
|
| 95 |
+
|
| 96 |
+
melspec = (melspec + 4.5) / 5. # fast normalization
|
| 97 |
+
return melspec
|
split10/config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_feats": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Network"
|
| 5 |
+
],
|
| 6 |
+
"base_channels": 32,
|
| 7 |
+
"channels_multiplier": 1.8,
|
| 8 |
+
"expansion_rate": 2.1,
|
| 9 |
+
"in_channels": 1,
|
| 10 |
+
"n_blocks": [
|
| 11 |
+
3,
|
| 12 |
+
2,
|
| 13 |
+
1
|
| 14 |
+
],
|
| 15 |
+
"n_classes": 10,
|
| 16 |
+
"strides": {
|
| 17 |
+
"b2": [
|
| 18 |
+
1,
|
| 19 |
+
1
|
| 20 |
+
],
|
| 21 |
+
"b3": [
|
| 22 |
+
1,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"b4": [
|
| 26 |
+
2,
|
| 27 |
+
1
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"torch_dtype": "float16",
|
| 31 |
+
"transformers_version": "4.37.1"
|
| 32 |
+
}
|
split10/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b57c03e4b9ecd67b1c33d8ed471bed10b4e5dfe872b2fa9bb1bb6bc3405b9b0
|
| 3 |
+
size 139504
|
split100/config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_feats": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Network"
|
| 5 |
+
],
|
| 6 |
+
"base_channels": 32,
|
| 7 |
+
"channels_multiplier": 1.8,
|
| 8 |
+
"expansion_rate": 2.1,
|
| 9 |
+
"in_channels": 1,
|
| 10 |
+
"n_blocks": [
|
| 11 |
+
3,
|
| 12 |
+
2,
|
| 13 |
+
1
|
| 14 |
+
],
|
| 15 |
+
"n_classes": 10,
|
| 16 |
+
"strides": {
|
| 17 |
+
"b2": [
|
| 18 |
+
1,
|
| 19 |
+
1
|
| 20 |
+
],
|
| 21 |
+
"b3": [
|
| 22 |
+
1,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"b4": [
|
| 26 |
+
2,
|
| 27 |
+
1
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"torch_dtype": "float16",
|
| 31 |
+
"transformers_version": "4.37.1"
|
| 32 |
+
}
|
split100/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02791b0e81bad11f603806e00252e76ceb4fb62f0d4880c115ed95513262b172
|
| 3 |
+
size 139504
|
split25/config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_feats": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Network"
|
| 5 |
+
],
|
| 6 |
+
"base_channels": 32,
|
| 7 |
+
"channels_multiplier": 1.8,
|
| 8 |
+
"expansion_rate": 2.1,
|
| 9 |
+
"in_channels": 1,
|
| 10 |
+
"n_blocks": [
|
| 11 |
+
3,
|
| 12 |
+
2,
|
| 13 |
+
1
|
| 14 |
+
],
|
| 15 |
+
"n_classes": 10,
|
| 16 |
+
"strides": {
|
| 17 |
+
"b2": [
|
| 18 |
+
1,
|
| 19 |
+
1
|
| 20 |
+
],
|
| 21 |
+
"b3": [
|
| 22 |
+
1,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"b4": [
|
| 26 |
+
2,
|
| 27 |
+
1
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"torch_dtype": "float16",
|
| 31 |
+
"transformers_version": "4.37.1"
|
| 32 |
+
}
|
split25/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10a46c2d2b99b39965a4fa4fbac6171275efb01bd7da08a7f768779df5effba0
|
| 3 |
+
size 139504
|
split5/config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_feats": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Network"
|
| 5 |
+
],
|
| 6 |
+
"base_channels": 32,
|
| 7 |
+
"channels_multiplier": 1.8,
|
| 8 |
+
"expansion_rate": 2.1,
|
| 9 |
+
"in_channels": 1,
|
| 10 |
+
"n_blocks": [
|
| 11 |
+
3,
|
| 12 |
+
2,
|
| 13 |
+
1
|
| 14 |
+
],
|
| 15 |
+
"n_classes": 10,
|
| 16 |
+
"strides": {
|
| 17 |
+
"b2": [
|
| 18 |
+
1,
|
| 19 |
+
1
|
| 20 |
+
],
|
| 21 |
+
"b3": [
|
| 22 |
+
1,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"b4": [
|
| 26 |
+
2,
|
| 27 |
+
1
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"torch_dtype": "float16",
|
| 31 |
+
"transformers_version": "4.37.1"
|
| 32 |
+
}
|
split5/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5401136f83a10102de48f37c724a50e188e4b89186177a9ceb2945ddc57f5b49
|
| 3 |
+
size 139504
|
split50/config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_feats": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Network"
|
| 5 |
+
],
|
| 6 |
+
"base_channels": 32,
|
| 7 |
+
"channels_multiplier": 1.8,
|
| 8 |
+
"expansion_rate": 2.1,
|
| 9 |
+
"in_channels": 1,
|
| 10 |
+
"n_blocks": [
|
| 11 |
+
3,
|
| 12 |
+
2,
|
| 13 |
+
1
|
| 14 |
+
],
|
| 15 |
+
"n_classes": 10,
|
| 16 |
+
"strides": {
|
| 17 |
+
"b2": [
|
| 18 |
+
1,
|
| 19 |
+
1
|
| 20 |
+
],
|
| 21 |
+
"b3": [
|
| 22 |
+
1,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"b4": [
|
| 26 |
+
2,
|
| 27 |
+
1
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"torch_dtype": "float16",
|
| 31 |
+
"transformers_version": "4.37.1"
|
| 32 |
+
}
|
split50/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:513d8bc217266fbb3a3f7354e3884eb5478550c47d9daa44cf1cdb7b08c54984
|
| 3 |
+
size 139504
|