File size: 3,066 Bytes
17c79b3 958473e fa98f1c 958473e 17c79b3 958473e 17c79b3 958473e fa98f1c de6da40 fa98f1c 958473e 17c79b3 de6da40 17c79b3 fa98f1c de6da40 17c79b3 de6da40 958473e c9d4907 fa98f1c 17c79b3 fa98f1c 17c79b3 fa98f1c 17c79b3 de6da40 17c79b3 de6da40 17c79b3 958473e fa98f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from collections import OrderedDict
from typing import List, Union, Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
HIDDEN_DIM = 8
class Model(nn.Module):
def __init__(self):
super().__init__()
# The model needs to be a nn.Module for finetuning, not required for representation extraction
self.model1 = nn.Linear(1, HIDDEN_DIM)
self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
def forward(self, wavs, upstream_feature_selection="hidden_states"):
# You can do task-specified pre- / post-processing based on upstream_feature_selection
hidden = self.model1(wavs)
# hidden: (batch_size, max_len, hidden_dim)
feature = self.model2(hidden)
# feature: (batch_size, max_len, hidden_dim)
return [hidden, feature]
class UpstreamExpert(nn.Module):
def __init__(
self,
ckpt: str = "./model.pt",
upstream_feature_selection: str = "hidden_states",
**kwargs):
"""
Args:
ckpt:
The checkpoint path for loading your pretrained weights.
Should be fixed as model.pt for SUPERB Challenge.
upstream_feature_selection:
The value could be
'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks).
You can use it to control which task-specified pre- / post-processing to do.
"""
super().__init__()
self.name = "[Example UpstreamExpert]"
self.upstream_feature_selection = upstream_feature_selection
# You can use ckpt to load your pretrained weights
ckpt = torch.load(ckpt, map_location="cpu")
self.model = Model()
self.model.load_state_dict(ckpt)
def get_downsample_rates(self, key: str) -> int:
"""
Since we do not do any downsampling in this example upstream
All keys' corresponding representations have downsample rate of 1
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
"""
return 1
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
"""
When the returning Dict contains the List with more than one Tensor,
those Tensors should be in the same shape to train a weighted-sum on them.
"""
wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1)
# wavs: (batch_size, max_len, 1)
hidden_states = self.model(wavs, upstream_feature_selection=self.upstream_feature_selection)
# Deprecated! Do not do any task-specified postprocess below
# You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do.
# The "hidden_states" key will be used as default in many cases
# Others keys in this example are presented for SUPERB Challenge
return {
"hidden_states": hidden_states,
}
|