File size: 3,066 Bytes
17c79b3
 
958473e
fa98f1c
958473e
17c79b3
958473e
 
17c79b3
958473e
fa98f1c
 
 
 
 
 
 
de6da40
 
fa98f1c
 
 
 
 
 
 
958473e
17c79b3
de6da40
 
 
 
 
17c79b3
 
 
 
fa98f1c
de6da40
 
 
 
17c79b3
 
 
de6da40
958473e
c9d4907
fa98f1c
 
 
17c79b3
 
 
 
 
fa98f1c
17c79b3
 
 
fa98f1c
17c79b3
 
 
 
 
 
 
 
de6da40
17c79b3
de6da40
 
17c79b3
 
958473e
fa98f1c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from collections import OrderedDict
from typing import List, Union, Dict

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

HIDDEN_DIM = 8

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # The model needs to be a nn.Module for finetuning, not required for representation extraction
        self.model1 = nn.Linear(1, HIDDEN_DIM)
        self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)

    def forward(self, wavs, upstream_feature_selection="hidden_states"):
        # You can do task-specified pre- / post-processing based on upstream_feature_selection
        hidden = self.model1(wavs)
        # hidden: (batch_size, max_len, hidden_dim)

        feature = self.model2(hidden)
        # feature: (batch_size, max_len, hidden_dim)

        return [hidden, feature]

class UpstreamExpert(nn.Module):
    def __init__(
        self,
        ckpt: str = "./model.pt",
        upstream_feature_selection: str = "hidden_states",
        **kwargs):
        """
        Args:
            ckpt:
                The checkpoint path for loading your pretrained weights.
                Should be fixed as model.pt for SUPERB Challenge.
            upstream_feature_selection:
                The value could be 
                'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks).
                You can use it to control which task-specified pre- / post-processing to do.
        """
        super().__init__()
        self.name = "[Example UpstreamExpert]"
        self.upstream_feature_selection = upstream_feature_selection

        # You can use ckpt to load your pretrained weights
        ckpt = torch.load(ckpt, map_location="cpu")
        self.model = Model()
        self.model.load_state_dict(ckpt)

    def get_downsample_rates(self, key: str) -> int:
        """
        Since we do not do any downsampling in this example upstream
        All keys' corresponding representations have downsample rate of 1
        Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
        """
        return 1

    def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
        """
        When the returning Dict contains the List with more than one Tensor,
        those Tensors should be in the same shape to train a weighted-sum on them.
        """

        wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1)
        # wavs: (batch_size, max_len, 1)

        hidden_states = self.model(wavs, upstream_feature_selection=self.upstream_feature_selection)

        # Deprecated! Do not do any task-specified postprocess below
        # You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do.
        # The "hidden_states" key will be used as default in many cases
        # Others keys in this example are presented for SUPERB Challenge
        return {
            "hidden_states": hidden_states,
        }