Prasanna-ETH commited on
Commit
f052fd1
·
verified ·
1 Parent(s): 907f0a6

Upload app (1).py

Browse files
Files changed (1) hide show
  1. app (1).py +75 -0
app (1).py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
+ import torchaudio.transforms as T
7
+
8
+ MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
9
+
10
+ # 1) Load model & feature extractor
11
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
12
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
13
+ model.eval()
14
+
15
+ # Optionally use GPU if available
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
+
19
+ label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
20
+
21
+
22
+ def classify_audio(audio_file):
23
+ """
24
+ audio_file: path to the uploaded file (WAV, MP3, etc.)
25
+ Returns: "fake" or "real"
26
+ """
27
+
28
+ # 2) Load the audio file
29
+ # torchaudio returns (waveform, sample_rate)
30
+ waveform, sr = torchaudio.load(audio_file)
31
+
32
+ # If stereo, pick one channel or average
33
+ if waveform.shape[0] > 1:
34
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
35
+ waveform = waveform.squeeze() # (samples,)
36
+
37
+ # 3) Resample if needed
38
+ if sr != 16000:
39
+ resampler = T.Resample(sr, 16000)
40
+ waveform = resampler(waveform)
41
+ sr = 16000
42
+
43
+
44
+ # 3) Preprocess with feature_extractor
45
+ inputs = feature_extractor(
46
+ waveform.numpy(),
47
+ sampling_rate=sr,
48
+ return_tensors="pt",
49
+ truncation=True,
50
+ max_length=int(16000* 6.0), # 6 second max
51
+ )
52
+
53
+ # Move everything to device
54
+ input_values = inputs["input_values"].to(device)
55
+
56
+ with torch.no_grad():
57
+ logits = model(input_values).logits
58
+ pred_id = torch.argmax(logits, dim=-1).item()
59
+
60
+ # 4) Return label text
61
+ predicted_label = label_names[pred_id]
62
+ return predicted_label
63
+
64
+
65
+ # 5) Build Gradio interface
66
+ demo = gr.Interface(
67
+ fn=classify_audio,
68
+ inputs=gr.Audio( type="filepath"),
69
+ outputs="text",
70
+ title="Wav2Vec2 Deepfake Detection",
71
+ description="Upload an audio sample to check if it is fake or real."
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ demo.launch()