File size: 3,994 Bytes
879c4b9
 
 
 
 
 
 
 
 
 
b2b5493
879c4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711f553
879c4b9
 
 
 
 
 
 
0436843
 
 
879c4b9
 
0436843
879c4b9
 
 
 
0436843
879c4b9
 
 
 
0436843
879c4b9
0436843
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b5493
c3a5025
b2b5493
879c4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
import numpy as np
from sklearn.linear_model import LogisticRegression
import gradio as gr
import os
import torch.nn.functional as F
from scipy.special import expit
import json


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CustomWav2Vec2Model(Wav2Vec2Model):
    def __init__(self, config):
        super().__init__(config)
        self.encoder.layers = self.encoder.layers[:9]

truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")

class HuggingFaceFeatureExtractor:
    def __init__(self, model, feature_extractor_name):
        self.device = device
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
        self.model = model
        self.model.eval()
        self.model.to(self.device)

    def __call__(self, audio, sr):
        inputs = self.feature_extractor(
            audio,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        return outputs.hidden_states[9]

FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
classifier,scaler, thresh = joblib.load('logreg_margin_pruning_ALL_with_scaler+threshold.joblib')

def segment_audio(audio, sr, segment_duration):
    segment_samples = int(segment_duration * sr)
    total_samples = len(audio)
    segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
    segments_check = []
    for seg in segments:
        # if the segment is shorter than 0.7s, skip it to avoid complications inside wav2vec2
        if len(seg) > 0.7 * sr:
           segments_check.append(seg)
    return segments_check

def process_audio(input_data, segment_duration=10):
    audio, sr = librosa.load(input_data, sr=16000)
    if len(audio.shape) > 1:
        audio = audio[0]
    segments = segment_audio(audio, sr, segment_duration)
    segment_predictions = []
    confidence_scores = []
    eer_threshold = thresh - 5e-3  # small margin of error

    for idx, segment in enumerate(segments):
        features = FEATURE_EXTRACTOR(segment, sr)
        features_avg = torch.mean(features, dim=1).cpu().numpy().reshape(1, -1)
        decision_score = classifier.decision_function(features_avg)
        decision_score_scaled = scaler.transform(decision_score.reshape(-1, 1)).flatten()
        decision_value = decision_score_scaled[0]
        pred = 1 if decision_value >= eer_threshold else 0

        if pred == 1:
            confidence_percentage = expit(decision_score).item()
        else:
            confidence_percentage = 1 - expit(decision_score).item()

        segment_predictions.append(pred)
        confidence_scores.append(confidence_percentage)

    output_dict = {
        "label": "real" if sum(segment_predictions) > (len(segment_predictions) / 2) else "fake",
        "segments": [
            {
                "segment": idx + 1,
                "prediction": "real" if pred == 1 else "fake",
                "confidence": round(conf * 100, 2)
            }
            for idx, (pred, conf) in enumerate(zip(segment_predictions, confidence_scores))
        ]
    }

    json_output = json.dumps(output_dict, indent=4)
    print(json_output)
    return json_output

def gradio_interface(audio):
    if audio:
        return process_audio(audio)
    else:
        return "please upload an audio file"

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Audio(type="filepath", label="Upload Audio")],
    outputs="text",
    title="SOL2 Audio Deepfake Detection Demo",
    description="Upload an audio file to check if it's AI-generated",
)

interface.launch(share=True)
#
#print(process_audio('SSL_scripts/1.wav'))