wli3221134's picture
Update app.py
af54972 verified
import spaces
import gradio as gr
import os
import torch
from model import SpoofVerificationModel # 自定义模型模块
import dataset # 自定义数据集模块
from huggingface_hub import hf_hub_download
from transformers import AutoFeatureExtractor
@spaces.GPU
def dummy(): # just a dummy
pass
# 修改 load_model 函数
def load_model():
checkpoint_path = hf_hub_download(
repo_id="amphion/deepfake_detection",
filename="checkpoints_w2v-bert_SpoofVerification_MultiDataset/model_checkpoint_4_new.pth",
repo_type="model"
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
return checkpoint_path
checkpoint_path = load_model()
# 将 detect 函数移到 GPU 装饰器下
@spaces.GPU
def detect_on_gpu(audio_path):
"""在 GPU 上进行音频伪造检测"""
print("\n=== 开始音频检测 ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据集处理移到GPU函数内部
audio_dataset = dataset.DemoDataset(audio_path)
print("正在初始化模型...")
model = SpoofVerificationModel().to(device)
print(f"正在加载模型权重: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model_state_dict = checkpoint['model_state_dict']
threshold = 0.5
print(f"检测阈值设置为: {threshold}")
# 处理模型状态字典的 key
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
print("添加 'module.' 前缀到状态字典的 key")
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
print("移除状态字典 key 中的 'module.' 前缀")
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict)
model.eval()
print("模型加载完成,进入评估模式")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
print("\n开始处理音频数据...")
with torch.no_grad():
for batch_idx, batch in enumerate(audio_dataset):
print(f"\n处理批次 {batch_idx + 1}")
if len(batch['waveforms'].shape) == 1:
batch['waveforms'] = batch['waveforms'].unsqueeze(0)
print('shape:', batch['waveforms'].shape)
waveforms = batch['waveforms'].numpy() # [B, T]
features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
outputs = model(features)
deepfake_logits = outputs['deepfake_logits']
deepfake_scores = deepfake_logits.float().softmax(dim=-1)[:, 1].contiguous()
is_fake = deepfake_scores[0].item() > threshold
result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
break
print("\n=== 检测完成 ===")
return result
def audio_deepfake_detection(audio_path):
# 移除了数据集处理步骤
# 直接传递音频路径到GPU函数
result = detect_on_gpu(audio_path)
is_fake = "是/Yes" if result["is_fake"] else "否/No"
confidence = f"{100*result['confidence']:.2f}%"
return {
"是否为AI生成/Is AI Generated": is_fake,
"检测可信度/Confidence": confidence
}
# Gradio 界面
def gradio_ui():
# def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio):
# demonstrations = [
# (demonstration_audio1, label1),
# (demonstration_audio2, label2),
# (demonstration_audio3, label3),
# ]
# return audio_deepfake_detection(demonstrations,query_audio)
# interface = gr.Interface(
# fn=detection_wrapper,
# inputs=[
# gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"),
# gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
# gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"),
# gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
# gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"),
# gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
# gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)")
# ],
# outputs=gr.JSON(label="Detection Results"),
# title="Audio Deepfake Detection System",
# description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.",
# )
# return interface
def detection_wrapper(query_audio):
return audio_deepfake_detection(query_audio)
interface = gr.Interface(
fn=detection_wrapper,
inputs=[
gr.Audio(sources=["upload"], type="filepath", label="测试音频 / Test Audio")
],
outputs=gr.JSON(label="检测结果 / Detection Results"),
title="音频伪造检测系统 / Audio Deepfake Detection System",
description="上传一个测试音频以检测该音频是否为AI生成。/ Upload a test audio to detect whether the audio is AI-generated.",
article=(
"由香港中文大学(深圳)武执政教授团队开发。"
"Developed by a team led by Prof Zhizheng Wu from the Chinese University of Hong Kong, Shenzhen."
"\n\n"
"本系统用于检测音频是否为AI生成,适用于研究和教育目的。"
"This system is designed to detect whether an audio is AI-generated, "
"and is intended for research and educational purposes."
)
)
return interface
if __name__ == "__main__":
demo = gradio_ui()
demo.launch()