ZeyadMostafa22 commited on
Commit
db175f8
·
1 Parent(s): a1e9e88

final version

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import torchaudio
4
  import numpy as np
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
 
6
 
7
  MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
8
 
@@ -33,13 +34,20 @@ def classify_audio(audio_file):
33
  waveform = torch.mean(waveform, dim=0, keepdim=True)
34
  waveform = waveform.squeeze() # (samples,)
35
 
 
 
 
 
 
 
 
36
  # 3) Preprocess with feature_extractor
37
  inputs = feature_extractor(
38
  waveform.numpy(),
39
  sampling_rate=sr,
40
  return_tensors="pt",
41
  truncation=True,
42
- max_length=int(feature_extractor.sampling_rate * 6.0), # 6 second max
43
  )
44
 
45
  # Move everything to device
 
3
  import torchaudio
4
  import numpy as np
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
+ import torchaudio.transforms as T
7
 
8
  MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
9
 
 
34
  waveform = torch.mean(waveform, dim=0, keepdim=True)
35
  waveform = waveform.squeeze() # (samples,)
36
 
37
+ # 3) Resample if needed
38
+ if sr != 16000:
39
+ resampler = T.Resample(sr, 16000)
40
+ waveform = resampler(waveform)
41
+ sr = 16000
42
+
43
+
44
  # 3) Preprocess with feature_extractor
45
  inputs = feature_extractor(
46
  waveform.numpy(),
47
  sampling_rate=sr,
48
  return_tensors="pt",
49
  truncation=True,
50
+ max_length=int(16000* 6.0), # 6 second max
51
  )
52
 
53
  # Move everything to device