ZeyadMostafa22 commited on
Commit
87a4e5a
·
1 Parent(s): b09c8e8

Add application file

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
+
7
+ MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
8
+
9
+ # 1) Load model & feature extractor
10
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
11
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
12
+ model.eval()
13
+
14
+ # Optionally use GPU if available
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
17
+
18
+ label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
19
+
20
+
21
+ def classify_audio(audio_file):
22
+ """
23
+ audio_file: path to the uploaded file (WAV, MP3, etc.)
24
+ Returns: "fake" or "real"
25
+ """
26
+
27
+ # 2) Load the audio file
28
+ # torchaudio returns (waveform, sample_rate)
29
+ waveform, sr = torchaudio.load(audio_file)
30
+
31
+ # If stereo, pick one channel or average
32
+ if waveform.shape[0] > 1:
33
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
34
+ waveform = waveform.squeeze() # (samples,)
35
+
36
+ # 3) Preprocess with feature_extractor
37
+ inputs = feature_extractor(
38
+ waveform.numpy(),
39
+ sampling_rate=sr,
40
+ return_tensors="pt",
41
+ truncation=True,
42
+ max_length=int(feature_extractor.sampling_rate * 6.0), # 6 second max
43
+ )
44
+
45
+ # Move everything to device
46
+ input_values = inputs["input_values"].to(device)
47
+
48
+ with torch.no_grad():
49
+ logits = model(input_values).logits
50
+ pred_id = torch.argmax(logits, dim=-1).item()
51
+
52
+ # 4) Return label text
53
+ predicted_label = label_names[pred_id]
54
+ return predicted_label
55
+
56
+
57
+ # 5) Build Gradio interface
58
+ demo = gr.Interface(
59
+ fn=classify_audio,
60
+ inputs=gr.Audio(source="upload", type="filepath"),
61
+ outputs="text",
62
+ title="Wav2Vec2 Deepfake Detection",
63
+ description="Upload an audio sample to check if it is fake or real."
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ demo.launch()