File size: 4,401 Bytes
ecc69a8
752ce9b
 
 
f3d14a8
ecc69a8
752ce9b
 
 
 
 
 
 
 
 
 
 
 
ecc69a8
 
752ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3d14a8
 
752ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3d14a8
 
752ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61ba593
 
 
752ce9b
 
61ba593
752ce9b
f3d14a8
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
from transformers import pipeline

# import nemo.collections.asr as nemo_asr

from dataset import Dataset
from utils import data



class Model:


    def __init__(self):

        self.options = [
            "openai/whisper-tiny.en",
            "facebook/s2t-medium-librispeech-asr",
            "nvidia/stt_en_fastconformer_ctc_large"
        ]
        self.selected = None
        self.pipeline = None
        self.normalize = None

    def get_options(self):
        return self.options

    def load(self, option:str = None):

        if option is None:
            if self.selected is None:
                raise ValueError("No model selected. Please first select a model")
            option = self.selected

        if option not in self.options:
            raise ValueError(f"Selected Option is not a valid value, see: {self.options}")

        if option == "openai/whisper-tiny.en":
            self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0)
            self.normalize = self.pipeline.tokenizer.normalize
            
        elif option == "facebook/s2t-medium-librispeech-asr":
            self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
            self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)

        # elif option == "nvidia/stt_en_fastconformer_ctc_large":
        #     self.model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/stt_en_fastconformer_ctc_large")

    def select(self, option:str=None):
        if option not in self.options:
            raise ValueError(f"This value is not an option, please see: {self.options}")
        self.selected = option

    def process(self, dataset:Dataset):

        if self.selected is None:
            raise ValueError("No Model is yet selected. Please select a model first")
        
        if self.selected == "openai/whisper-tiny.en":
            references, predictions = self._process_openai_whisper_tiny_en(dataset)
        elif self.selected == "facebook/s2t-medium-librispeech-asr":
            references, predictions = self._process_facebook_s2t_medium(dataset)
        # elif self.selected == "nvidia/stt_en_fastconformer_ctc_large":
        #     references, predictions = self._process_facebook_s2t_medium(dataset)

        return references, predictions

    def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset):

        def normalise(batch):
            batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch))
            return batch
        
        DaTaSeT.normalised(normalise)
        dataset = DaTaSeT.filter("norm_text")

        predictions = []
        references = []

        # run streamed inference
        for out in self.pipeline(data(dataset), batch_size=16):
            predictions.append(self.normalize(out["text"]))
            references.append(out["reference"][0])

        return references, predictions
    
    def _process_facebook_s2t_medium(self, DaTaSeT:Dataset):

        def map_to_pred(batch):
            features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
            input_features = features.input_features
            attention_mask = features.attention_mask

            gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask)
            batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
            return batch
        
        DaTaSeT.dataset = DaTaSeT.dataset.take(100)
        result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"])

        predictions = []
        references = []

        DaTaSeT._check_text()
        text_column = DaTaSeT.text

        for sample in result:
            predictions.append(sample['transcription'])
            references.append(sample[text_column])

        return references, predictions
    
    def _process_stt_en_fastconformer_ctc_large(self, DaTaSeT:Dataset):


        self.model.transcribe(['2086-149220-0033.wav'])

        predictions = []
        references = []

        return references, predictions