wli3221134 commited on
Commit
88a8fb2
·
verified ·
1 Parent(s): c16a4df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -84
app.py CHANGED
@@ -1,32 +1,37 @@
1
  import gradio as gr
2
  import os
3
- import dataset
4
  import torch
5
- from model import Wav2Vec2BERT_Llama
 
6
 
7
- # init
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- # # init model
11
- # model = Wav2Vec2BERT_Llama().to(device)
12
- # checkpoint_path = "ckpt/model_checkpoint.pth"
13
- # if os.path.exists(checkpoint_path):
14
- # checkpoint = torch.load(checkpoint_path)
15
- # model_state_dict = checkpoint['model_state_dict']
16
-
17
- # # 处理模型状态字典
18
- # if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
19
- # model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
20
- # elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
21
- # model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
22
-
23
- # model.load_state_dict(model_state_dict)
24
- # model.eval()
25
- # else:
26
- # raise FileNotFoundError(f"Not found checkpoint: {checkpoint_path}")
27
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
29
  def detect(dataset, model):
 
30
  with torch.no_grad():
31
  for batch in dataset:
32
  main_features = {
@@ -37,81 +42,65 @@ def detect(dataset, model):
37
  'input_features': pf['input_features'].to(device),
38
  'attention_mask': pf['attention_mask'].to(device)
39
  } for pf in batch['prompt_features']]
40
-
41
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def audio_deepfake_detection(demonstrations_container, audio_path):
45
- """Audio deepfake detection function"""
46
- # 收集所有demonstration的路径和标签
47
- demonstration_paths = []
48
- for child in demonstrations_container.children:
49
- if isinstance(child, gr.Row):
50
- audio = child.children[0].children[0].value
51
- if audio is not None:
52
- demonstration_paths.append(audio)
53
 
54
- print("Demonstration audio paths: {}".format(demonstration_paths))
55
- print("Query audio path: {}".format(audio_path))
56
 
57
- # dataset
58
- dataset = dataset.DemoDataset(demonstration_paths, audio_path)
59
- # Example return value, modify according to your model
60
- result = detect(dataset, model)
61
-
62
- # Return detection results and confidence scores
63
  return {
64
  "Is AI Generated": result["is_fake"],
65
  "Confidence": f"{result['confidence']:.2f}%"
66
  }
67
 
68
- with gr.Blocks() as demo:
69
- gr.Markdown(
70
- """
71
- # Audio Deepfake Detection System
72
-
73
- This demo helps you detect whether an audio clip is AI-generated or authentic.
74
- """
75
- )
76
-
77
- with gr.Column() as demonstrations_container:
78
- gr.Markdown("## Demonstration Audios (Optional)")
79
-
80
- # 创建3个固定的demonstration组件
81
- for i in range(3):
82
- with gr.Row():
83
- with gr.Column(scale=8):
84
- audio = gr.Audio(
85
- sources=["upload"],
86
- type="filepath",
87
- label=f"Demonstration Audio {i+1}"
88
- )
89
- with gr.Column(scale=3):
90
- label = gr.Dropdown(
91
- choices=["bonafide", "spoof"],
92
- value="bonafide",
93
- label="Label"
94
- )
95
-
96
- # Query audio input component
97
- query_audio_input = gr.Audio(
98
- sources=["upload"],
99
- label="Query Audio (Audio for Detection)",
100
- type="filepath",
101
- )
102
-
103
- # Submit button
104
- submit_btn = gr.Button(value="Start Detection", variant="primary")
105
-
106
- # Output results
107
- output_labels = gr.Json(label="Detection Results")
108
 
109
- # Set click event
110
- submit_btn.click(
111
- fn=audio_deepfake_detection,
112
- inputs=[demonstrations_container, query_audio_input],
113
- outputs=[output_labels]
 
 
 
 
 
 
 
 
 
 
114
  )
 
115
 
116
  if __name__ == "__main__":
117
- demo.launch(share=False)
 
 
1
  import gradio as gr
2
  import os
 
3
  import torch
4
+ from model import Wav2Vec2BERT_Llama # 自定义模型模块
5
+ import dataset # 自定义数据集模块
6
 
7
+ # 初始化设备
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
+ # 初始化模型
11
+ # def load_model():
12
+ # model = Wav2Vec2BERT_Llama().to(device)
13
+ # checkpoint_path = "ckpt/model_checkpoint.pth"
14
+ # if os.path.exists(checkpoint_path):
15
+ # checkpoint = torch.load(checkpoint_path)
16
+ # model_state_dict = checkpoint['model_state_dict']
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # # 处理模型状态字典的 key
19
+ # if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
20
+ # model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
21
+ # elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
22
+ # model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
23
 
24
+ # model.load_state_dict(model_state_dict)
25
+ # model.eval()
26
+ # else:
27
+ # raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
28
+ # return model
29
+
30
+ # model = load_model()
31
+
32
+ # 检测函数
33
  def detect(dataset, model):
34
+ """进行音频伪造检测"""
35
  with torch.no_grad():
36
  for batch in dataset:
37
  main_features = {
 
42
  'input_features': pf['input_features'].to(device),
43
  'attention_mask': pf['attention_mask'].to(device)
44
  } for pf in batch['prompt_features']]
 
45
 
46
+ # 模型的前向传播逻辑 (需要补充具体实现)
47
+ # 假设 result 是模型返回的结果
48
+ result = {"is_fake": True, "confidence": 85.5} # 示例返回值
49
+ return result
50
 
51
+ # 音频伪造检测主函数
52
+ def audio_deepfake_detection(demonstrations, query_audio_path):
53
+ """
54
+ 音频伪造检测函数
55
+ :param demonstrations: 演示音频路径和标签的列表
56
+ :param query_audio_path: 查询音频路径
57
+ :return: 检测结果
58
+ """
59
+ demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
60
+ print(f"Demonstration audio paths: {demonstration_paths}")
61
+ print(f"Query audio path: {query_audio_path}")
62
 
63
+ # 数据集处理
64
+ audio_dataset = dataset.DemoDataset(demonstration_paths, query_audio_path)
 
 
 
 
 
 
 
65
 
66
+ # 调用检测函数
67
+ result = detect(audio_dataset, model)
68
 
69
+ # 返回结果
 
 
 
 
 
70
  return {
71
  "Is AI Generated": result["is_fake"],
72
  "Confidence": f"{result['confidence']:.2f}%"
73
  }
74
 
75
+ # Gradio 界面
76
+ def gradio_ui():
77
+ def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio):
78
+ # 将输入音频和标签封装成列表
79
+ demonstrations = [
80
+ (demonstration_audio1, label1),
81
+ (demonstration_audio2, label2),
82
+ (demonstration_audio3, label3),
83
+ ]
84
+ return audio_deepfake_detection(demonstrations, query_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # 构建 Gradio 界面
87
+ interface = gr.Interface(
88
+ fn=detection_wrapper, # 主函数
89
+ inputs=[
90
+ gr.Audio(source="upload", type="filepath", label="Demonstration Audio 1"),
91
+ gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
92
+ gr.Audio(source="upload", type="filepath", label="Demonstration Audio 2"),
93
+ gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
94
+ gr.Audio(source="upload", type="filepath", label="Demonstration Audio 3"),
95
+ gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
96
+ gr.Audio(source="upload", type="filepath", label="Query Audio (Audio for Detection)")
97
+ ],
98
+ outputs=gr.JSON(label="Detection Results"),
99
+ title="Audio Deepfake Detection System",
100
+ description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.",
101
  )
102
+ return interface
103
 
104
  if __name__ == "__main__":
105
+ demo = gradio_ui()
106
+ demo.launch()