import spaces import gradio as gr import os import torch from model import SpoofVerificationModel # 自定义模型模块 import dataset # 自定义数据集模块 from huggingface_hub import hf_hub_download from transformers import AutoFeatureExtractor @spaces.GPU def dummy(): # just a dummy pass # 修改 load_model 函数 def load_model(): checkpoint_path = hf_hub_download( repo_id="amphion/deepfake_detection", filename="checkpoints_w2v-bert_SpoofVerification_MultiDataset/model_checkpoint_4_new.pth", repo_type="model" ) if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") return checkpoint_path checkpoint_path = load_model() # 将 detect 函数移到 GPU 装饰器下 @spaces.GPU def detect_on_gpu(audio_path): """在 GPU 上进行音频伪造检测""" print("\n=== 开始音频检测 ===") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 数据集处理移到GPU函数内部 audio_dataset = dataset.DemoDataset(audio_path) print("正在初始化模型...") model = SpoofVerificationModel().to(device) print(f"正在加载模型权重: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) model_state_dict = checkpoint['model_state_dict'] threshold = 0.5 print(f"检测阈值设置为: {threshold}") # 处理模型状态字典的 key if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()): print("添加 'module.' 前缀到状态字典的 key") model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()} elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()): print("移除状态字典 key 中的 'module.' 前缀") model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()} model.load_state_dict(model_state_dict) model.eval() print("模型加载完成,进入评估模式") feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") print("\n开始处理音频数据...") with torch.no_grad(): for batch_idx, batch in enumerate(audio_dataset): print(f"\n处理批次 {batch_idx + 1}") if len(batch['waveforms'].shape) == 1: batch['waveforms'] = batch['waveforms'].unsqueeze(0) print('shape:', batch['waveforms'].shape) waveforms = batch['waveforms'].numpy() # [B, T] features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device) outputs = model(features) deepfake_logits = outputs['deepfake_logits'] deepfake_scores = deepfake_logits.float().softmax(dim=-1)[:, 1].contiguous() is_fake = deepfake_scores[0].item() > threshold result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]} break print("\n=== 检测完成 ===") return result def audio_deepfake_detection(audio_path): # 移除了数据集处理步骤 # 直接传递音频路径到GPU函数 result = detect_on_gpu(audio_path) is_fake = "是/Yes" if result["is_fake"] else "否/No" confidence = f"{100*result['confidence']:.2f}%" return { "是否为AI生成/Is AI Generated": is_fake, "检测可信度/Confidence": confidence } # Gradio 界面 def gradio_ui(): # def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio): # demonstrations = [ # (demonstration_audio1, label1), # (demonstration_audio2, label2), # (demonstration_audio3, label3), # ] # return audio_deepfake_detection(demonstrations,query_audio) # interface = gr.Interface( # fn=detection_wrapper, # inputs=[ # gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"), # gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"), # gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"), # gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"), # gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"), # gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"), # gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)") # ], # outputs=gr.JSON(label="Detection Results"), # title="Audio Deepfake Detection System", # description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.", # ) # return interface def detection_wrapper(query_audio): return audio_deepfake_detection(query_audio) interface = gr.Interface( fn=detection_wrapper, inputs=[ gr.Audio(sources=["upload"], type="filepath", label="测试音频 / Test Audio") ], outputs=gr.JSON(label="检测结果 / Detection Results"), title="音频伪造检测系统 / Audio Deepfake Detection System", description="上传一个测试音频以检测该音频是否为AI生成。/ Upload a test audio to detect whether the audio is AI-generated.", article=( "由香港中文大学(深圳)武执政教授团队开发。" "Developed by a team led by Prof Zhizheng Wu from the Chinese University of Hong Kong, Shenzhen." "\n\n" "本系统用于检测音频是否为AI生成,适用于研究和教育目的。" "This system is designed to detect whether an audio is AI-generated, " "and is intended for research and educational purposes." ) ) return interface if __name__ == "__main__": demo = gradio_ui() demo.launch()