ZeyadMostafa22 commited on
Commit
0100779
·
1 Parent(s): 709b43f
Files changed (1) hide show
  1. app.py +75 -1
app.py CHANGED
@@ -1,3 +1,77 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- gr.load("models/Zeyadd-Mostaffa/Deepfake-Audio-Detection-v1").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
+ import torch.nn.functional as F
7
+ import torchaudio.transforms as T
8
 
9
+ MODEL_ID = "Zeyadd-Mostaffa/Deepfake-Audio-Detection-v1"
10
+
11
+ # 1) Load model & feature extractor
12
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
13
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
14
+ model.eval()
15
+
16
+ # Optionally use GPU if available
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model.to(device)
19
+
20
+ label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
21
+
22
+ def classify_audio(audio_file):
23
+ """
24
+ audio_file: path to the uploaded file (WAV, MP3, etc.)
25
+ Returns: predicted label and confidence score
26
+ """
27
+
28
+ # 2) Load the audio file
29
+ waveform, sr = torchaudio.load(audio_file)
30
+
31
+ # If stereo, pick one channel or average
32
+ if waveform.shape[0] > 1:
33
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
34
+ waveform = waveform.squeeze() # (samples,)
35
+
36
+ # 3) Resample if needed
37
+ if sr != 16000:
38
+ resampler = T.Resample(sr, 16000)
39
+ waveform = resampler(waveform)
40
+ sr = 16000
41
+
42
+ # 3) Preprocess with feature_extractor
43
+ inputs = feature_extractor(
44
+ waveform.numpy(),
45
+ sampling_rate=sr,
46
+ return_tensors="pt",
47
+ truncation=True,
48
+ max_length=int(16000 * 6.0), # 6 second max
49
+ )
50
+
51
+ # Move everything to device
52
+ input_values = inputs["input_values"].to(device)
53
+
54
+ with torch.no_grad():
55
+ logits = model(input_values).logits
56
+
57
+ # 4) Calculate probabilities using softmax
58
+ probabilities = F.softmax(logits, dim=-1)
59
+
60
+ # Get predicted label and confidence
61
+ confidence, pred_id = torch.max(probabilities, dim=-1)
62
+ predicted_label = label_names[pred_id.item()]
63
+
64
+ # 5) Return label and confidence percentage
65
+ return f"Prediction: {predicted_label}, Confidence: {confidence.item() * 100:.2f}%"
66
+
67
+ # 6) Build Gradio interface
68
+ demo = gr.Interface(
69
+ fn=classify_audio,
70
+ inputs=gr.Audio(type="filepath"),
71
+ outputs="text",
72
+ title="Wav2Vec2 Deepfake Detection",
73
+ description="Upload an audio sample to check if it is fake or real, along with confidence."
74
+ )
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch()