File size: 3,728 Bytes
752ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61ba593
 
 
752ce9b
 
61ba593
752ce9b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
from transformers import pipeline

from dataset import Dataset
from utils import data



class Model:


    def __init__(self):

        self.options = [
            "openai/whisper-tiny.en",
            "facebook/s2t-medium-librispeech-asr"
        ]
        self.selected = None
        self.pipeline = None
        self.normalize = None

    def get_options(self):
        return self.options

    def load(self, option:str = None):

        if option is None:
            if self.selected is None:
                raise ValueError("No model selected. Please first select a model")
            option = self.selected

        if option not in self.options:
            raise ValueError(f"Selected Option is not a valid value, see: {self.options}")

        if option == "openai/whisper-tiny.en":
            self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0)
            self.normalize = self.pipeline.tokenizer.normalize
            
        elif option == "facebook/s2t-medium-librispeech-asr":
            self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
            self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)


    def select(self, option:str=None):
        if option not in self.options:
            raise ValueError(f"This value is not an option, please see: {self.options}")
        self.selected = option

    def process(self, dataset:Dataset):

        if self.selected is None:
            raise ValueError("No Model is yet selected. Please select a model first")
        
        if self.selected == "openai/whisper-tiny.en":
            references, predictions = self._process_openai_whisper_tiny_en(dataset)
        elif self.selected == "facebook/s2t-medium-librispeech-asr":
            references, predictions = self._process_facebook_s2t_medium(dataset)

        return references, predictions

    def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset):

        def normalise(batch):
            batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch))
            return batch
        
        DaTaSeT.normalised(normalise)
        dataset = DaTaSeT.filter("norm_text")

        predictions = []
        references = []

        # run streamed inference
        for out in self.pipeline(data(dataset), batch_size=16):
            predictions.append(self.normalize(out["text"]))
            references.append(out["reference"][0])

        return references, predictions
    
    def _process_facebook_s2t_medium(self, DaTaSeT:Dataset):



        def map_to_pred(batch):
            features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
            input_features = features.input_features
            attention_mask = features.attention_mask

            gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask)
            batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
            return batch
        
        DaTaSeT.dataset = DaTaSeT.dataset.take(100)
        result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"])

        predictions = []
        references = []

        DaTaSeT._check_text()
        text_column = DaTaSeT.text

        for sample in result:
            predictions.append(sample['transcription'])
            references.append(sample[text_column])

        return references, predictions