File size: 2,078 Bytes
f052fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979a7cd
f052fd1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torchaudio.transforms as T

MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"

# 1) Load model & feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
model = AutoModelForAudioClassification.from_pretrained(MODEL_ID)
model.eval()

# Optionally use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

label_names = ["fake", "real"]  # According to your label2id = {"fake": 0, "real": 1}


def classify_audio(audio_file):
    """
    audio_file: path to the uploaded file (WAV, MP3, etc.)
    Returns: "fake" or "real"
    """

    # 2) Load the audio file
    # torchaudio returns (waveform, sample_rate)
    waveform, sr = torchaudio.load(audio_file)

    # If stereo, pick one channel or average
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    waveform = waveform.squeeze()  # (samples,)

    # 3) Resample if needed
    if sr != 16000:
        resampler = T.Resample(sr, 16000)
        waveform = resampler(waveform)
        sr = 16000


    # 3) Preprocess with feature_extractor
    inputs = feature_extractor(
        waveform.numpy(),
        sampling_rate=sr,
        return_tensors="pt",
        truncation=True,
        max_length=int(16000* 6.0),  # 6 second max
    )

    # Move everything to device
    input_values = inputs["input_values"].to(device)

    with torch.no_grad():
        logits = model(input_values).logits
        pred_id = torch.argmax(logits, dim=-1).item()

    # 4) Return label text
    predicted_label = label_names[pred_id]
    return predicted_label


# 5) Build Gradio interface
demo = gr.Interface(
    fn=classify_audio,
    inputs=gr.Audio( type="filepath"),
    outputs="text",
    title="Sigma One - Deepfake Audio Detection",
    description="Upload an audio sample to check if it is fake or real."
)

if __name__ == "__main__":
    demo.launch()