Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,32 +1,37 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
-
import dataset
|
4 |
import torch
|
5 |
-
from model import Wav2Vec2BERT_Llama
|
|
|
6 |
|
7 |
-
#
|
8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
|
10 |
-
#
|
11 |
-
#
|
12 |
-
#
|
13 |
-
#
|
14 |
-
#
|
15 |
-
#
|
16 |
-
|
17 |
-
# # 处理模型状态字典
|
18 |
-
# if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
|
19 |
-
# model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
|
20 |
-
# elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
|
21 |
-
# model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
|
22 |
-
|
23 |
-
# model.load_state_dict(model_state_dict)
|
24 |
-
# model.eval()
|
25 |
-
# else:
|
26 |
-
# raise FileNotFoundError(f"Not found checkpoint: {checkpoint_path}")
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def detect(dataset, model):
|
|
|
30 |
with torch.no_grad():
|
31 |
for batch in dataset:
|
32 |
main_features = {
|
@@ -37,81 +42,65 @@ def detect(dataset, model):
|
|
37 |
'input_features': pf['input_features'].to(device),
|
38 |
'attention_mask': pf['attention_mask'].to(device)
|
39 |
} for pf in batch['prompt_features']]
|
40 |
-
|
41 |
|
|
|
|
|
|
|
|
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
# 收集所有demonstration的路径和标签
|
47 |
-
demonstration_paths = []
|
48 |
-
for child in demonstrations_container.children:
|
49 |
-
if isinstance(child, gr.Row):
|
50 |
-
audio = child.children[0].children[0].value
|
51 |
-
if audio is not None:
|
52 |
-
demonstration_paths.append(audio)
|
53 |
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
#
|
58 |
-
dataset = dataset.DemoDataset(demonstration_paths, audio_path)
|
59 |
-
# Example return value, modify according to your model
|
60 |
-
result = detect(dataset, model)
|
61 |
-
|
62 |
-
# Return detection results and confidence scores
|
63 |
return {
|
64 |
"Is AI Generated": result["is_fake"],
|
65 |
"Confidence": f"{result['confidence']:.2f}%"
|
66 |
}
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
gr.Markdown("## Demonstration Audios (Optional)")
|
79 |
-
|
80 |
-
# 创建3个固定的demonstration组件
|
81 |
-
for i in range(3):
|
82 |
-
with gr.Row():
|
83 |
-
with gr.Column(scale=8):
|
84 |
-
audio = gr.Audio(
|
85 |
-
sources=["upload"],
|
86 |
-
type="filepath",
|
87 |
-
label=f"Demonstration Audio {i+1}"
|
88 |
-
)
|
89 |
-
with gr.Column(scale=3):
|
90 |
-
label = gr.Dropdown(
|
91 |
-
choices=["bonafide", "spoof"],
|
92 |
-
value="bonafide",
|
93 |
-
label="Label"
|
94 |
-
)
|
95 |
-
|
96 |
-
# Query audio input component
|
97 |
-
query_audio_input = gr.Audio(
|
98 |
-
sources=["upload"],
|
99 |
-
label="Query Audio (Audio for Detection)",
|
100 |
-
type="filepath",
|
101 |
-
)
|
102 |
-
|
103 |
-
# Submit button
|
104 |
-
submit_btn = gr.Button(value="Start Detection", variant="primary")
|
105 |
-
|
106 |
-
# Output results
|
107 |
-
output_labels = gr.Json(label="Detection Results")
|
108 |
|
109 |
-
#
|
110 |
-
|
111 |
-
fn=
|
112 |
-
inputs=[
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
)
|
|
|
115 |
|
116 |
if __name__ == "__main__":
|
117 |
-
demo
|
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
|
|
3 |
import torch
|
4 |
+
from model import Wav2Vec2BERT_Llama # 自定义模型模块
|
5 |
+
import dataset # 自定义数据集模块
|
6 |
|
7 |
+
# 初始化设备
|
8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
|
10 |
+
# 初始化模型
|
11 |
+
# def load_model():
|
12 |
+
# model = Wav2Vec2BERT_Llama().to(device)
|
13 |
+
# checkpoint_path = "ckpt/model_checkpoint.pth"
|
14 |
+
# if os.path.exists(checkpoint_path):
|
15 |
+
# checkpoint = torch.load(checkpoint_path)
|
16 |
+
# model_state_dict = checkpoint['model_state_dict']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
# # 处理模型状态字典的 key
|
19 |
+
# if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
|
20 |
+
# model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
|
21 |
+
# elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
|
22 |
+
# model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
|
23 |
|
24 |
+
# model.load_state_dict(model_state_dict)
|
25 |
+
# model.eval()
|
26 |
+
# else:
|
27 |
+
# raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
28 |
+
# return model
|
29 |
+
|
30 |
+
# model = load_model()
|
31 |
+
|
32 |
+
# 检测函数
|
33 |
def detect(dataset, model):
|
34 |
+
"""进行音频伪造检测"""
|
35 |
with torch.no_grad():
|
36 |
for batch in dataset:
|
37 |
main_features = {
|
|
|
42 |
'input_features': pf['input_features'].to(device),
|
43 |
'attention_mask': pf['attention_mask'].to(device)
|
44 |
} for pf in batch['prompt_features']]
|
|
|
45 |
|
46 |
+
# 模型的前向传播逻辑 (需要补充具体实现)
|
47 |
+
# 假设 result 是模型返回的结果
|
48 |
+
result = {"is_fake": True, "confidence": 85.5} # 示例返回值
|
49 |
+
return result
|
50 |
|
51 |
+
# 音频伪造检测主函数
|
52 |
+
def audio_deepfake_detection(demonstrations, query_audio_path):
|
53 |
+
"""
|
54 |
+
音频伪造检测函数
|
55 |
+
:param demonstrations: 演示音频路径和标签的列表
|
56 |
+
:param query_audio_path: 查询音频路径
|
57 |
+
:return: 检测结果
|
58 |
+
"""
|
59 |
+
demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
|
60 |
+
print(f"Demonstration audio paths: {demonstration_paths}")
|
61 |
+
print(f"Query audio path: {query_audio_path}")
|
62 |
|
63 |
+
# 数据集处理
|
64 |
+
audio_dataset = dataset.DemoDataset(demonstration_paths, query_audio_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# 调用检测函数
|
67 |
+
result = detect(audio_dataset, model)
|
68 |
|
69 |
+
# 返回结果
|
|
|
|
|
|
|
|
|
|
|
70 |
return {
|
71 |
"Is AI Generated": result["is_fake"],
|
72 |
"Confidence": f"{result['confidence']:.2f}%"
|
73 |
}
|
74 |
|
75 |
+
# Gradio 界面
|
76 |
+
def gradio_ui():
|
77 |
+
def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio):
|
78 |
+
# 将输入音频和标签封装成列表
|
79 |
+
demonstrations = [
|
80 |
+
(demonstration_audio1, label1),
|
81 |
+
(demonstration_audio2, label2),
|
82 |
+
(demonstration_audio3, label3),
|
83 |
+
]
|
84 |
+
return audio_deepfake_detection(demonstrations, query_audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
# 构建 Gradio 界面
|
87 |
+
interface = gr.Interface(
|
88 |
+
fn=detection_wrapper, # 主函数
|
89 |
+
inputs=[
|
90 |
+
gr.Audio(source="upload", type="filepath", label="Demonstration Audio 1"),
|
91 |
+
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
|
92 |
+
gr.Audio(source="upload", type="filepath", label="Demonstration Audio 2"),
|
93 |
+
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
|
94 |
+
gr.Audio(source="upload", type="filepath", label="Demonstration Audio 3"),
|
95 |
+
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
|
96 |
+
gr.Audio(source="upload", type="filepath", label="Query Audio (Audio for Detection)")
|
97 |
+
],
|
98 |
+
outputs=gr.JSON(label="Detection Results"),
|
99 |
+
title="Audio Deepfake Detection System",
|
100 |
+
description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.",
|
101 |
)
|
102 |
+
return interface
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
+
demo = gradio_ui()
|
106 |
+
demo.launch()
|