ZeyadMostafa22
commited on
Commit
·
db175f8
1
Parent(s):
a1e9e88
final version
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
import torchaudio
|
4 |
import numpy as np
|
5 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
|
|
6 |
|
7 |
MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
|
8 |
|
@@ -33,13 +34,20 @@ def classify_audio(audio_file):
|
|
33 |
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
34 |
waveform = waveform.squeeze() # (samples,)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# 3) Preprocess with feature_extractor
|
37 |
inputs = feature_extractor(
|
38 |
waveform.numpy(),
|
39 |
sampling_rate=sr,
|
40 |
return_tensors="pt",
|
41 |
truncation=True,
|
42 |
-
max_length=int(
|
43 |
)
|
44 |
|
45 |
# Move everything to device
|
|
|
3 |
import torchaudio
|
4 |
import numpy as np
|
5 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
6 |
+
import torchaudio.transforms as T
|
7 |
|
8 |
MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
|
9 |
|
|
|
34 |
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
35 |
waveform = waveform.squeeze() # (samples,)
|
36 |
|
37 |
+
# 3) Resample if needed
|
38 |
+
if sr != 16000:
|
39 |
+
resampler = T.Resample(sr, 16000)
|
40 |
+
waveform = resampler(waveform)
|
41 |
+
sr = 16000
|
42 |
+
|
43 |
+
|
44 |
# 3) Preprocess with feature_extractor
|
45 |
inputs = feature_extractor(
|
46 |
waveform.numpy(),
|
47 |
sampling_rate=sr,
|
48 |
return_tensors="pt",
|
49 |
truncation=True,
|
50 |
+
max_length=int(16000* 6.0), # 6 second max
|
51 |
)
|
52 |
|
53 |
# Move everything to device
|