update preprocessing
Browse files
README.md
CHANGED
@@ -107,6 +107,19 @@ class PowerToDB(torch.nn.Module):
|
|
107 |
return log_spec
|
108 |
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def preprocess(audio, sample_rate_of_audio):
|
111 |
"""
|
112 |
Preprocess the audio to the format that the model expects
|
@@ -115,30 +128,28 @@ def preprocess(audio, sample_rate_of_audio):
|
|
115 |
- Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
|
116 |
|
117 |
"""
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
)
|
123 |
-
audio = resample(audio)
|
124 |
-
spectrogram = torchaudio.transforms.Spectrogram(
|
125 |
-
n_fft=1024, hop_length=320, power=2.0
|
126 |
-
)(audio)
|
127 |
-
melspec = torchaudio.transforms.MelScale(n_mels=128, n_stft=513)(spectrogram)
|
128 |
dbscale = powerToDB(melspec)
|
129 |
-
normalized_dbscale =
|
|
|
|
|
130 |
return normalized_dbscale
|
131 |
|
132 |
preprocessed_audio = preprocess(audio, sample_rate)
|
|
|
|
|
133 |
|
134 |
-
|
|
|
135 |
print("Logits shape: ", logits.shape)
|
136 |
|
137 |
top5 = torch.topk(logits, 5)
|
138 |
print("Top 5 logits:", top5.values)
|
139 |
print("Top 5 predicted classes:")
|
140 |
print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])
|
141 |
-
|
142 |
```
|
143 |
|
144 |
## Model Source
|
|
|
107 |
return log_spec
|
108 |
|
109 |
|
110 |
+
|
111 |
+
# Initialize the transformations
|
112 |
+
|
113 |
+
spectrogram_converter = torchaudio.transforms.Spectrogram(
|
114 |
+
n_fft=1024, hop_length=320, power=2.0
|
115 |
+
)
|
116 |
+
mel_converter = torchaudio.transforms.MelScale(
|
117 |
+
n_mels=128, n_stft=513, sample_rate=32_000
|
118 |
+
)
|
119 |
+
normalizer = transforms.Normalize((-4.268,), (4.569,))
|
120 |
+
powerToDB = PowerToDB(top_db=80)
|
121 |
+
|
122 |
+
|
123 |
def preprocess(audio, sample_rate_of_audio):
|
124 |
"""
|
125 |
Preprocess the audio to the format that the model expects
|
|
|
128 |
- Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
|
129 |
|
130 |
"""
|
131 |
+
# convert waveform to spectrogram
|
132 |
+
spectrogram = spectrogram_converter(audio)
|
133 |
+
spectrogram = spectrogram.to(torch.float32)
|
134 |
+
melspec = mel_converter(spectrogram)
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
dbscale = powerToDB(melspec)
|
136 |
+
normalized_dbscale = normalizer(dbscale)
|
137 |
+
# add dimension 3 from left
|
138 |
+
normalized_dbscale = normalized_dbscale.unsqueeze(-3)
|
139 |
return normalized_dbscale
|
140 |
|
141 |
preprocessed_audio = preprocess(audio, sample_rate)
|
142 |
+
print("Preprocessed_audio shape:", preprocessed_audio.shape)
|
143 |
+
|
144 |
|
145 |
+
|
146 |
+
logits = model(preprocessed_audio).logits
|
147 |
print("Logits shape: ", logits.shape)
|
148 |
|
149 |
top5 = torch.topk(logits, 5)
|
150 |
print("Top 5 logits:", top5.values)
|
151 |
print("Top 5 predicted classes:")
|
152 |
print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])
|
|
|
153 |
```
|
154 |
|
155 |
## Model Source
|