ZeyadMostafa22 commited on
Commit
4dc47c8
·
1 Parent(s): c629c7c
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import torchaudio
4
  import numpy as np
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
- import torch.nn.functional as F
7
  import torchaudio.transforms as T
8
 
9
  MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
@@ -19,13 +18,15 @@ model.to(device)
19
 
20
  label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
21
 
 
22
  def classify_audio(audio_file):
23
  """
24
  audio_file: path to the uploaded file (WAV, MP3, etc.)
25
- Returns: predicted label and confidence score
26
  """
27
 
28
  # 2) Load the audio file
 
29
  waveform, sr = torchaudio.load(audio_file)
30
 
31
  # If stereo, pick one channel or average
@@ -39,13 +40,14 @@ def classify_audio(audio_file):
39
  waveform = resampler(waveform)
40
  sr = 16000
41
 
 
42
  # 3) Preprocess with feature_extractor
43
  inputs = feature_extractor(
44
  waveform.numpy(),
45
  sampling_rate=sr,
46
  return_tensors="pt",
47
  truncation=True,
48
- max_length=int(16000 * 6.0), # 6 second max
49
  )
50
 
51
  # Move everything to device
@@ -53,24 +55,20 @@ def classify_audio(audio_file):
53
 
54
  with torch.no_grad():
55
  logits = model(input_values).logits
 
56
 
57
- # 4) Calculate probabilities using softmax
58
- probabilities = F.softmax(logits, dim=-1)
59
-
60
- # Get predicted label and confidence
61
- confidence, pred_id = torch.max(probabilities, dim=-1)
62
- predicted_label = label_names[pred_id.item()]
63
 
64
- # 5) Return label and confidence percentage
65
- return f"Prediction: {predicted_label}, Confidence: {confidence.item() * 100:.2f}%"
66
 
67
- # 6) Build Gradio interface
68
  demo = gr.Interface(
69
  fn=classify_audio,
70
- inputs=gr.Audio(type="filepath"),
71
  outputs="text",
72
  title="Wav2Vec2 Deepfake Detection",
73
- description="Upload an audio sample to check if it is fake or real, along with confidence."
74
  )
75
 
76
  if __name__ == "__main__":
 
3
  import torchaudio
4
  import numpy as np
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
 
6
  import torchaudio.transforms as T
7
 
8
  MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
 
18
 
19
  label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
20
 
21
+
22
  def classify_audio(audio_file):
23
  """
24
  audio_file: path to the uploaded file (WAV, MP3, etc.)
25
+ Returns: "fake" or "real"
26
  """
27
 
28
  # 2) Load the audio file
29
+ # torchaudio returns (waveform, sample_rate)
30
  waveform, sr = torchaudio.load(audio_file)
31
 
32
  # If stereo, pick one channel or average
 
40
  waveform = resampler(waveform)
41
  sr = 16000
42
 
43
+
44
  # 3) Preprocess with feature_extractor
45
  inputs = feature_extractor(
46
  waveform.numpy(),
47
  sampling_rate=sr,
48
  return_tensors="pt",
49
  truncation=True,
50
+ max_length=int(16000* 6.0), # 6 second max
51
  )
52
 
53
  # Move everything to device
 
55
 
56
  with torch.no_grad():
57
  logits = model(input_values).logits
58
+ pred_id = torch.argmax(logits, dim=-1).item()
59
 
60
+ # 4) Return label text
61
+ predicted_label = label_names[pred_id]
62
+ return predicted_label
 
 
 
63
 
 
 
64
 
65
+ # 5) Build Gradio interface
66
  demo = gr.Interface(
67
  fn=classify_audio,
68
+ inputs=gr.Audio( type="filepath"),
69
  outputs="text",
70
  title="Wav2Vec2 Deepfake Detection",
71
+ description="Upload an audio sample to check if it is fake or real."
72
  )
73
 
74
  if __name__ == "__main__":