File size: 2,603 Bytes
34146f0
 
 
 
b82b421
c63d90f
b82b421
34146f0
c63d90f
b82b421
 
 
 
 
 
34146f0
b82b421
34146f0
b82b421
34146f0
b82b421
34146f0
b82b421
 
 
 
 
34146f0
b82b421
 
 
 
 
34146f0
 
 
b82b421
 
 
 
 
 
 
 
 
 
 
 
 
 
34146f0
b82b421
 
 
 
 
34146f0
 
b82b421
 
 
 
 
 
 
 
 
 
34146f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch.nn as nn
from transformers import Wav2Vec2BertModel


class SpoofVerificationModel(nn.Module):
    def __init__(self, w2v_path='facebook/w2v-bert-2.0', num_types=59):
        super(SpoofVerificationModel, self).__init__()

        self.wav2vec2 = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True)
        self.wav2vec_config = self.wav2vec2.config

        self.deepfake_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)
        self.type_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024)

        self.deepfake_classifier = nn.Sequential(
            nn.ReLU(),
            nn.Linear(1024, 2)
        )
        self.type_classifier = nn.Sequential(
            nn.ReLU(),
            nn.Linear(1024, num_types)
        )
        # self.deepfake_classifier = nn.Sequential(
        #     nn.Linear(self.wav2vec2.config.hidden_size, 1024),
        #     nn.ReLU(),
        #     nn.Linear(1024, 2)
        # )

        # self.type_classifier = nn.Sequential(
        #     nn.Linear(self.wav2vec2.config.hidden_size, 1024),
        #     nn.ReLU(),
        #     nn.Linear(1024, num_types)
        # )

        

    def forward(self, audio_features):

        audio_features = self.wav2vec2(**audio_features) # [B, T, D]
        audio_features = audio_features.last_hidden_state # (B, T, D)
        audio_features = audio_features.mean(dim=1) # (B, D)

        # deepfake_logits = self.deepfake_classifier(audio_features)
        # type_logits = self.type_classifier(audio_features)

        deepfake_emb = self.deepfake_embed(audio_features)
        type_emb = self.type_embed(audio_features)
        deepfake_logits = self.deepfake_classifier(deepfake_emb)
        type_logits = self.type_classifier(type_emb)

        return {
            'deepfake_logits': deepfake_logits,
            'type_logits': type_logits,
            'embeddings': audio_features,
            'deepfake_embed': deepfake_emb,  # 新增embedding输出
            'type_embed': type_emb           # 新增embedding输出
        }

        # return {
        #     'deepfake_logits': deepfake_logits,
        #     'type_logits': type_logits,
        #     'embeddings': audio_features
        # }

    def print_parameters_info(self):
        print(f"wav2vec2 parameters: {sum(p.numel() for p in self.wav2vec2.parameters())/1e6:.2f}M")
        print(f"deepfake_classifier parameters: {sum(p.numel() for p in self.deepfake_classifier.parameters())/1e6:.2f}M")
        print(f"type_classifier parameters: {sum(p.numel() for p in self.type_classifier.parameters())/1e6:.2f}M")