wli3221134 commited on
Commit
b82b421
·
verified ·
1 Parent(s): 87c8d2e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +16 -59
  2. dataset.py +24 -122
  3. model.py +52 -363
app.py CHANGED
@@ -2,9 +2,11 @@ import spaces
2
  import gradio as gr
3
  import os
4
  import torch
5
- from model import Wav2Vec2BERT_Llama # 自定义模型模块
6
  import dataset # 自定义数据集模块
7
  from huggingface_hub import hf_hub_download
 
 
8
 
9
  @spaces.GPU
10
  def dummy(): # just a dummy
@@ -14,7 +16,7 @@ def dummy(): # just a dummy
14
  def load_model():
15
  checkpoint_path = hf_hub_download(
16
  repo_id="amphion/deepfake_detection",
17
- filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
18
  repo_type="model"
19
  )
20
  if not os.path.exists(checkpoint_path):
@@ -33,12 +35,12 @@ def detect_on_gpu(dataset):
33
  print(f"使用设备: {device}")
34
 
35
  print("正在初始化模型...")
36
- model = Wav2Vec2BERT_Llama().to(device)
37
 
38
  print(f"正在加载模型权重: {checkpoint_path}")
39
  checkpoint = torch.load(checkpoint_path, map_location=device)
40
  model_state_dict = checkpoint['model_state_dict']
41
- threshold = 0.8
42
  print(f"检测阈值设置为: {threshold}")
43
 
44
  # 处理模型状态字典的 key
@@ -53,45 +55,18 @@ def detect_on_gpu(dataset):
53
  model.eval()
54
  print("模型加载完成,进入评估模式")
55
 
 
 
56
  print("\n开始处理音频数据...")
57
  with torch.no_grad():
58
  for batch_idx, batch in enumerate(dataset):
59
  print(f"\n处理批次 {batch_idx + 1}")
60
-
61
- print("准备主特征...")
62
- main_features = {
63
- 'input_features': batch['main_features']['input_features'].to(device),
64
- 'attention_mask': batch['main_features']['attention_mask'].to(device)
65
- }
66
- print(f"主特征形状: {main_features['input_features'].shape}")
67
-
68
- if len(batch['prompt_features']) > 0:
69
- print("\n准备提示特征...")
70
- prompt_features = [{
71
- 'input_features': pf['input_features'].to(device),
72
- 'attention_mask': pf['attention_mask'].to(device)
73
- } for pf in batch['prompt_features']]
74
- print(f"提示特征数量: {len(prompt_features)}")
75
- print(f"第一个提示特征形状: {prompt_features[0]['input_features'].shape}")
76
-
77
- print("\n准备提示标签...")
78
- prompt_labels = batch['prompt_labels'].to(device)
79
- print(f"提示标签形状: {prompt_labels.shape}")
80
- print(f"提示标签值: {prompt_labels}")
81
- else:
82
- prompt_features = []
83
- prompt_labels = []
84
-
85
- print("\n执行模型推理...")
86
- outputs = model({
87
- 'main_features': main_features,
88
- 'prompt_features': prompt_features,
89
- 'prompt_labels': prompt_labels
90
- })
91
-
92
- print("\n处理模型输出...")
93
- avg_scores = outputs['avg_logits'].softmax(dim=-1)
94
- deepfake_scores = avg_scores[:, 1].cpu()
95
  is_fake = deepfake_scores[0].item() > threshold
96
 
97
  result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
@@ -101,28 +76,10 @@ def detect_on_gpu(dataset):
101
  print("\n=== 检测完成 ===")
102
  return result
103
 
104
- # 修改音频伪造检测主函数
105
- # def audio_deepfake_detection(demonstrations, query_audio_path):
106
- # demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
107
- # demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
108
- # if len(demonstration_paths) != len(demonstration_labels):
109
- # demonstration_labels = demonstration_labels[:len(demonstration_paths)]
110
-
111
- # # 数据集处理
112
- # audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
113
-
114
- # # 调用 GPU 检测函数
115
- # result = detect_on_gpu(audio_dataset)
116
-
117
- # return {
118
- # "Is AI Generated": result["is_fake"],
119
- # "Confidence": f"{100*result['confidence']:.2f}%"
120
- # }
121
- # 0 demonstrations
122
- def audio_deepfake_detection(query_audio_path):
123
 
124
  # 数据集处理
125
- audio_dataset = dataset.DemoDataset([], [], query_audio_path)
126
 
127
  # 调用 GPU 检测函数
128
  result = detect_on_gpu(audio_dataset)
 
2
  import gradio as gr
3
  import os
4
  import torch
5
+ from model import SpoofVerificationModel # 自定义模型模块
6
  import dataset # 自定义数据集模块
7
  from huggingface_hub import hf_hub_download
8
+ from transformers import AutoFeatureExtractor
9
+
10
 
11
  @spaces.GPU
12
  def dummy(): # just a dummy
 
16
  def load_model():
17
  checkpoint_path = hf_hub_download(
18
  repo_id="amphion/deepfake_detection",
19
+ filename="checkpoints_w2v-bert_SpoofVerification_MultiDataset/model_checkpoint_4_new.pth",
20
  repo_type="model"
21
  )
22
  if not os.path.exists(checkpoint_path):
 
35
  print(f"使用设备: {device}")
36
 
37
  print("正在初始化模型...")
38
+ model = SpoofVerificationModel().to(device)
39
 
40
  print(f"正在加载模型权重: {checkpoint_path}")
41
  checkpoint = torch.load(checkpoint_path, map_location=device)
42
  model_state_dict = checkpoint['model_state_dict']
43
+ threshold = 0.5
44
  print(f"检测阈值设置为: {threshold}")
45
 
46
  # 处理模型状态字典的 key
 
55
  model.eval()
56
  print("模型加载完成,进入评估模式")
57
 
58
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
59
+
60
  print("\n开始处理音频数据...")
61
  with torch.no_grad():
62
  for batch_idx, batch in enumerate(dataset):
63
  print(f"\n处理批次 {batch_idx + 1}")
64
+ waveforms = batch['waveforms'].numpy() # [B, T]
65
+ features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
66
+ outputs = model(features)
67
+ deepfake_logits = outputs['deepfake_logits']
68
+
69
+ deepfake_scores = deepfake_logits.float().softmax(dim=-1)[:, 1].contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  is_fake = deepfake_scores[0].item() > threshold
71
 
72
  result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
 
76
  print("\n=== 检测完成 ===")
77
  return result
78
 
79
+ def audio_deepfake_detection(audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # 数据集处理
82
+ audio_dataset = dataset.DemoDataset(audio_path)
83
 
84
  # 调用 GPU 检测函数
85
  result = detect_on_gpu(audio_dataset)
dataset.py CHANGED
@@ -1,133 +1,35 @@
1
- import torch
2
  from torch.utils.data import Dataset
3
- from transformers import AutoFeatureExtractor
4
- import os
5
  import librosa
6
  import numpy as np
 
7
 
8
- class DemoDataset(Dataset):
9
- def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000):
10
- self.sample_rate = sample_rate
11
- self.query_path = query_path
12
-
13
- # Convert to list if single path
14
- self.demonstration_paths = demonstration_paths
15
- self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels]
16
-
17
- # Load feature extractor
18
- self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
19
-
20
 
21
- def load_pad(self, path, max_length=64000):
22
- """Load and pad audio file"""
23
- X, sr = librosa.load(path, sr=self.sample_rate)
24
- X = self.pad(X, max_length)
25
- return X
26
-
27
- def pad(self, x, max_len=64000):
28
- """Pad audio to fixed length"""
29
- x_len = x.shape[0]
30
- if x_len >= max_len:
31
  return x[:max_len]
32
- pad_length = max_len - x_len
33
- return np.concatenate([x, np.zeros(pad_length)], axis=0)
 
 
 
 
 
 
34
 
35
  def __len__(self):
36
- return 1 # Only one query audio
37
 
38
  def __getitem__(self, idx):
39
- # Load query audio
40
- query_waveform = self.load_pad(self.query_path)
41
- query_waveform = torch.from_numpy(query_waveform).float()
42
- if len(query_waveform.shape) == 1:
43
- query_waveform = query_waveform.unsqueeze(0)
44
-
45
- # Extract features for query audio
46
- main_features = self.feature_extractor(
47
- query_waveform,
48
- sampling_rate=self.sample_rate,
49
- padding=True,
50
- return_attention_mask=True,
51
- return_tensors="pt"
52
- )
53
-
54
- # Process demonstration audios
55
- prompt_features = []
56
- for demo_path in self.demonstration_paths:
57
- # Load demonstration audio
58
- demo_waveform = self.load_pad(demo_path)
59
- demo_waveform = torch.from_numpy(demo_waveform).float()
60
- if len(demo_waveform.shape) == 1:
61
- demo_waveform = demo_waveform.unsqueeze(0)
62
-
63
- # Extract features
64
- prompt_feature = self.feature_extractor(
65
- demo_waveform,
66
- sampling_rate=self.sample_rate,
67
- padding=True,
68
- return_attention_mask=True,
69
- return_tensors="pt"
70
- )
71
- prompt_features.append(prompt_feature)
72
-
73
- prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long)
74
-
75
- return {
76
- 'main_features': main_features,
77
- 'prompt_features': prompt_features,
78
- 'prompt_labels': prompt_labels,
79
- 'file_name': os.path.basename(self.query_path),
80
- 'file_path': self.query_path
81
- }
82
 
83
- def collate_fn(batch):
84
- """
85
- Collate function for dataloader
86
- Args:
87
- batch: List containing dictionaries with:
88
- - main_features: feature extractor output
89
- - prompt_features: list of feature extractor outputs
90
- - file_name: file name
91
- - file_path: file path
92
- """
93
- batch_size = len(batch)
94
-
95
- # Process main features
96
- main_features_keys = batch[0]['main_features'].keys()
97
- main_features = {}
98
- for key in main_features_keys:
99
- main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0)
100
-
101
- # Get number of prompts
102
- num_prompts = len(batch[0]['prompt_features'])
103
-
104
- # Process prompt features
105
- prompt_features = []
106
- for i in range(num_prompts):
107
- prompt_feature = {}
108
- for key in main_features_keys:
109
- prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0)
110
- prompt_features.append(prompt_feature)
111
-
112
- # Collect file names and paths
113
- file_names = [item['file_name'] for item in batch]
114
- file_paths = [item['file_path'] for item in batch]
115
-
116
- # 确保 prompt_labels 的形状正确 [batch_size, num_prompts]
117
- prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0)
118
-
119
- return {
120
- 'main_features': main_features,
121
- 'prompt_features': prompt_features,
122
- 'prompt_labels': prompt_labels,
123
- 'file_names': file_names,
124
- 'file_paths': file_paths
125
- }
126
-
127
- if __name__ == '__main__':
128
- # Test the dataset
129
- demo_paths = ["examples/demo1.wav", "examples/demo2.wav"]
130
- query_path = "examples/query.wav"
131
-
132
- dataset = DemoDataset(demo_paths, query_path)
133
- print(dataset[0])
 
 
1
  from torch.utils.data import Dataset
 
 
2
  import librosa
3
  import numpy as np
4
+ from torch import Tensor
5
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def pad(x, max_len=64600, random_clip=True):
8
+ x_len = x.shape[0]
9
+ if x_len > max_len:
10
+ # random clip
11
+ if random_clip:
12
+ start_idx = np.random.randint(0, x_len - max_len)
13
+ return x[start_idx:start_idx + max_len]
14
+ else:
 
 
15
  return x[:max_len]
16
+ # need to pad
17
+ num_repeats = int(max_len / x_len)+1
18
+ padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
19
+ return padded_x
20
+
21
+ class DemoDataset(Dataset):
22
+ def __init__(self, path):
23
+ self.path = path
24
 
25
  def __len__(self):
26
+ return 1
27
 
28
  def __getitem__(self, idx):
29
+ waveform, sample_rate = librosa.load(self.path, sr=16000)
30
+ waveform_pad = pad(waveform)
31
+ waveform_tensor = Tensor(waveform_pad)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ return {
34
+ 'waveforms': waveform_tensor,
35
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -1,380 +1,69 @@
1
- import torch
2
  import torch.nn as nn
3
  from transformers import Wav2Vec2BertModel
4
- from llama_nar import LlamaNAREmb
5
- from transformers import LlamaConfig
6
- import time
7
- import torch.nn.functional as F
8
- from huggingface_hub import hf_hub_download
9
 
10
 
11
- class Wav2Vec2BERT_Llama(nn.Module):
12
- def __init__(self):
13
- super().__init__()
14
 
15
- # 1. 加载预训练模型
16
- self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True)
17
-
18
- # 2. 选择性冻结参数
19
- for name, param in self.wav2vec2bert.named_parameters():
20
- # 冻结所有FFN1 (保留FFN2的适应能力)
21
- if 'ffn1' in name:
22
- param.requires_grad = False
23
-
24
- # 冻结多头注意力中的K,V投影
25
- if any(proj in name for proj in ['linear_k', 'linear_v']):
26
- param.requires_grad = False
27
-
28
- # 冻结distance_embedding
29
- if 'distance_embedding' in name:
30
- param.requires_grad = False
31
-
32
- # 冻结所有卷积相关模块
33
- if any(conv_name in name for conv_name in [
34
- 'conv_module', 'pointwise_conv', 'depthwise_conv',
35
- 'feature_extractor', 'pos_conv_embed', 'conv_layers'
36
- ]):
37
- param.requires_grad = False
38
-
39
- # 3. 减小Llama模型规模
40
- self.llama_nar = LlamaNAREmb(
41
- config=LlamaConfig(
42
- hidden_size=512,
43
- num_attention_heads=8,
44
- num_hidden_layers=8,
45
- ),
46
- num_heads=8,
47
- num_layers=8,
48
- hidden_size=512
49
- )
50
-
51
- # 4. 降维投影层
52
- self.projection = nn.Sequential(
53
- nn.Linear(1024, 512),
54
- nn.LayerNorm(512)
55
- )
56
-
57
- # 5. 简化分类头
58
- self.classifier = nn.Sequential(
59
- nn.Linear(512, 128),
60
  nn.ReLU(),
61
- nn.Dropout(0.1),
62
- nn.Linear(128, 2)
63
  )
64
-
65
- # 6. 减小embedding维度
66
- self.label_embedding = nn.Embedding(num_embeddings=2, embedding_dim=512)
67
-
68
- # 7. 简化特征处理层
69
- self.feature_processor = nn.Sequential(
70
- nn.Linear(512, 512),
71
- nn.LayerNorm(512),
72
  nn.ReLU(),
73
- nn.Dropout(0.1)
74
  )
75
-
76
- # 8. 减小特殊token的维度
77
- self.special_tokens = nn.Parameter(torch.randn(4, 512))
 
 
78
 
79
- def _fuse_layers(self, hidden_states):
80
- # 修改特征融合方法
81
- def downsample_sequence(sequence, factor=10):
82
- """对序列进行下采样"""
83
- batch_size, seq_len, hidden_size = sequence.shape
84
- # 确保序列长度可以被因子整除
85
- new_len = seq_len // factor
86
- padded_len = new_len * factor
87
-
88
- if seq_len > padded_len:
89
- sequence = sequence[:, :padded_len, :]
90
-
91
- # 重塑张量并进行平均池化 [batch_size, new_len, factor, hidden_size]
92
- reshaped = sequence.reshape(batch_size, new_len, factor, hidden_size)
93
- downsampled = torch.mean(reshaped, dim=2) # [batch_size, new_len, hidden_size]
94
- return downsampled
95
 
96
- # 1. 获取最后一层特征并进行下采样
97
- last_layer = hidden_states[-1] # [batch_size, seq_len, 1024]
98
- downsampled_features = downsample_sequence(last_layer) # [batch_size, seq_len//10, 1024]
99
 
100
- # 2. 投影到512维度
101
- projected_features = self.projection(downsampled_features) # [batch_size, seq_len//10, 512]
102
-
103
- return projected_features # 不再需要unsqueeze,因为已经保留了序列维度
104
 
105
- def forward(self, batch):
106
- main_output = self.wav2vec2bert(
107
- **batch['main_features']
108
- )
109
-
110
- fused_features = self._fuse_layers(main_output.hidden_states)
111
- fused_features = self.feature_processor(fused_features)
112
-
113
- if ('prompt_labels' in batch and
114
- batch['prompt_labels'] is not None and
115
- 'prompt_features' in batch and
116
- batch['prompt_features'] and
117
- len(batch['prompt_features']) > 0):
118
-
119
- batch_size, num_prompts = batch['prompt_labels'].shape
120
-
121
- # 重塑特征以批量处理
122
- prompt_features = batch['prompt_features']
123
- all_prompt_outputs = []
124
-
125
- for i in range(num_prompts):
126
- prompt_output = self.wav2vec2bert(
127
- **prompt_features[i]
128
- )
129
- all_prompt_outputs.append(self._fuse_layers(prompt_output.hidden_states))
130
-
131
- if all_prompt_outputs:
132
- fused_prompts = torch.stack([
133
- self.feature_processor(p) for p in all_prompt_outputs
134
- ], dim=1) # [batch_size, num_prompts, seq_len, hidden_size]
135
-
136
- # 获取label embeddings并扩展到对应序列长度
137
- label_embs = self.label_embedding(batch['prompt_labels']) # [batch_size, num_prompts, 512]
138
-
139
- prompt_embeddings = []
140
- for i in range(batch_size):
141
- sequence = []
142
-
143
- # 添加示例prompts
144
- for j in range(num_prompts):
145
- prompt_seq_len = fused_prompts[i, j].size(0) # 获取当前prompt的序列长度
146
-
147
- sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
148
- sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
149
- sequence.append(fused_prompts[i, j]) # [seq_len, hidden_size]
150
- sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
151
-
152
- # 扩展label embedding到与音频特征相同的长度
153
- expanded_label = label_embs[i, j].unsqueeze(0).expand(prompt_seq_len, -1)
154
- sequence.append(expanded_label) # [seq_len, hidden_size]
155
-
156
- sequence.append(self.special_tokens[0].expand(1, -1)) # [SEP]
157
-
158
- # 添加待预测的主特征
159
- main_seq_len = fused_features[i].size(0) # 获取主特征的序列长度
160
- sequence.append(self.special_tokens[1].expand(1, -1)) # [PROMPT]
161
- sequence.append(self.special_tokens[2].expand(1, -1)) # [AUDIO]
162
- sequence.append(fused_features[i]) # [main_seq_len, hidden_size]
163
- sequence.append(self.special_tokens[3].expand(1, -1)) # [LABEL]
164
- # 预测位置使用零向量,长度与主特征相同
165
- sequence.append(torch.zeros(main_seq_len, fused_features.size(-1)).to(fused_features.device))
166
-
167
- prompt_embeddings.append(torch.cat(sequence, dim=0))
168
-
169
- prompt_embeddings = torch.stack(prompt_embeddings, dim=0)
170
-
171
- else:
172
- # 简化无prompt情况的处理
173
- batch_size = fused_features.size(0)
174
- main_seq_len = fused_features.size(1) # 直接获取主特征序列长度
175
-
176
- # 构建序列 [batch_size, total_len, hidden_size]
177
- prompt_embeddings = torch.cat([
178
- self.special_tokens[1].expand(batch_size, 1, -1), # [PROMPT]
179
- self.special_tokens[2].expand(batch_size, 1, -1), # [AUDIO]
180
- fused_features, # [batch_size, main_seq_len, hidden_size]
181
- self.special_tokens[3].expand(batch_size, 1, -1), # [LABEL]
182
- torch.zeros(batch_size, main_seq_len, fused_features.size(-1)).to(fused_features.device) # 预测位置
183
- ], dim=1)
184
-
185
- # 输入到llama_nar
186
- output = self.llama_nar(inputs_embeds=prompt_embeddings)
187
-
188
- # 获取所有预测位置的输出(即最后main_seq_len个位置)
189
- pred_pos_embeddings = output[:, -main_seq_len:, :] # [batch_size, main_seq_len, hidden_size]
190
- # 对每一帧进行分类
191
- frame_logits = self.classifier(pred_pos_embeddings) # [batch_size, main_seq_len, 2]
192
-
193
- # 同时返回帧级别的logits和整体的logits(通过平均得到)
194
- avg_embedding = torch.mean(pred_pos_embeddings, dim=1) # [batch_size, hidden_size]
195
- avg_logits = self.classifier(avg_embedding) # [batch_size, 2]
196
-
197
  return {
198
- 'frame_logits': frame_logits, # 每一帧的预测分数
199
- 'avg_logits': avg_logits # 整体的预测分数
 
 
 
200
  }
201
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- if __name__ == '__main__':
204
- import torch
205
- from torch.utils.data import DataLoader
206
- from dataset.train_MultiDataset import train_MultiDataset, collate_fn
207
- from tqdm import tqdm
208
- import time
209
-
210
- # 设置设备
211
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
212
- print(f"\n=== 使用设备: {device} ===")
213
-
214
- # 初始化模型
215
- print("\n=== 初始化模型 ===")
216
- model = Wav2Vec2BERT_Llama().to(device)
217
- model.eval() # 设置为评估模式
218
-
219
- # 打印wav2vec2bert的参数结构
220
- print("\n=== Wav2Vec2BERT 参数结构 ===")
221
- w2v_params_by_layer = {}
222
- total_trainable = 0
223
- total_frozen = 0
224
-
225
- for name, param in model.wav2vec2bert.named_parameters():
226
- # 获取主要层名称
227
- layer_name = name.split('.')[0]
228
- if layer_name not in w2v_params_by_layer:
229
- w2v_params_by_layer[layer_name] = {
230
- 'trainable_params': 0,
231
- 'frozen_params': 0,
232
- 'parameter_names': []
233
- }
234
-
235
- # 统计参数
236
- if param.requires_grad:
237
- w2v_params_by_layer[layer_name]['trainable_params'] += param.numel()
238
- total_trainable += param.numel()
239
- else:
240
- w2v_params_by_layer[layer_name]['frozen_params'] += param.numel()
241
- total_frozen += param.numel()
242
-
243
- w2v_params_by_layer[layer_name]['parameter_names'].append(name)
244
-
245
- # 打印每层的详细信息
246
- print("\n各层参数统计:")
247
- for layer_name, info in w2v_params_by_layer.items():
248
- trainable_mb = info['trainable_params'] / 1024 / 1024
249
- frozen_mb = info['frozen_params'] / 1024 / 1024
250
- total_mb = (info['trainable_params'] + info['frozen_params']) / 1024 / 1024
251
-
252
- print(f"\n{layer_name}:")
253
- print(f" - 总参数量: {total_mb:.2f}MB")
254
- print(f" - 可训练参数: {trainable_mb:.2f}MB")
255
- print(f" - 冻结参数: {frozen_mb:.2f}MB")
256
- print(f" - 参数名称:")
257
- for param_name in info['parameter_names']:
258
- print(f" * {param_name}")
259
-
260
- # 打印总体统计
261
- print("\n=== 总体统计 ===")
262
- print(f"可训练参数总量: {total_trainable/1024/1024:.2f}MB")
263
- print(f"冻结参数总量: {total_frozen/1024/1024:.2f}MB")
264
- print(f"参数总量: {(total_trainable + total_frozen)/1024/1024:.2f}MB")
265
- print(f"可训练参数占比: {total_trainable/(total_trainable + total_frozen)*100:.2f}%")
266
-
267
- # 分别统计各个模块的参数量
268
- wav2vec2bert_params = sum(p.numel() for p in model.wav2vec2bert.parameters())
269
- llama_params = sum(p.numel() for p in model.llama_nar.parameters())
270
- other_params = sum(p.numel() for name, p in model.named_parameters()
271
- if not name.startswith('wav2vec2bert.') and not name.startswith('llama_nar.'))
272
-
273
- total_params = wav2vec2bert_params + llama_params + other_params
274
-
275
- print(f"\n=== 参数量统计 ===")
276
- print(f"Wav2Vec2BERT参数量: {wav2vec2bert_params:,} ({wav2vec2bert_params/1024/1024:.2f}MB)")
277
- print(f"LlamaNAR参数量: {llama_params:,} ({llama_params/1024/1024:.2f}MB)")
278
- print(f"其他模块参数量: {other_params:,} ({other_params/1024/1024:.2f}MB)")
279
- print(f"总参数量: {total_params:,} ({total_params/1024/1024:.2f}MB)")
280
-
281
- # 计算百分比
282
- print(f"\n=== 参数量占比 ===")
283
- print(f"Wav2Vec2BERT: {wav2vec2bert_params/total_params*100:.2f}%")
284
- print(f"LlamaNAR: {llama_params/total_params*100:.2f}%")
285
- print(f"其他模块: {other_params/total_params*100:.2f}%")
286
-
287
- # 测试运行时间和内存使用
288
- print("\n=== 测试运行时间和内存使用 (batch_size=4) ===")
289
- batch_size = 4
290
- total_samples = 600000
291
-
292
- # 清空GPU缓存
293
- if torch.cuda.is_available():
294
- torch.cuda.empty_cache()
295
- initial_memory = torch.cuda.memory_allocated() / 1024 / 1024
296
- print(f"初始GPU内存使用: {initial_memory:.2f}MB")
297
-
298
- # 初始化数据集
299
- print("\n初始化数据集...")
300
- ds = train_MultiDataset(max_prompts=3)
301
-
302
- # 创建DataLoader
303
- dl = DataLoader(ds,
304
- batch_size=batch_size,
305
- shuffle=True,
306
- collate_fn=collate_fn,
307
- num_workers=4)
308
-
309
- print(f"\n数据集大小: {len(ds)}")
310
- print(f"批次数量: {len(dl)}")
311
-
312
- # 计算一个batch的平均时间
313
- num_test_batches = 10
314
- total_time = 0
315
- max_memory = 0
316
-
317
- print(f"\n测试{num_test_batches}个batch的平均运行时间...")
318
- with torch.no_grad():
319
- for i, batch in enumerate(tqdm(dl, total=num_test_batches)):
320
- if i >= num_test_batches:
321
- break
322
-
323
- # 正确处理字典类型的特征
324
- main_features = {
325
- 'input_features': batch['main_features']['input_features'].to(device),
326
- 'attention_mask': batch['main_features']['attention_mask'].to(device)
327
- }
328
-
329
- prompt_features = [{
330
- 'input_features': pf['input_features'].to(device),
331
- 'attention_mask': pf['attention_mask'].to(device)
332
- } for pf in batch['prompt_features']]
333
-
334
- labels = batch['labels'].to(device)
335
- prompt_labels = batch['prompt_labels'].to(device)
336
-
337
- # 记录开始时间
338
- start_time = time.time()
339
-
340
- # 前向传播
341
- outputs = model({
342
- 'main_features': main_features,
343
- 'prompt_features': prompt_features,
344
- 'prompt_labels': prompt_labels
345
- })
346
-
347
- # 确保GPU运算完成
348
- if torch.cuda.is_available():
349
- torch.cuda.synchronize()
350
-
351
- # 记录结束时间和内存使用
352
- end_time = time.time()
353
- total_time += (end_time - start_time)
354
-
355
- if torch.cuda.is_available():
356
- current_memory = torch.cuda.memory_allocated() / 1024 / 1024
357
- max_memory = max(max_memory, current_memory)
358
-
359
- # 打印第一个batch的详细信息
360
- if i == 0:
361
- print("\n=== 第一个Batch的详细信息 ===")
362
- print(f"主特征形状: {main_features['input_features'].shape}")
363
- print(f"主掩码形状: {main_features['attention_mask'].shape}")
364
- print(f"Prompt特征形状: {prompt_features[0]['input_features'].shape}")
365
- print(f"Prompt掩码形状: {prompt_features[0]['attention_mask'].shape}")
366
- print(f"标签形状: {labels.shape}")
367
- print(f"Prompt标签形状: {prompt_labels.shape}")
368
- print(f"模型输出形状: {outputs.shape}")
369
- print(f"输出logits范围: [{outputs.min().item():.3f}, {outputs.max().item():.3f}]")
370
-
371
- # 计算和打印统计信息
372
- avg_time = total_time / num_test_batches
373
- print(f"\n=== 性能统计 ===")
374
- print(f"平均每个batch处理时间: {avg_time:.4f}秒")
375
- print(f"估计处理{total_samples}个样本需要: {(total_samples/batch_size*avg_time/3600):.2f}小时")
376
- if torch.cuda.is_available():
377
- print(f"最大GPU内存使用: {max_memory:.2f}MB")
378
- print(f"GPU内存增长: {max_memory - initial_memory:.2f}MB")
379
-
380
- print("\n测试完成!")
 
 
1
  import torch.nn as nn
2
  from transformers import Wav2Vec2BertModel
 
 
 
 
 
3
 
4
 
5
+ class SpoofVerificationModel(nn.Module):
6
+ def __init__(self, w2v_path, num_types=49):
7
+ super(SpoofVerificationModel, self).__init__()
8
 
9
+ self.wav2vec2 = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True)
10
+ self.wav2vec_config = self.wav2vec2.config
11
+
12
+ self.deepfake_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)
13
+ self.type_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)
14
+
15
+ self.deepfake_classifier = nn.Sequential(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  nn.ReLU(),
17
+ nn.Linear(1024, 2)
 
18
  )
19
+ self.type_classifier = nn.Sequential(
 
 
 
 
 
 
 
20
  nn.ReLU(),
21
+ nn.Linear(1024, num_types)
22
  )
23
+ # self.deepfake_classifier = nn.Sequential(
24
+ # nn.Linear(self.wav2vec2.config.hidden_size, 1024),
25
+ # nn.ReLU(),
26
+ # nn.Linear(1024, 2)
27
+ # )
28
 
29
+ # self.type_classifier = nn.Sequential(
30
+ # nn.Linear(self.wav2vec2.config.hidden_size, 1024),
31
+ # nn.ReLU(),
32
+ # nn.Linear(1024, num_types)
33
+ # )
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
35
 
 
 
 
 
36
 
37
+ def forward(self, audio_features):
38
+
39
+ audio_features = self.wav2vec2(**audio_features) # [B, T, D]
40
+ audio_features = audio_features.last_hidden_state # (B, T, D)
41
+ audio_features = audio_features.mean(dim=1) # (B, D)
42
+
43
+ # deepfake_logits = self.deepfake_classifier(audio_features)
44
+ # type_logits = self.type_classifier(audio_features)
45
+
46
+ deepfake_emb = self.deepfake_embed(audio_features)
47
+ type_emb = self.type_embed(audio_features)
48
+ deepfake_logits = self.deepfake_classifier(deepfake_emb)
49
+ type_logits = self.type_classifier(type_emb)
50
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return {
52
+ 'deepfake_logits': deepfake_logits,
53
+ 'type_logits': type_logits,
54
+ 'embeddings': audio_features,
55
+ 'deepfake_embed': deepfake_emb, # 新增embedding输出
56
+ 'type_embed': type_emb # 新增embedding输出
57
  }
58
 
59
+ # return {
60
+ # 'deepfake_logits': deepfake_logits,
61
+ # 'type_logits': type_logits,
62
+ # 'embeddings': audio_features
63
+ # }
64
+
65
+ def print_parameters_info(self):
66
+ print(f"wav2vec2 parameters: {sum(p.numel() for p in self.wav2vec2.parameters())/1e6:.2f}M")
67
+ print(f"deepfake_classifier parameters: {sum(p.numel() for p in self.deepfake_classifier.parameters())/1e6:.2f}M")
68
+ print(f"type_classifier parameters: {sum(p.numel() for p in self.type_classifier.parameters())/1e6:.2f}M")
69