Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import os | |
import torch | |
from model import Wav2Vec2BERT_Llama # 自定义模型模块 | |
import dataset # 自定义数据集模块 | |
from huggingface_hub import hf_hub_download | |
def dummy(): # just a dummy | |
pass | |
# 修改 load_model 函数 | |
def load_model(): | |
checkpoint_path = hf_hub_download( | |
repo_id="amphion/deepfake_detection", | |
filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth", | |
repo_type="model" | |
) | |
if not os.path.exists(checkpoint_path): | |
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
return checkpoint_path | |
checkpoint_path = load_model() | |
# 将 detect 函数移到 GPU 装饰器下 | |
def detect_on_gpu(dataset): | |
"""在 GPU 上进行音频伪造检测""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = Wav2Vec2BERT_Llama().to(device) | |
# 加载模型权重 | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
model_state_dict = checkpoint['model_state_dict'] | |
threshold = 0.9996 | |
# 处理模型状态字典的 key | |
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()): | |
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()} | |
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()): | |
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()} | |
model.load_state_dict(model_state_dict) | |
model.eval() | |
with torch.no_grad(): | |
for batch in dataset: | |
main_features = { | |
'input_features': batch['main_features']['input_features'].to(device), | |
'attention_mask': batch['main_features']['attention_mask'].to(device) | |
} | |
prompt_features = [{ | |
'input_features': pf['input_features'].to(device), | |
'attention_mask': pf['attention_mask'].to(device) | |
} for pf in batch['prompt_features']] | |
prompt_labels = batch['prompt_labels'].to(device) | |
outputs = model({ | |
'main_features': main_features, | |
'prompt_features': prompt_features, | |
'prompt_labels': prompt_labels | |
}) | |
avg_scores = outputs['avg_logits'].softmax(dim=-1) | |
deepfake_scores = avg_scores[:, 1].cpu() | |
is_fake = deepfake_scores[0] > threshold | |
result = {"is_fake": is_fake, "confidence": deepfake_scores[0]} | |
return result | |
# 修改音频伪造检测主函数 | |
def audio_deepfake_detection(demonstrations, query_audio_path): | |
demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None] | |
demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None] | |
if len(demonstration_paths) != len(demonstration_labels): | |
demonstration_labels = demonstration_labels[:len(demonstration_paths)] | |
# 数据集处理 | |
audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path) | |
# 调用 GPU 检测函数 | |
result = detect_on_gpu(audio_dataset) | |
return { | |
"Is AI Generated": result["is_fake"], | |
"Confidence": f"{result['confidence']:.2f}%" | |
} | |
# Gradio 界面 | |
def gradio_ui(): | |
def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio): | |
demonstrations = [ | |
(demonstration_audio1, label1), | |
(demonstration_audio2, label2), | |
(demonstration_audio3, label3), | |
] | |
return audio_deepfake_detection(demonstrations,query_audio) | |
interface = gr.Interface( | |
fn=detection_wrapper, | |
inputs=[ | |
gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"), | |
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"), | |
gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"), | |
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"), | |
gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"), | |
gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"), | |
gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)") | |
], | |
outputs=gr.JSON(label="Detection Results"), | |
title="Audio Deepfake Detection System", | |
description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.", | |
) | |
return interface | |
if __name__ == "__main__": | |
demo = gradio_ui() | |
demo.launch() | |