ZeyadMostafa22 commited on
Commit
c629c7c
·
1 Parent(s): db175f8
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import torchaudio
4
  import numpy as np
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
 
6
  import torchaudio.transforms as T
7
 
8
  MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
@@ -18,15 +19,13 @@ model.to(device)
18
 
19
  label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
20
 
21
-
22
  def classify_audio(audio_file):
23
  """
24
  audio_file: path to the uploaded file (WAV, MP3, etc.)
25
- Returns: "fake" or "real"
26
  """
27
 
28
  # 2) Load the audio file
29
- # torchaudio returns (waveform, sample_rate)
30
  waveform, sr = torchaudio.load(audio_file)
31
 
32
  # If stereo, pick one channel or average
@@ -40,14 +39,13 @@ def classify_audio(audio_file):
40
  waveform = resampler(waveform)
41
  sr = 16000
42
 
43
-
44
  # 3) Preprocess with feature_extractor
45
  inputs = feature_extractor(
46
  waveform.numpy(),
47
  sampling_rate=sr,
48
  return_tensors="pt",
49
  truncation=True,
50
- max_length=int(16000* 6.0), # 6 second max
51
  )
52
 
53
  # Move everything to device
@@ -55,20 +53,24 @@ def classify_audio(audio_file):
55
 
56
  with torch.no_grad():
57
  logits = model(input_values).logits
58
- pred_id = torch.argmax(logits, dim=-1).item()
59
 
60
- # 4) Return label text
61
- predicted_label = label_names[pred_id]
62
- return predicted_label
 
 
 
63
 
 
 
64
 
65
- # 5) Build Gradio interface
66
  demo = gr.Interface(
67
  fn=classify_audio,
68
- inputs=gr.Audio( type="filepath"),
69
  outputs="text",
70
  title="Wav2Vec2 Deepfake Detection",
71
- description="Upload an audio sample to check if it is fake or real."
72
  )
73
 
74
  if __name__ == "__main__":
 
3
  import torchaudio
4
  import numpy as np
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
+ import torch.nn.functional as F
7
  import torchaudio.transforms as T
8
 
9
  MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
 
19
 
20
  label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
21
 
 
22
  def classify_audio(audio_file):
23
  """
24
  audio_file: path to the uploaded file (WAV, MP3, etc.)
25
+ Returns: predicted label and confidence score
26
  """
27
 
28
  # 2) Load the audio file
 
29
  waveform, sr = torchaudio.load(audio_file)
30
 
31
  # If stereo, pick one channel or average
 
39
  waveform = resampler(waveform)
40
  sr = 16000
41
 
 
42
  # 3) Preprocess with feature_extractor
43
  inputs = feature_extractor(
44
  waveform.numpy(),
45
  sampling_rate=sr,
46
  return_tensors="pt",
47
  truncation=True,
48
+ max_length=int(16000 * 6.0), # 6 second max
49
  )
50
 
51
  # Move everything to device
 
53
 
54
  with torch.no_grad():
55
  logits = model(input_values).logits
 
56
 
57
+ # 4) Calculate probabilities using softmax
58
+ probabilities = F.softmax(logits, dim=-1)
59
+
60
+ # Get predicted label and confidence
61
+ confidence, pred_id = torch.max(probabilities, dim=-1)
62
+ predicted_label = label_names[pred_id.item()]
63
 
64
+ # 5) Return label and confidence percentage
65
+ return f"Prediction: {predicted_label}, Confidence: {confidence.item() * 100:.2f}%"
66
 
67
+ # 6) Build Gradio interface
68
  demo = gr.Interface(
69
  fn=classify_audio,
70
+ inputs=gr.Audio(type="filepath"),
71
  outputs="text",
72
  title="Wav2Vec2 Deepfake Detection",
73
+ description="Upload an audio sample to check if it is fake or real, along with confidence."
74
  )
75
 
76
  if __name__ == "__main__":