werning commited on
Commit
f3b0a90
·
1 Parent(s): a77a8ee

Add code and weights

Browse files
example.ipynb ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "09225787-6a4b-4484-b00b-d0f731915a81",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from models.baseline import Network\n",
11
+ "from models.mel import AugmentMelSTFT\n",
12
+ "import soundfile as sf\n",
13
+ "import torch\n",
14
+ "import numpy as np"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "id": "c377b699-2c2e-468e-88b0-6767338988c8",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "audio_path = \"/path/to/audio.wav\""
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "fa950347-df0d-4135-801a-d54525c57e58",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "from IPython.display import display, Audio\n",
35
+ "\n",
36
+ "display(Audio(audio_path))"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "id": "79faad26-0f20-439d-b152-10f4666db41d",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "mel = AugmentMelSTFT().eval()\n",
47
+ "model = Network.from_pretrained(\"split5\").eval()\n",
48
+ "\n",
49
+ "audio, sr = sf.read(audio_path, dtype=np.float32)\n",
50
+ "assert sr == 32_000\n",
51
+ "\n",
52
+ "audio = torch.as_tensor(audio)\n",
53
+ "\n",
54
+ "# audio.shape: (1,samples)\n",
55
+ "\n",
56
+ "audio = audio.unsqueeze(0)\n",
57
+ "\n",
58
+ "# audio.shape: (1,1,samples)\n",
59
+ "\n",
60
+ "with torch.no_grad():\n",
61
+ " mel_spec = mel(audio)\n",
62
+ "\n",
63
+ "# mel_spec.shape: (1, mel_bins, frames)\n",
64
+ "\n",
65
+ "mel_spec = mel_spec.unsqueeze(0)\n",
66
+ "\n",
67
+ "with torch.no_grad():\n",
68
+ " logits = model(mel_spec)\n",
69
+ "\n",
70
+ "# logits.shape: (1,classes)\n",
71
+ "\n",
72
+ "logits = logits.squeeze(0)\n",
73
+ "\n",
74
+ "tau2022_classes = [\n",
75
+ " \"airport\",\n",
76
+ " \"bus\",\n",
77
+ " \"metro\",\n",
78
+ " \"metro_station\",\n",
79
+ " \"park\",\n",
80
+ " \"public_square\",\n",
81
+ " \"shopping_mall\",\n",
82
+ " \"street_pedestrian\",\n",
83
+ " \"street_traffic\",\n",
84
+ " \"tram\"\n",
85
+ "]\n",
86
+ "\n",
87
+ "best_prediction_idx = torch.argmax(logits)\n",
88
+ "\n",
89
+ "scores = torch.softmax(logits, dim=0)\n",
90
+ "\n",
91
+ "print(f\"Prediction: {tau2022_classes[best_prediction_idx]} (score: {scores[best_prediction_idx]:0.2f})\")"
92
+ ]
93
+ }
94
+ ],
95
+ "metadata": {
96
+ "kernelspec": {
97
+ "display_name": "Python 3 (ipykernel)",
98
+ "language": "python",
99
+ "name": "python3"
100
+ },
101
+ "language_info": {
102
+ "codemirror_mode": {
103
+ "name": "ipython",
104
+ "version": 3
105
+ },
106
+ "file_extension": ".py",
107
+ "mimetype": "text/x-python",
108
+ "name": "python",
109
+ "nbconvert_exporter": "python",
110
+ "pygments_lexer": "ipython3",
111
+ "version": "3.8.8"
112
+ }
113
+ },
114
+ "nbformat": 4,
115
+ "nbformat_minor": 5
116
+ }
models/baseline.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.ops.misc import Conv2dNormActivation
4
+
5
+ from .helpers.utils import make_divisible
6
+ from transformers.modeling_utils import PreTrainedModel
7
+ from transformers.configuration_utils import PretrainedConfig
8
+
9
+
10
+ def initialize_weights(m):
11
+ if isinstance(m, nn.Conv2d):
12
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
13
+ if m.bias is not None:
14
+ nn.init.zeros_(m.bias)
15
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
16
+ nn.init.ones_(m.weight)
17
+ nn.init.zeros_(m.bias)
18
+ elif isinstance(m, nn.Linear):
19
+ nn.init.normal_(m.weight, 0, 0.01)
20
+ if m.bias is not None:
21
+ nn.init.zeros_(m.bias)
22
+
23
+
24
+ class Block(nn.Module):
25
+ def __init__(self, in_channels, out_channels, expansion_rate, stride):
26
+ super().__init__()
27
+ exp_channels = make_divisible(in_channels * expansion_rate, 8)
28
+
29
+ # create the three factorized convs that make up the inverted bottleneck block
30
+ exp_conv = Conv2dNormActivation(
31
+ in_channels,
32
+ exp_channels,
33
+ kernel_size=1,
34
+ stride=1,
35
+ norm_layer=nn.BatchNorm2d,
36
+ activation_layer=nn.ReLU,
37
+ inplace=False,
38
+ )
39
+
40
+ # depthwise convolution with possible stride
41
+ depth_conv = Conv2dNormActivation(
42
+ exp_channels,
43
+ exp_channels,
44
+ kernel_size=3,
45
+ stride=stride,
46
+ padding=1,
47
+ groups=exp_channels,
48
+ norm_layer=nn.BatchNorm2d,
49
+ activation_layer=nn.ReLU,
50
+ inplace=False,
51
+ )
52
+
53
+ proj_conv = Conv2dNormActivation(
54
+ exp_channels,
55
+ out_channels,
56
+ kernel_size=1,
57
+ stride=1,
58
+ norm_layer=nn.BatchNorm2d,
59
+ activation_layer=None,
60
+ inplace=False,
61
+ )
62
+ self.after_block_activation = nn.ReLU()
63
+
64
+ if in_channels == out_channels:
65
+ self.use_shortcut = True
66
+ if stride == 1 or stride == (1, 1):
67
+ self.shortcut = nn.Sequential()
68
+ else:
69
+ # average pooling required for shortcut
70
+ self.shortcut = nn.Sequential(
71
+ nn.AvgPool2d(kernel_size=3, stride=stride, padding=1),
72
+ nn.Sequential(),
73
+ )
74
+ else:
75
+ self.use_shortcut = False
76
+
77
+ self.block = nn.Sequential(exp_conv, depth_conv, proj_conv)
78
+
79
+ def forward(self, x):
80
+ if self.use_shortcut:
81
+ x = self.block(x) + self.shortcut(x)
82
+ else:
83
+ x = self.block(x)
84
+ x = self.after_block_activation(x)
85
+ return x
86
+
87
+
88
+ class NetworkConfig(PretrainedConfig):
89
+ def __init__(
90
+ self,
91
+ n_classes=10,
92
+ in_channels=1,
93
+ base_channels=32,
94
+ channels_multiplier=2.3,
95
+ expansion_rate=3.0,
96
+ n_blocks=(3, 2, 1),
97
+ strides=dict(b2=(1, 1), b3=(1, 2), b4=(2, 1)),
98
+ add_feats=False,
99
+ *args,
100
+ **kwargs,
101
+ ):
102
+ super().__init__(*args, **kwargs)
103
+ self.n_classes = n_classes
104
+ self.in_channels = in_channels
105
+ self.base_channels = base_channels
106
+ self.channels_multiplier = channels_multiplier
107
+ self.expansion_rate = expansion_rate
108
+ self.n_blocks = n_blocks
109
+ self.strides = strides
110
+ self.add_feats = add_feats
111
+
112
+
113
+ class Network(PreTrainedModel):
114
+ config_class = NetworkConfig
115
+
116
+ def __init__(self, config):
117
+ super().__init__(config)
118
+ n_classes = config.n_classes
119
+ in_channels = config.in_channels
120
+ base_channels = config.base_channels
121
+ channels_multiplier = config.channels_multiplier
122
+ expansion_rate = config.expansion_rate
123
+ n_blocks = config.n_blocks
124
+ strides = config.strides
125
+ n_stages = len(n_blocks)
126
+
127
+ self.add_feats = config.add_feats
128
+
129
+ base_channels = make_divisible(base_channels, 8)
130
+ channels_per_stage = [base_channels] + [
131
+ make_divisible(base_channels * channels_multiplier**stage_id, 8)
132
+ for stage_id in range(n_stages)
133
+ ]
134
+ self.total_block_count = 0
135
+
136
+ self.in_c = nn.Sequential(
137
+ Conv2dNormActivation(
138
+ in_channels,
139
+ channels_per_stage[0] // 4,
140
+ activation_layer=torch.nn.ReLU,
141
+ kernel_size=3,
142
+ stride=2,
143
+ inplace=False,
144
+ ),
145
+ Conv2dNormActivation(
146
+ channels_per_stage[0] // 4,
147
+ channels_per_stage[0],
148
+ activation_layer=torch.nn.ReLU,
149
+ kernel_size=3,
150
+ stride=2,
151
+ inplace=False,
152
+ ),
153
+ )
154
+
155
+ self.stages = nn.Sequential()
156
+ for stage_id in range(n_stages):
157
+ stage = self._make_stage(
158
+ channels_per_stage[stage_id],
159
+ channels_per_stage[stage_id + 1],
160
+ n_blocks[stage_id],
161
+ strides=strides,
162
+ expansion_rate=expansion_rate,
163
+ )
164
+ self.stages.add_module(f"s{stage_id + 1}", stage)
165
+
166
+ ff_list = []
167
+ ff_list += [
168
+ nn.Conv2d(
169
+ channels_per_stage[-1],
170
+ n_classes,
171
+ kernel_size=(1, 1),
172
+ stride=(1, 1),
173
+ padding=0,
174
+ bias=False,
175
+ ),
176
+ nn.BatchNorm2d(n_classes),
177
+ ]
178
+
179
+ ff_list.append(nn.AdaptiveAvgPool2d((1, 1)))
180
+
181
+ self.feed_forward = nn.Sequential(*ff_list)
182
+
183
+ self.apply(initialize_weights)
184
+
185
+ def _make_stage(self, in_channels, out_channels, n_blocks, strides, expansion_rate):
186
+ stage = nn.Sequential()
187
+ for index in range(n_blocks):
188
+ block_id = self.total_block_count + 1
189
+ bname = f"b{block_id}"
190
+ self.total_block_count = self.total_block_count + 1
191
+ if bname in strides:
192
+ stride = strides[bname]
193
+ else:
194
+ stride = (1, 1)
195
+
196
+ block = self._make_block(
197
+ in_channels, out_channels, stride=stride, expansion_rate=expansion_rate
198
+ )
199
+ stage.add_module(bname, block)
200
+
201
+ in_channels = out_channels
202
+ return stage
203
+
204
+ def _make_block(self, in_channels, out_channels, stride, expansion_rate):
205
+
206
+ block = Block(in_channels, out_channels, expansion_rate, stride)
207
+ return block
208
+
209
+ def _forward_conv(self, x):
210
+ x = self.in_c(x)
211
+ x = self.stages(x)
212
+ return x
213
+
214
+ def forward(self, x):
215
+ y = self._forward_conv(x)
216
+ x = self.feed_forward(y)
217
+ logits = x.squeeze(2).squeeze(2)
218
+ if self.add_feats:
219
+ return logits, y
220
+ else:
221
+ return logits
222
+
223
+
224
+ def get_model(
225
+ n_classes=10,
226
+ in_channels=1,
227
+ base_channels=32,
228
+ channels_multiplier=2.3,
229
+ expansion_rate=3.0,
230
+ n_blocks=(3, 2, 1),
231
+ strides=None,
232
+ add_feats=False,
233
+ ):
234
+ """
235
+ @param n_classes: number of the classes to predict
236
+ @param in_channels: input channels to the network, for audio it is by default 1
237
+ @param base_channels: number of channels after in_conv
238
+ @param channels_multiplier: controls the increase in the width of the network after each stage
239
+ @param expansion_rate: determines the expansion rate in inverted bottleneck blocks
240
+ @param n_blocks: number of blocks that should exist in each stage
241
+ @param strides: default value set below
242
+ @return: full neural network model based on the specified configs
243
+ """
244
+
245
+ if strides is None:
246
+ strides = dict(b2=(1, 1), b3=(1, 2), b4=(2, 1))
247
+
248
+ model_config = {
249
+ "n_classes": n_classes,
250
+ "in_channels": in_channels,
251
+ "base_channels": base_channels,
252
+ "channels_multiplier": channels_multiplier,
253
+ "expansion_rate": expansion_rate,
254
+ "n_blocks": n_blocks,
255
+ "strides": strides,
256
+ "add_feats": add_feats,
257
+ }
258
+
259
+ m = Network(NetworkConfig(**model_config))
260
+ return m
models/helpers/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+
4
+ def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
5
+ if min_value is None:
6
+ min_value = divisor
7
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
8
+ # Make sure that round down does not go down by more than 10%.
9
+ if new_v < 0.9 * v:
10
+ new_v += divisor
11
+ return new_v
models/mel.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+
5
+
6
+ class AugmentMelSTFT(nn.Module):
7
+ def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
8
+ fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000, norm_mel=False):
9
+ """
10
+ :param n_mels: number of mel bins
11
+ :param sr: sampling rate used (same as passed as argument to dataset)
12
+ :param win_length: fft window length in samples
13
+ :param hopsize: fft hop size in samples
14
+ :param n_fft: length of fft
15
+ :param freqm: maximum possible length of mask along frequency dimension
16
+ :param timem: maximum possible length of mask along time dimension
17
+ :param fmin: minimum frequency used
18
+ :param fmax: maximum frequency used
19
+ :param fmin_aug_range: randomly changes min frequency
20
+ :param fmax_aug_range: randomly changes max frequency
21
+ """
22
+ torch.nn.Module.__init__(self)
23
+ # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
24
+
25
+ self.win_length = win_length
26
+ self.n_mels = n_mels
27
+ self.n_fft = n_fft
28
+ self.sr = sr
29
+ self.fmin = fmin
30
+ if fmax is None:
31
+ # fmax is by default set to sampling_rate/2 -> Nyquist!
32
+ fmax = sr // 2 - fmax_aug_range // 2
33
+ print(f"Warning: FMAX is None setting to {fmax} ")
34
+ self.fmax = fmax
35
+ self.hopsize = hopsize
36
+ # buffers are not trained by the optimizer, persistent=False also avoids adding it to the model's state dict
37
+ self.register_buffer('window',
38
+ torch.hann_window(win_length, periodic=False),
39
+ persistent=False)
40
+ assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
41
+ assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
42
+ self.fmin_aug_range = fmin_aug_range
43
+ self.fmax_aug_range = fmax_aug_range
44
+
45
+ self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)
46
+ if freqm == 0:
47
+ self.freqm = torch.nn.Identity()
48
+ else:
49
+ self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
50
+ if timem == 0:
51
+ self.timem = torch.nn.Identity()
52
+ else:
53
+ self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
54
+ self.norm_mel = norm_mel
55
+
56
+ def forward(self, x):
57
+ # shape: batch size x samples
58
+ # majority of energy located in lower end of the spectrum, pre-emphasis compensates for the average spectral
59
+ # shape
60
+ x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
61
+
62
+ # Short-Time Fourier Transform using Hanning window
63
+ x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
64
+ center=True, normalized=False, window=self.window, return_complex=False)
65
+ # shape: batch size x freqs (n_fft/2 + 1) x timeframes (samples/hop_length) x 2 (real and imaginary components)
66
+
67
+ # calculate power spectrum
68
+ x = (x ** 2).sum(dim=-1)
69
+ fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
70
+ fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
71
+
72
+ if not self.training:
73
+ # don't augment eval data
74
+ fmin = self.fmin
75
+ fmax = self.fmax
76
+
77
+ # create mel filterbank
78
+ mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr,
79
+ fmin, fmax, vtln_low=100.0, vtln_high=-500.,
80
+ vtln_warp_factor=1.0)
81
+ mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
82
+ device=x.device)
83
+ if self.norm_mel:
84
+ mel_basis = mel_basis / mel_basis.sum(1)[:, None]
85
+ # apply mel filterbank to power spectrogram
86
+ with torch.cuda.amp.autocast(enabled=False):
87
+ melspec = torch.matmul(mel_basis, x)
88
+ # calculate log mel spectrogram
89
+ melspec = (melspec + 0.00001).log()
90
+
91
+ if self.training:
92
+ # don't augment eval data
93
+ melspec = self.freqm(melspec)
94
+ melspec = self.timem(melspec)
95
+
96
+ melspec = (melspec + 4.5) / 5. # fast normalization
97
+ return melspec
split10/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_feats": false,
3
+ "architectures": [
4
+ "Network"
5
+ ],
6
+ "base_channels": 32,
7
+ "channels_multiplier": 1.8,
8
+ "expansion_rate": 2.1,
9
+ "in_channels": 1,
10
+ "n_blocks": [
11
+ 3,
12
+ 2,
13
+ 1
14
+ ],
15
+ "n_classes": 10,
16
+ "strides": {
17
+ "b2": [
18
+ 1,
19
+ 1
20
+ ],
21
+ "b3": [
22
+ 1,
23
+ 2
24
+ ],
25
+ "b4": [
26
+ 2,
27
+ 1
28
+ ]
29
+ },
30
+ "torch_dtype": "float16",
31
+ "transformers_version": "4.37.1"
32
+ }
split10/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b57c03e4b9ecd67b1c33d8ed471bed10b4e5dfe872b2fa9bb1bb6bc3405b9b0
3
+ size 139504
split100/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_feats": false,
3
+ "architectures": [
4
+ "Network"
5
+ ],
6
+ "base_channels": 32,
7
+ "channels_multiplier": 1.8,
8
+ "expansion_rate": 2.1,
9
+ "in_channels": 1,
10
+ "n_blocks": [
11
+ 3,
12
+ 2,
13
+ 1
14
+ ],
15
+ "n_classes": 10,
16
+ "strides": {
17
+ "b2": [
18
+ 1,
19
+ 1
20
+ ],
21
+ "b3": [
22
+ 1,
23
+ 2
24
+ ],
25
+ "b4": [
26
+ 2,
27
+ 1
28
+ ]
29
+ },
30
+ "torch_dtype": "float16",
31
+ "transformers_version": "4.37.1"
32
+ }
split100/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02791b0e81bad11f603806e00252e76ceb4fb62f0d4880c115ed95513262b172
3
+ size 139504
split25/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_feats": false,
3
+ "architectures": [
4
+ "Network"
5
+ ],
6
+ "base_channels": 32,
7
+ "channels_multiplier": 1.8,
8
+ "expansion_rate": 2.1,
9
+ "in_channels": 1,
10
+ "n_blocks": [
11
+ 3,
12
+ 2,
13
+ 1
14
+ ],
15
+ "n_classes": 10,
16
+ "strides": {
17
+ "b2": [
18
+ 1,
19
+ 1
20
+ ],
21
+ "b3": [
22
+ 1,
23
+ 2
24
+ ],
25
+ "b4": [
26
+ 2,
27
+ 1
28
+ ]
29
+ },
30
+ "torch_dtype": "float16",
31
+ "transformers_version": "4.37.1"
32
+ }
split25/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10a46c2d2b99b39965a4fa4fbac6171275efb01bd7da08a7f768779df5effba0
3
+ size 139504
split5/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_feats": false,
3
+ "architectures": [
4
+ "Network"
5
+ ],
6
+ "base_channels": 32,
7
+ "channels_multiplier": 1.8,
8
+ "expansion_rate": 2.1,
9
+ "in_channels": 1,
10
+ "n_blocks": [
11
+ 3,
12
+ 2,
13
+ 1
14
+ ],
15
+ "n_classes": 10,
16
+ "strides": {
17
+ "b2": [
18
+ 1,
19
+ 1
20
+ ],
21
+ "b3": [
22
+ 1,
23
+ 2
24
+ ],
25
+ "b4": [
26
+ 2,
27
+ 1
28
+ ]
29
+ },
30
+ "torch_dtype": "float16",
31
+ "transformers_version": "4.37.1"
32
+ }
split5/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5401136f83a10102de48f37c724a50e188e4b89186177a9ceb2945ddc57f5b49
3
+ size 139504
split50/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_feats": false,
3
+ "architectures": [
4
+ "Network"
5
+ ],
6
+ "base_channels": 32,
7
+ "channels_multiplier": 1.8,
8
+ "expansion_rate": 2.1,
9
+ "in_channels": 1,
10
+ "n_blocks": [
11
+ 3,
12
+ 2,
13
+ 1
14
+ ],
15
+ "n_classes": 10,
16
+ "strides": {
17
+ "b2": [
18
+ 1,
19
+ 1
20
+ ],
21
+ "b3": [
22
+ 1,
23
+ 2
24
+ ],
25
+ "b4": [
26
+ 2,
27
+ 1
28
+ ]
29
+ },
30
+ "torch_dtype": "float16",
31
+ "transformers_version": "4.37.1"
32
+ }
split50/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:513d8bc217266fbb3a3f7354e3884eb5478550c47d9daa44cf1cdb7b08c54984
3
+ size 139504