Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,603 Bytes
34146f0 b82b421 c63d90f b82b421 34146f0 c63d90f b82b421 34146f0 b82b421 34146f0 b82b421 34146f0 b82b421 34146f0 b82b421 34146f0 b82b421 34146f0 b82b421 34146f0 b82b421 34146f0 b82b421 34146f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch.nn as nn
from transformers import Wav2Vec2BertModel
class SpoofVerificationModel(nn.Module):
def __init__(self, w2v_path='facebook/w2v-bert-2.0', num_types=59):
super(SpoofVerificationModel, self).__init__()
self.wav2vec2 = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True)
self.wav2vec_config = self.wav2vec2.config
self.deepfake_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)
self.type_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)
self.deepfake_classifier = nn.Sequential(
nn.ReLU(),
nn.Linear(1024, 2)
)
self.type_classifier = nn.Sequential(
nn.ReLU(),
nn.Linear(1024, num_types)
)
# self.deepfake_classifier = nn.Sequential(
# nn.Linear(self.wav2vec2.config.hidden_size, 1024),
# nn.ReLU(),
# nn.Linear(1024, 2)
# )
# self.type_classifier = nn.Sequential(
# nn.Linear(self.wav2vec2.config.hidden_size, 1024),
# nn.ReLU(),
# nn.Linear(1024, num_types)
# )
def forward(self, audio_features):
audio_features = self.wav2vec2(**audio_features) # [B, T, D]
audio_features = audio_features.last_hidden_state # (B, T, D)
audio_features = audio_features.mean(dim=1) # (B, D)
# deepfake_logits = self.deepfake_classifier(audio_features)
# type_logits = self.type_classifier(audio_features)
deepfake_emb = self.deepfake_embed(audio_features)
type_emb = self.type_embed(audio_features)
deepfake_logits = self.deepfake_classifier(deepfake_emb)
type_logits = self.type_classifier(type_emb)
return {
'deepfake_logits': deepfake_logits,
'type_logits': type_logits,
'embeddings': audio_features,
'deepfake_embed': deepfake_emb, # 新增embedding输出
'type_embed': type_emb # 新增embedding输出
}
# return {
# 'deepfake_logits': deepfake_logits,
# 'type_logits': type_logits,
# 'embeddings': audio_features
# }
def print_parameters_info(self):
print(f"wav2vec2 parameters: {sum(p.numel() for p in self.wav2vec2.parameters())/1e6:.2f}M")
print(f"deepfake_classifier parameters: {sum(p.numel() for p in self.deepfake_classifier.parameters())/1e6:.2f}M")
print(f"type_classifier parameters: {sum(p.numel() for p in self.type_classifier.parameters())/1e6:.2f}M")
|