Spaces:
Running
on
Zero
Running
on
Zero
Upload 3 files
Browse files- app.py +16 -59
- dataset.py +24 -122
- model.py +52 -363
app.py
CHANGED
@@ -2,9 +2,11 @@ import spaces
|
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
import torch
|
5 |
-
from model import
|
6 |
import dataset # 自定义数据集模块
|
7 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
8 |
|
9 |
@spaces.GPU
|
10 |
def dummy(): # just a dummy
|
@@ -14,7 +16,7 @@ def dummy(): # just a dummy
|
|
14 |
def load_model():
|
15 |
checkpoint_path = hf_hub_download(
|
16 |
repo_id="amphion/deepfake_detection",
|
17 |
-
filename="
|
18 |
repo_type="model"
|
19 |
)
|
20 |
if not os.path.exists(checkpoint_path):
|
@@ -33,12 +35,12 @@ def detect_on_gpu(dataset):
|
|
33 |
print(f"使用设备: {device}")
|
34 |
|
35 |
print("正在初始化模型...")
|
36 |
-
model =
|
37 |
|
38 |
print(f"正在加载模型权重: {checkpoint_path}")
|
39 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
40 |
model_state_dict = checkpoint['model_state_dict']
|
41 |
-
threshold = 0.
|
42 |
print(f"检测阈值设置为: {threshold}")
|
43 |
|
44 |
# 处理模型状态字典的 key
|
@@ -53,45 +55,18 @@ def detect_on_gpu(dataset):
|
|
53 |
model.eval()
|
54 |
print("模型加载完成,进入评估模式")
|
55 |
|
|
|
|
|
56 |
print("\n开始处理音频数据...")
|
57 |
with torch.no_grad():
|
58 |
for batch_idx, batch in enumerate(dataset):
|
59 |
print(f"\n处理批次 {batch_idx + 1}")
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
print(f"主特征形状: {main_features['input_features'].shape}")
|
67 |
-
|
68 |
-
if len(batch['prompt_features']) > 0:
|
69 |
-
print("\n准备提示特征...")
|
70 |
-
prompt_features = [{
|
71 |
-
'input_features': pf['input_features'].to(device),
|
72 |
-
'attention_mask': pf['attention_mask'].to(device)
|
73 |
-
} for pf in batch['prompt_features']]
|
74 |
-
print(f"提示特征数量: {len(prompt_features)}")
|
75 |
-
print(f"第一个提示特征形状: {prompt_features[0]['input_features'].shape}")
|
76 |
-
|
77 |
-
print("\n准备提示标签...")
|
78 |
-
prompt_labels = batch['prompt_labels'].to(device)
|
79 |
-
print(f"提示标签形状: {prompt_labels.shape}")
|
80 |
-
print(f"提示标签值: {prompt_labels}")
|
81 |
-
else:
|
82 |
-
prompt_features = []
|
83 |
-
prompt_labels = []
|
84 |
-
|
85 |
-
print("\n执行模型推理...")
|
86 |
-
outputs = model({
|
87 |
-
'main_features': main_features,
|
88 |
-
'prompt_features': prompt_features,
|
89 |
-
'prompt_labels': prompt_labels
|
90 |
-
})
|
91 |
-
|
92 |
-
print("\n处理模型输出...")
|
93 |
-
avg_scores = outputs['avg_logits'].softmax(dim=-1)
|
94 |
-
deepfake_scores = avg_scores[:, 1].cpu()
|
95 |
is_fake = deepfake_scores[0].item() > threshold
|
96 |
|
97 |
result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
|
@@ -101,28 +76,10 @@ def detect_on_gpu(dataset):
|
|
101 |
print("\n=== 检测完成 ===")
|
102 |
return result
|
103 |
|
104 |
-
|
105 |
-
# def audio_deepfake_detection(demonstrations, query_audio_path):
|
106 |
-
# demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
|
107 |
-
# demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
|
108 |
-
# if len(demonstration_paths) != len(demonstration_labels):
|
109 |
-
# demonstration_labels = demonstration_labels[:len(demonstration_paths)]
|
110 |
-
|
111 |
-
# # 数据集处理
|
112 |
-
# audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
|
113 |
-
|
114 |
-
# # 调用 GPU 检测函数
|
115 |
-
# result = detect_on_gpu(audio_dataset)
|
116 |
-
|
117 |
-
# return {
|
118 |
-
# "Is AI Generated": result["is_fake"],
|
119 |
-
# "Confidence": f"{100*result['confidence']:.2f}%"
|
120 |
-
# }
|
121 |
-
# 0 demonstrations
|
122 |
-
def audio_deepfake_detection(query_audio_path):
|
123 |
|
124 |
# 数据集处理
|
125 |
-
audio_dataset = dataset.DemoDataset(
|
126 |
|
127 |
# 调用 GPU 检测函数
|
128 |
result = detect_on_gpu(audio_dataset)
|
|
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
import torch
|
5 |
+
from model import SpoofVerificationModel # 自定义模型模块
|
6 |
import dataset # 自定义数据集模块
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
+
from transformers import AutoFeatureExtractor
|
9 |
+
|
10 |
|
11 |
@spaces.GPU
|
12 |
def dummy(): # just a dummy
|
|
|
16 |
def load_model():
|
17 |
checkpoint_path = hf_hub_download(
|
18 |
repo_id="amphion/deepfake_detection",
|
19 |
+
filename="checkpoints_w2v-bert_SpoofVerification_MultiDataset/model_checkpoint_4_new.pth",
|
20 |
repo_type="model"
|
21 |
)
|
22 |
if not os.path.exists(checkpoint_path):
|
|
|
35 |
print(f"使用设备: {device}")
|
36 |
|
37 |
print("正在初始化模型...")
|
38 |
+
model = SpoofVerificationModel().to(device)
|
39 |
|
40 |
print(f"正在加载模型权重: {checkpoint_path}")
|
41 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
42 |
model_state_dict = checkpoint['model_state_dict']
|
43 |
+
threshold = 0.5
|
44 |
print(f"检测阈值设置为: {threshold}")
|
45 |
|
46 |
# 处理模型状态字典的 key
|
|
|
55 |
model.eval()
|
56 |
print("模型加载完成,进入评估模式")
|
57 |
|
58 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
59 |
+
|
60 |
print("\n开始处理音频数据...")
|
61 |
with torch.no_grad():
|
62 |
for batch_idx, batch in enumerate(dataset):
|
63 |
print(f"\n处理批次 {batch_idx + 1}")
|
64 |
+
waveforms = batch['waveforms'].numpy() # [B, T]
|
65 |
+
features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
|
66 |
+
outputs = model(features)
|
67 |
+
deepfake_logits = outputs['deepfake_logits']
|
68 |
+
|
69 |
+
deepfake_scores = deepfake_logits.float().softmax(dim=-1)[:, 1].contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
is_fake = deepfake_scores[0].item() > threshold
|
71 |
|
72 |
result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
|
|
|
76 |
print("\n=== 检测完成 ===")
|
77 |
return result
|
78 |
|
79 |
+
def audio_deepfake_detection(audio_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# 数据集处理
|
82 |
+
audio_dataset = dataset.DemoDataset(audio_path)
|
83 |
|
84 |
# 调用 GPU 检测函数
|
85 |
result = detect_on_gpu(audio_dataset)
|
dataset.py
CHANGED
@@ -1,133 +1,35 @@
|
|
1 |
-
import torch
|
2 |
from torch.utils.data import Dataset
|
3 |
-
from transformers import AutoFeatureExtractor
|
4 |
-
import os
|
5 |
import librosa
|
6 |
import numpy as np
|
|
|
7 |
|
8 |
-
class DemoDataset(Dataset):
|
9 |
-
def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000):
|
10 |
-
self.sample_rate = sample_rate
|
11 |
-
self.query_path = query_path
|
12 |
-
|
13 |
-
# Convert to list if single path
|
14 |
-
self.demonstration_paths = demonstration_paths
|
15 |
-
self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels]
|
16 |
-
|
17 |
-
# Load feature extractor
|
18 |
-
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
19 |
-
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
x_len = x.shape[0]
|
30 |
-
if x_len >= max_len:
|
31 |
return x[:max_len]
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def __len__(self):
|
36 |
-
return 1
|
37 |
|
38 |
def __getitem__(self, idx):
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
if len(query_waveform.shape) == 1:
|
43 |
-
query_waveform = query_waveform.unsqueeze(0)
|
44 |
-
|
45 |
-
# Extract features for query audio
|
46 |
-
main_features = self.feature_extractor(
|
47 |
-
query_waveform,
|
48 |
-
sampling_rate=self.sample_rate,
|
49 |
-
padding=True,
|
50 |
-
return_attention_mask=True,
|
51 |
-
return_tensors="pt"
|
52 |
-
)
|
53 |
-
|
54 |
-
# Process demonstration audios
|
55 |
-
prompt_features = []
|
56 |
-
for demo_path in self.demonstration_paths:
|
57 |
-
# Load demonstration audio
|
58 |
-
demo_waveform = self.load_pad(demo_path)
|
59 |
-
demo_waveform = torch.from_numpy(demo_waveform).float()
|
60 |
-
if len(demo_waveform.shape) == 1:
|
61 |
-
demo_waveform = demo_waveform.unsqueeze(0)
|
62 |
-
|
63 |
-
# Extract features
|
64 |
-
prompt_feature = self.feature_extractor(
|
65 |
-
demo_waveform,
|
66 |
-
sampling_rate=self.sample_rate,
|
67 |
-
padding=True,
|
68 |
-
return_attention_mask=True,
|
69 |
-
return_tensors="pt"
|
70 |
-
)
|
71 |
-
prompt_features.append(prompt_feature)
|
72 |
-
|
73 |
-
prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long)
|
74 |
-
|
75 |
-
return {
|
76 |
-
'main_features': main_features,
|
77 |
-
'prompt_features': prompt_features,
|
78 |
-
'prompt_labels': prompt_labels,
|
79 |
-
'file_name': os.path.basename(self.query_path),
|
80 |
-
'file_path': self.query_path
|
81 |
-
}
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
Args:
|
87 |
-
batch: List containing dictionaries with:
|
88 |
-
- main_features: feature extractor output
|
89 |
-
- prompt_features: list of feature extractor outputs
|
90 |
-
- file_name: file name
|
91 |
-
- file_path: file path
|
92 |
-
"""
|
93 |
-
batch_size = len(batch)
|
94 |
-
|
95 |
-
# Process main features
|
96 |
-
main_features_keys = batch[0]['main_features'].keys()
|
97 |
-
main_features = {}
|
98 |
-
for key in main_features_keys:
|
99 |
-
main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0)
|
100 |
-
|
101 |
-
# Get number of prompts
|
102 |
-
num_prompts = len(batch[0]['prompt_features'])
|
103 |
-
|
104 |
-
# Process prompt features
|
105 |
-
prompt_features = []
|
106 |
-
for i in range(num_prompts):
|
107 |
-
prompt_feature = {}
|
108 |
-
for key in main_features_keys:
|
109 |
-
prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0)
|
110 |
-
prompt_features.append(prompt_feature)
|
111 |
-
|
112 |
-
# Collect file names and paths
|
113 |
-
file_names = [item['file_name'] for item in batch]
|
114 |
-
file_paths = [item['file_path'] for item in batch]
|
115 |
-
|
116 |
-
# 确保 prompt_labels 的形状正确 [batch_size, num_prompts]
|
117 |
-
prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0)
|
118 |
-
|
119 |
-
return {
|
120 |
-
'main_features': main_features,
|
121 |
-
'prompt_features': prompt_features,
|
122 |
-
'prompt_labels': prompt_labels,
|
123 |
-
'file_names': file_names,
|
124 |
-
'file_paths': file_paths
|
125 |
-
}
|
126 |
-
|
127 |
-
if __name__ == '__main__':
|
128 |
-
# Test the dataset
|
129 |
-
demo_paths = ["examples/demo1.wav", "examples/demo2.wav"]
|
130 |
-
query_path = "examples/query.wav"
|
131 |
-
|
132 |
-
dataset = DemoDataset(demo_paths, query_path)
|
133 |
-
print(dataset[0])
|
|
|
|
|
1 |
from torch.utils.data import Dataset
|
|
|
|
|
2 |
import librosa
|
3 |
import numpy as np
|
4 |
+
from torch import Tensor
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
def pad(x, max_len=64600, random_clip=True):
|
8 |
+
x_len = x.shape[0]
|
9 |
+
if x_len > max_len:
|
10 |
+
# random clip
|
11 |
+
if random_clip:
|
12 |
+
start_idx = np.random.randint(0, x_len - max_len)
|
13 |
+
return x[start_idx:start_idx + max_len]
|
14 |
+
else:
|
|
|
|
|
15 |
return x[:max_len]
|
16 |
+
# need to pad
|
17 |
+
num_repeats = int(max_len / x_len)+1
|
18 |
+
padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
|
19 |
+
return padded_x
|
20 |
+
|
21 |
+
class DemoDataset(Dataset):
|
22 |
+
def __init__(self, path):
|
23 |
+
self.path = path
|
24 |
|
25 |
def __len__(self):
|
26 |
+
return 1
|
27 |
|
28 |
def __getitem__(self, idx):
|
29 |
+
waveform, sample_rate = librosa.load(self.path, sr=16000)
|
30 |
+
waveform_pad = pad(waveform)
|
31 |
+
waveform_tensor = Tensor(waveform_pad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
return {
|
34 |
+
'waveforms': waveform_tensor,
|
35 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
CHANGED
@@ -1,380 +1,69 @@
|
|
1 |
-
import torch
|
2 |
import torch.nn as nn
|
3 |
from transformers import Wav2Vec2BertModel
|
4 |
-
from llama_nar import LlamaNAREmb
|
5 |
-
from transformers import LlamaConfig
|
6 |
-
import time
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from huggingface_hub import hf_hub_download
|
9 |
|
10 |
|
11 |
-
class
|
12 |
-
def __init__(self):
|
13 |
-
super().__init__()
|
14 |
|
15 |
-
|
16 |
-
self.
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
param.requires_grad = False
|
23 |
-
|
24 |
-
# 冻结多头注意力中的K,V投影
|
25 |
-
if any(proj in name for proj in ['linear_k', 'linear_v']):
|
26 |
-
param.requires_grad = False
|
27 |
-
|
28 |
-
# 冻结distance_embedding
|
29 |
-
if 'distance_embedding' in name:
|
30 |
-
param.requires_grad = False
|
31 |
-
|
32 |
-
# 冻结所有卷积相关模块
|
33 |
-
if any(conv_name in name for conv_name in [
|
34 |
-
'conv_module', 'pointwise_conv', 'depthwise_conv',
|
35 |
-
'feature_extractor', 'pos_conv_embed', 'conv_layers'
|
36 |
-
]):
|
37 |
-
param.requires_grad = False
|
38 |
-
|
39 |
-
# 3. 减小Llama模型规模
|
40 |
-
self.llama_nar = LlamaNAREmb(
|
41 |
-
config=LlamaConfig(
|
42 |
-
hidden_size=512,
|
43 |
-
num_attention_heads=8,
|
44 |
-
num_hidden_layers=8,
|
45 |
-
),
|
46 |
-
num_heads=8,
|
47 |
-
num_layers=8,
|
48 |
-
hidden_size=512
|
49 |
-
)
|
50 |
-
|
51 |
-
# 4. 降维投影层
|
52 |
-
self.projection = nn.Sequential(
|
53 |
-
nn.Linear(1024, 512),
|
54 |
-
nn.LayerNorm(512)
|
55 |
-
)
|
56 |
-
|
57 |
-
# 5. 简化分类头
|
58 |
-
self.classifier = nn.Sequential(
|
59 |
-
nn.Linear(512, 128),
|
60 |
nn.ReLU(),
|
61 |
-
nn.
|
62 |
-
nn.Linear(128, 2)
|
63 |
)
|
64 |
-
|
65 |
-
# 6. 减小embedding维度
|
66 |
-
self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512)
|
67 |
-
|
68 |
-
# 7. 简化特征处理层
|
69 |
-
self.feature_processor = nn.Sequential(
|
70 |
-
nn.Linear(512, 512),
|
71 |
-
nn.LayerNorm(512),
|
72 |
nn.ReLU(),
|
73 |
-
nn.
|
74 |
)
|
75 |
-
|
76 |
-
#
|
77 |
-
|
|
|
|
|
78 |
|
79 |
-
|
80 |
-
#
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
# 确保序列长度可以被因子整除
|
85 |
-
new_len = seq_len // factor
|
86 |
-
padded_len = new_len * factor
|
87 |
-
|
88 |
-
if seq_len > padded_len:
|
89 |
-
sequence = sequence[:, :padded_len, :]
|
90 |
-
|
91 |
-
# 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size]
|
92 |
-
reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size)
|
93 |
-
downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size]
|
94 |
-
return downsampled
|
95 |
|
96 |
-
# 1. 获取最后一层特征并进行下采样
|
97 |
-
last_layer = hidden_states[-1] # [batch_size, seq_len, 1024]
|
98 |
-
downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024]
|
99 |
|
100 |
-
# 2. 投影到512维度
|
101 |
-
projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512]
|
102 |
-
|
103 |
-
return projected_features # 不再需要unsqueeze,因为已经保留了序列维度
|
104 |
|
105 |
-
def forward(self,
|
106 |
-
|
107 |
-
|
108 |
-
)
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
batch_size, num_prompts = batch['prompt_labels'].shape
|
120 |
-
|
121 |
-
# 重塑特征以批量处理
|
122 |
-
prompt_features = batch['prompt_features']
|
123 |
-
all_prompt_outputs = []
|
124 |
-
|
125 |
-
for i in range(num_prompts):
|
126 |
-
prompt_output = self.wav2vec2bert(
|
127 |
-
**prompt_features[i]
|
128 |
-
)
|
129 |
-
all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states))
|
130 |
-
|
131 |
-
if all_prompt_outputs:
|
132 |
-
fused_prompts = torch.stack([
|
133 |
-
self.feature_processor(p) for p in all_prompt_outputs
|
134 |
-
], dim=1) # [batch_size, num_prompts, seq_len, hidden_size]
|
135 |
-
|
136 |
-
# 获取label embeddings并扩展到对应序列长度
|
137 |
-
label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512]
|
138 |
-
|
139 |
-
prompt_embeddings = []
|
140 |
-
for i in range(batch_size):
|
141 |
-
sequence = []
|
142 |
-
|
143 |
-
# 添加示例prompts
|
144 |
-
for j in range(num_prompts):
|
145 |
-
prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度
|
146 |
-
|
147 |
-
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
|
148 |
-
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
|
149 |
-
sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size]
|
150 |
-
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
|
151 |
-
|
152 |
-
# 扩展label embedding到与音频特征相同的长度
|
153 |
-
expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1)
|
154 |
-
sequence.append(expanded_label) # [seq_len, hidden_size]
|
155 |
-
|
156 |
-
sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP]
|
157 |
-
|
158 |
-
# 添加待预测的主特征
|
159 |
-
main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度
|
160 |
-
sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
|
161 |
-
sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
|
162 |
-
sequence.append(fused_features[i]) # [main_seq_len, hidden_size]
|
163 |
-
sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
|
164 |
-
# 预测位置使用零向量,长度与主特征相同
|
165 |
-
sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device))
|
166 |
-
|
167 |
-
prompt_embeddings.append(torch.cat(sequence, dim=0))
|
168 |
-
|
169 |
-
prompt_embeddings = torch.stack(prompt_embeddings, dim=0)
|
170 |
-
|
171 |
-
else:
|
172 |
-
# 简化无prompt情况的处理
|
173 |
-
batch_size = fused_features.size(0)
|
174 |
-
main_seq_len = fused_features.size(1) # 直接获取主特征序列长度
|
175 |
-
|
176 |
-
# 构建序列 [batch_size, total_len, hidden_size]
|
177 |
-
prompt_embeddings = torch.cat([
|
178 |
-
self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT]
|
179 |
-
self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO]
|
180 |
-
fused_features, # [batch_size, main_seq_len, hidden_size]
|
181 |
-
self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL]
|
182 |
-
torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置
|
183 |
-
], dim=1)
|
184 |
-
|
185 |
-
# 输入到llama_nar
|
186 |
-
output = self.llama_nar(inputs_embeds=prompt_embeddings)
|
187 |
-
|
188 |
-
# 获取所有预测位置的输出(即最后main_seq_len个位置)
|
189 |
-
pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size]
|
190 |
-
# 对每一帧进行分类
|
191 |
-
frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2]
|
192 |
-
|
193 |
-
# 同时返回帧级别的logits和整体的logits(通过平均得到)
|
194 |
-
avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size]
|
195 |
-
avg_logits = self.classifier(avg_embedding) # [batch_size, 2]
|
196 |
-
|
197 |
return {
|
198 |
-
'
|
199 |
-
'
|
|
|
|
|
|
|
200 |
}
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
-
if __name__ == '__main__':
|
204 |
-
import torch
|
205 |
-
from torch.utils.data import DataLoader
|
206 |
-
from dataset.train_MultiDataset import train_MultiDataset, collate_fn
|
207 |
-
from tqdm import tqdm
|
208 |
-
import time
|
209 |
-
|
210 |
-
# 设置设备
|
211 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
212 |
-
print(f"\n=== 使用设备: {device} ===")
|
213 |
-
|
214 |
-
# 初始化模型
|
215 |
-
print("\n=== 初始化模型 ===")
|
216 |
-
model = Wav2Vec2BERT_Llama().to(device)
|
217 |
-
model.eval() # 设置为评估模式
|
218 |
-
|
219 |
-
# 打印wav2vec2bert的参数结构
|
220 |
-
print("\n=== Wav2Vec2BERT 参数结构 ===")
|
221 |
-
w2v_params_by_layer = {}
|
222 |
-
total_trainable = 0
|
223 |
-
total_frozen = 0
|
224 |
-
|
225 |
-
for name, param in model.wav2vec2bert.named_parameters():
|
226 |
-
# 获取主要层名称
|
227 |
-
layer_name = name.split('.')[0]
|
228 |
-
if layer_name not in w2v_params_by_layer:
|
229 |
-
w2v_params_by_layer[layer_name] = {
|
230 |
-
'trainable_params': 0,
|
231 |
-
'frozen_params': 0,
|
232 |
-
'parameter_names': []
|
233 |
-
}
|
234 |
-
|
235 |
-
# 统计参数
|
236 |
-
if param.requires_grad:
|
237 |
-
w2v_params_by_layer[layer_name]['trainable_params'] += param.numel()
|
238 |
-
total_trainable += param.numel()
|
239 |
-
else:
|
240 |
-
w2v_params_by_layer[layer_name]['frozen_params'] += param.numel()
|
241 |
-
total_frozen += param.numel()
|
242 |
-
|
243 |
-
w2v_params_by_layer[layer_name]['parameter_names'].append(name)
|
244 |
-
|
245 |
-
# 打印每层的详细信息
|
246 |
-
print("\n各层参数统计:")
|
247 |
-
for layer_name, info in w2v_params_by_layer.items():
|
248 |
-
trainable_mb = info['trainable_params'] / 1024 / 1024
|
249 |
-
frozen_mb = info['frozen_params'] / 1024 / 1024
|
250 |
-
total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024
|
251 |
-
|
252 |
-
print(f"\n{layer_name}:")
|
253 |
-
print(f" - 总参数量: {total_mb:.2f}MB")
|
254 |
-
print(f" - 可训练参数: {trainable_mb:.2f}MB")
|
255 |
-
print(f" - 冻结参数: {frozen_mb:.2f}MB")
|
256 |
-
print(f" - 参数名称:")
|
257 |
-
for param_name in info['parameter_names']:
|
258 |
-
print(f" * {param_name}")
|
259 |
-
|
260 |
-
# 打印总体统计
|
261 |
-
print("\n=== 总体统计 ===")
|
262 |
-
print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB")
|
263 |
-
print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB")
|
264 |
-
print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB")
|
265 |
-
print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%")
|
266 |
-
|
267 |
-
# 分别统计各个模块的参数量
|
268 |
-
wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters())
|
269 |
-
llama_params = sum(p.numel() for p in model.llama_nar.parameters())
|
270 |
-
other_params = sum(p.numel() for name, p in model.named_parameters()
|
271 |
-
if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.'))
|
272 |
-
|
273 |
-
total_params = wav2vec2bert_params + llama_params + other_params
|
274 |
-
|
275 |
-
print(f"\n=== 参数量统计 ===")
|
276 |
-
print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)")
|
277 |
-
print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)")
|
278 |
-
print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)")
|
279 |
-
print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)")
|
280 |
-
|
281 |
-
# 计算百分比
|
282 |
-
print(f"\n=== 参数量占比 ===")
|
283 |
-
print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%")
|
284 |
-
print(f"LlamaNAR: {llama_params/total_params*100:.2f}%")
|
285 |
-
print(f"其他模块: {other_params/total_params*100:.2f}%")
|
286 |
-
|
287 |
-
# 测试运行时间和内存使用
|
288 |
-
print("\n=== 测试运行时间和内存使用 (batch_size=4) ===")
|
289 |
-
batch_size = 4
|
290 |
-
total_samples = 600000
|
291 |
-
|
292 |
-
# 清空GPU缓存
|
293 |
-
if torch.cuda.is_available():
|
294 |
-
torch.cuda.empty_cache()
|
295 |
-
initial_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
296 |
-
print(f"初始GPU内存使用: {initial_memory:.2f}MB")
|
297 |
-
|
298 |
-
# 初始化数据集
|
299 |
-
print("\n初始化数据集...")
|
300 |
-
ds = train_MultiDataset(max_prompts=3)
|
301 |
-
|
302 |
-
# 创建DataLoader
|
303 |
-
dl = DataLoader(ds,
|
304 |
-
batch_size=batch_size,
|
305 |
-
shuffle=True,
|
306 |
-
collate_fn=collate_fn,
|
307 |
-
num_workers=4)
|
308 |
-
|
309 |
-
print(f"\n数据集大小: {len(ds)}")
|
310 |
-
print(f"批次数量: {len(dl)}")
|
311 |
-
|
312 |
-
# 计算一个batch的平均时间
|
313 |
-
num_test_batches = 10
|
314 |
-
total_time = 0
|
315 |
-
max_memory = 0
|
316 |
-
|
317 |
-
print(f"\n测试{num_test_batches}个batch的平均运行时间...")
|
318 |
-
with torch.no_grad():
|
319 |
-
for i, batch in enumerate(tqdm(dl, total=num_test_batches)):
|
320 |
-
if i >= num_test_batches:
|
321 |
-
break
|
322 |
-
|
323 |
-
# 正确处理字典类型的特征
|
324 |
-
main_features = {
|
325 |
-
'input_features': batch['main_features']['input_features'].to(device),
|
326 |
-
'attention_mask': batch['main_features']['attention_mask'].to(device)
|
327 |
-
}
|
328 |
-
|
329 |
-
prompt_features = [{
|
330 |
-
'input_features': pf['input_features'].to(device),
|
331 |
-
'attention_mask': pf['attention_mask'].to(device)
|
332 |
-
} for pf in batch['prompt_features']]
|
333 |
-
|
334 |
-
labels = batch['labels'].to(device)
|
335 |
-
prompt_labels = batch['prompt_labels'].to(device)
|
336 |
-
|
337 |
-
# 记录开始时间
|
338 |
-
start_time = time.time()
|
339 |
-
|
340 |
-
# 前向传播
|
341 |
-
outputs = model({
|
342 |
-
'main_features': main_features,
|
343 |
-
'prompt_features': prompt_features,
|
344 |
-
'prompt_labels': prompt_labels
|
345 |
-
})
|
346 |
-
|
347 |
-
# 确保GPU运算完成
|
348 |
-
if torch.cuda.is_available():
|
349 |
-
torch.cuda.synchronize()
|
350 |
-
|
351 |
-
# 记录结束时间和内存使用
|
352 |
-
end_time = time.time()
|
353 |
-
total_time += (end_time - start_time)
|
354 |
-
|
355 |
-
if torch.cuda.is_available():
|
356 |
-
current_memory = torch.cuda.memory_allocated() / 1024 / 1024
|
357 |
-
max_memory = max(max_memory, current_memory)
|
358 |
-
|
359 |
-
# 打印第一个batch的详细信息
|
360 |
-
if i == 0:
|
361 |
-
print("\n=== 第一个Batch的详细信息 ===")
|
362 |
-
print(f"主特征形状: {main_features['input_features'].shape}")
|
363 |
-
print(f"主掩码形状: {main_features['attention_mask'].shape}")
|
364 |
-
print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}")
|
365 |
-
print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}")
|
366 |
-
print(f"标签形状: {labels.shape}")
|
367 |
-
print(f"Prompt标签形状: {prompt_labels.shape}")
|
368 |
-
print(f"模型输出形状: {outputs.shape}")
|
369 |
-
print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]")
|
370 |
-
|
371 |
-
# 计算和打印统计信息
|
372 |
-
avg_time = total_time / num_test_batches
|
373 |
-
print(f"\n=== 性能统计 ===")
|
374 |
-
print(f"平均每个batch处理时间: {avg_time:.4f}秒")
|
375 |
-
print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时")
|
376 |
-
if torch.cuda.is_available():
|
377 |
-
print(f"最大GPU内存使用: {max_memory:.2f}MB")
|
378 |
-
print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB")
|
379 |
-
|
380 |
-
print("\n测试完成!")
|
|
|
|
|
1 |
import torch.nn as nn
|
2 |
from transformers import Wav2Vec2BertModel
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
+
class SpoofVerificationModel(nn.Module):
|
6 |
+
def __init__(self, w2v_path, num_types=49):
|
7 |
+
super(SpoofVerificationModel, self).__init__()
|
8 |
|
9 |
+
self.wav2vec2 = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True)
|
10 |
+
self.wav2vec_config = self.wav2vec2.config
|
11 |
+
|
12 |
+
self.deepfake_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)
|
13 |
+
self.type_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)
|
14 |
+
|
15 |
+
self.deepfake_classifier = nn.Sequential(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
nn.ReLU(),
|
17 |
+
nn.Linear(1024, 2)
|
|
|
18 |
)
|
19 |
+
self.type_classifier = nn.Sequential(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
nn.ReLU(),
|
21 |
+
nn.Linear(1024, num_types)
|
22 |
)
|
23 |
+
# self.deepfake_classifier = nn.Sequential(
|
24 |
+
# nn.Linear(self.wav2vec2.config.hidden_size, 1024),
|
25 |
+
# nn.ReLU(),
|
26 |
+
# nn.Linear(1024, 2)
|
27 |
+
# )
|
28 |
|
29 |
+
# self.type_classifier = nn.Sequential(
|
30 |
+
# nn.Linear(self.wav2vec2.config.hidden_size, 1024),
|
31 |
+
# nn.ReLU(),
|
32 |
+
# nn.Linear(1024, num_types)
|
33 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
|
|
|
|
|
|
35 |
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
def forward(self, audio_features):
|
38 |
+
|
39 |
+
audio_features = self.wav2vec2(**audio_features) # [B, T, D]
|
40 |
+
audio_features = audio_features.last_hidden_state # (B, T, D)
|
41 |
+
audio_features = audio_features.mean(dim=1) # (B, D)
|
42 |
+
|
43 |
+
# deepfake_logits = self.deepfake_classifier(audio_features)
|
44 |
+
# type_logits = self.type_classifier(audio_features)
|
45 |
+
|
46 |
+
deepfake_emb = self.deepfake_embed(audio_features)
|
47 |
+
type_emb = self.type_embed(audio_features)
|
48 |
+
deepfake_logits = self.deepfake_classifier(deepfake_emb)
|
49 |
+
type_logits = self.type_classifier(type_emb)
|
50 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return {
|
52 |
+
'deepfake_logits': deepfake_logits,
|
53 |
+
'type_logits': type_logits,
|
54 |
+
'embeddings': audio_features,
|
55 |
+
'deepfake_embed': deepfake_emb, # 新增embedding输出
|
56 |
+
'type_embed': type_emb # 新增embedding输出
|
57 |
}
|
58 |
|
59 |
+
# return {
|
60 |
+
# 'deepfake_logits': deepfake_logits,
|
61 |
+
# 'type_logits': type_logits,
|
62 |
+
# 'embeddings': audio_features
|
63 |
+
# }
|
64 |
+
|
65 |
+
def print_parameters_info(self):
|
66 |
+
print(f"wav2vec2 parameters: {sum(p.numel() for p in self.wav2vec2.parameters())/1e6:.2f}M")
|
67 |
+
print(f"deepfake_classifier parameters: {sum(p.numel() for p in self.deepfake_classifier.parameters())/1e6:.2f}M")
|
68 |
+
print(f"type_classifier parameters: {sum(p.numel() for p in self.type_classifier.parameters())/1e6:.2f}M")
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|