|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
|
import torchaudio.transforms as T |
|
|
|
MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints" |
|
|
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID) |
|
model = AutoModelForAudioClassification.from_pretrained(MODEL_ID) |
|
model.eval() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
label_names = ["fake", "real"] |
|
|
|
|
|
def classify_audio(audio_file): |
|
""" |
|
audio_file: path to the uploaded file (WAV, MP3, etc.) |
|
Returns: "fake" or "real" |
|
""" |
|
|
|
|
|
|
|
waveform, sr = torchaudio.load(audio_file) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
waveform = waveform.squeeze() |
|
|
|
|
|
if sr != 16000: |
|
resampler = T.Resample(sr, 16000) |
|
waveform = resampler(waveform) |
|
sr = 16000 |
|
|
|
|
|
|
|
inputs = feature_extractor( |
|
waveform.numpy(), |
|
sampling_rate=sr, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=int(16000* 6.0), |
|
) |
|
|
|
|
|
input_values = inputs["input_values"].to(device) |
|
|
|
with torch.no_grad(): |
|
logits = model(input_values).logits |
|
pred_id = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
predicted_label = label_names[pred_id] |
|
return predicted_label |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_audio, |
|
inputs=gr.Audio( type="filepath"), |
|
outputs="text", |
|
title="Wav2Vec2 Deepfake Detection", |
|
description="Upload an audio sample to check if it is fake or real." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|