wli3221134's picture
Update app.py
18b7b66 verified
raw
history blame
5.45 kB
import spaces
import gradio as gr
import os
import torch
from model import Wav2Vec2BERT_Llama # 自定义模型模块
import dataset # 自定义数据集模块
from huggingface_hub import hf_hub_download
@spaces.GPU
def dummy(): # just a dummy
pass
# 初始化设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device:', device)
# 初始化模型
def load_model():
model = Wav2Vec2BERT_Llama().to(device)
checkpoint_path = hf_hub_download(
repo_id="amphion/deepfake_detection",
filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
repo_type="model"
)
# checkpoint_path = "ckpt/model_checkpoint.pth"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model_state_dict = checkpoint['model_state_dict']
threshold = 0.9996
# 处理模型状态字典的 key
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict)
model.eval()
else:
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
return model, threshold
model, threshold = load_model()
# 检测函数
def detect(dataset):
"""进行音频伪造检测"""
with torch.no_grad():
for batch in dataset:
main_features = {
'input_features': batch['main_features']['input_features'].to(device),
'attention_mask': batch['main_features']['attention_mask'].to(device)
}
prompt_features = [{
'input_features': pf['input_features'].to(device),
'attention_mask': pf['attention_mask'].to(device)
} for pf in batch['prompt_features']]
prompt_labels = batch['prompt_labels'].to(device)
# 模型的前向传播逻辑 (需要补充具体实现)
outputs = model({
'main_features': main_features,
'prompt_features': prompt_features,
'prompt_labels': prompt_labels
})
avg_scores = outputs['avg_logits'].softmax(dim=-1) # [batch_size, 2]
deepfake_scores = avg_scores[:, 1].cpu() # [batch_size]
is_fake = True if deepfake_scores[0] > threshold else False
# 假设 result 是模型返回的结果
result = {"is_fake": is_fake, "confidence": deepfake_scores[0]} # 示例返回值
return result
# 音频伪造检测主函数
def audio_deepfake_detection(demonstrations, query_audio_path):
"""
音频伪造检测函数
:param demonstrations: 演示音频路径和标签的列表
:param query_audio_path: 查询音频路径
:return: 检测结果
"""
demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
if len(demonstration_paths) != len(demonstration_labels):
demonstration_labels = demonstration_labels[:len(demonstration_paths)]
print(f"Demonstration audio paths: {demonstration_paths}")
print(f"Demonstration audio labels: {demonstration_labels}")
print(f"Query audio path: {query_audio_path}")
# 数据集处理
audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
# 调用检测函数
result = detect(audio_dataset)
# 返回结果
return {
"Is AI Generated": result["is_fake"],
"Confidence": f"{result['confidence']:.2f}%"
}
# Gradio 界面
def gradio_ui():
def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio):
demonstrations = [
(demonstration_audio1, label1),
(demonstration_audio2, label2),
(demonstration_audio3, label3),
]
return audio_deepfake_detection(demonstrations,query_audio)
interface = gr.Interface(
fn=detection_wrapper,
inputs=[
gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"),
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"),
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"),
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)")
],
outputs=gr.JSON(label="Detection Results"),
title="Audio Deepfake Detection System",
description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.",
)
return interface
if __name__ == "__main__":
demo = gradio_ui()
demo.launch()