RaphaelSchwinger commited on
Commit
136d36f
·
verified ·
1 Parent(s): cfa17bd

update preprocessing

Browse files
Files changed (1) hide show
  1. README.md +24 -13
README.md CHANGED
@@ -107,6 +107,19 @@ class PowerToDB(torch.nn.Module):
107
  return log_spec
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def preprocess(audio, sample_rate_of_audio):
111
  """
112
  Preprocess the audio to the format that the model expects
@@ -115,30 +128,28 @@ def preprocess(audio, sample_rate_of_audio):
115
  - Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
116
 
117
  """
118
- powerToDB = PowerToDB()
119
- # Resample to 32kHz
120
- resample = torchaudio.transforms.Resample(
121
- orig_freq=sample_rate_of_audio, new_freq=32000
122
- )
123
- audio = resample(audio)
124
- spectrogram = torchaudio.transforms.Spectrogram(
125
- n_fft=1024, hop_length=320, power=2.0
126
- )(audio)
127
- melspec = torchaudio.transforms.MelScale(n_mels=128, n_stft=513)(spectrogram)
128
  dbscale = powerToDB(melspec)
129
- normalized_dbscale = transforms.Normalize((-4.268,), (4.569,))(dbscale)
 
 
130
  return normalized_dbscale
131
 
132
  preprocessed_audio = preprocess(audio, sample_rate)
 
 
133
 
134
- logits = model(preprocessed_audio.unsqueeze(0)).logits
 
135
  print("Logits shape: ", logits.shape)
136
 
137
  top5 = torch.topk(logits, 5)
138
  print("Top 5 logits:", top5.values)
139
  print("Top 5 predicted classes:")
140
  print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])
141
-
142
  ```
143
 
144
  ## Model Source
 
107
  return log_spec
108
 
109
 
110
+
111
+ # Initialize the transformations
112
+
113
+ spectrogram_converter = torchaudio.transforms.Spectrogram(
114
+ n_fft=1024, hop_length=320, power=2.0
115
+ )
116
+ mel_converter = torchaudio.transforms.MelScale(
117
+ n_mels=128, n_stft=513, sample_rate=32_000
118
+ )
119
+ normalizer = transforms.Normalize((-4.268,), (4.569,))
120
+ powerToDB = PowerToDB(top_db=80)
121
+
122
+
123
  def preprocess(audio, sample_rate_of_audio):
124
  """
125
  Preprocess the audio to the format that the model expects
 
128
  - Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
129
 
130
  """
131
+ # convert waveform to spectrogram
132
+ spectrogram = spectrogram_converter(audio)
133
+ spectrogram = spectrogram.to(torch.float32)
134
+ melspec = mel_converter(spectrogram)
 
 
 
 
 
 
135
  dbscale = powerToDB(melspec)
136
+ normalized_dbscale = normalizer(dbscale)
137
+ # add dimension 3 from left
138
+ normalized_dbscale = normalized_dbscale.unsqueeze(-3)
139
  return normalized_dbscale
140
 
141
  preprocessed_audio = preprocess(audio, sample_rate)
142
+ print("Preprocessed_audio shape:", preprocessed_audio.shape)
143
+
144
 
145
+
146
+ logits = model(preprocessed_audio).logits
147
  print("Logits shape: ", logits.shape)
148
 
149
  top5 = torch.topk(logits, 5)
150
  print("Top 5 logits:", top5.values)
151
  print("Top 5 predicted classes:")
152
  print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])
 
153
  ```
154
 
155
  ## Model Source