File size: 4,852 Bytes
5db4ee9
0a63b23
719b808
 
88a8fb2
 
719b808
ca6274c
e7b9619
 
 
0a63b23
4a81ee5
719b808
 
 
c24b00b
 
719b808
4a81ee5
 
 
719b808
4a81ee5
719b808
4a81ee5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719b808
4a81ee5
 
88a8fb2
34146f0
 
 
 
 
 
 
 
 
 
 
340d6c0
 
 
 
 
 
 
4a81ee5
 
 
 
88a8fb2
34146f0
4a81ee5
88a8fb2
 
340d6c0
18b7b66
 
4a81ee5
88a8fb2
340d6c0
4a81ee5
 
 
 
ca6274c
 
 
 
 
88a8fb2
 
 
 
 
 
 
 
340d6c0
ca6274c
88a8fb2
18c21b7
88a8fb2
18c21b7
88a8fb2
18c21b7
88a8fb2
18c21b7
88a8fb2
18c21b7
88a8fb2
 
 
 
ca6274c
88a8fb2
ca6274c
 
88a8fb2
719b808
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import spaces
import gradio as gr
import os
import torch
from model import Wav2Vec2BERT_Llama  # 自定义模型模块
import dataset  # 自定义数据集模块
from huggingface_hub import hf_hub_download

@spaces.GPU
def dummy(): # just a dummy
    pass

# 修改 load_model 函数
def load_model():
    checkpoint_path = hf_hub_download(
        repo_id="amphion/deepfake_detection", 
        filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
        repo_type="model"
    )
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    return checkpoint_path

checkpoint_path = load_model()

# 将 detect 函数移到 GPU 装饰器下
@spaces.GPU
def detect_on_gpu(dataset):
    """在 GPU 上进行音频伪造检测"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Wav2Vec2BERT_Llama().to(device)
    
    # 加载模型权重
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_state_dict = checkpoint['model_state_dict']
    threshold = 0.9996

    # 处理模型状态字典的 key
    if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
        model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
    elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
        model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}

    model.load_state_dict(model_state_dict)
    model.eval()

    with torch.no_grad():
        for batch in dataset:
            main_features = {
                'input_features': batch['main_features']['input_features'].to(device),
                'attention_mask': batch['main_features']['attention_mask'].to(device)
            }
            prompt_features = [{
                'input_features': pf['input_features'].to(device),
                'attention_mask': pf['attention_mask'].to(device)
            } for pf in batch['prompt_features']]

            prompt_labels = batch['prompt_labels'].to(device)
            outputs = model({
                'main_features': main_features,
                'prompt_features': prompt_features,
                'prompt_labels': prompt_labels
            })

            avg_scores = outputs['avg_logits'].softmax(dim=-1)
            deepfake_scores = avg_scores[:, 1].cpu()
            is_fake = deepfake_scores[0] > threshold
            result = {"is_fake": is_fake, "confidence": deepfake_scores[0]}
    return result

# 修改音频伪造检测主函数
def audio_deepfake_detection(demonstrations, query_audio_path):
    demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
    demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
    if len(demonstration_paths) != len(demonstration_labels):
        demonstration_labels = demonstration_labels[:len(demonstration_paths)]
    
    # 数据集处理
    audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
    
    # 调用 GPU 检测函数
    result = detect_on_gpu(audio_dataset)
    
    return {
        "Is AI Generated": result["is_fake"],
        "Confidence": f"{result['confidence']:.2f}%"
    }

# Gradio 界面
def gradio_ui():
    def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio):
        demonstrations = [
            (demonstration_audio1, label1),
            (demonstration_audio2, label2),
            (demonstration_audio3, label3),
        ]
        return audio_deepfake_detection(demonstrations,query_audio)

    interface = gr.Interface(
        fn=detection_wrapper,
        inputs=[
            gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"),
            gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
            gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"),
            gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
            gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"),
            gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
            gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)")
        ],
        outputs=gr.JSON(label="Detection Results"),
        title="Audio Deepfake Detection System",
        description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.",
    )
    return interface

if __name__ == "__main__":
    demo = gradio_ui()
    demo.launch()