# from transformers import WhisperProcessor, WhisperForConditionalGeneration from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor from transformers import pipeline # import nemo.collections.asr as nemo_asr from dataset import Dataset from utils import data class Model: def __init__(self): self.options = [ "openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr", "nvidia/stt_en_fastconformer_ctc_large" ] self.selected = None self.pipeline = None self.normalize = None def get_options(self): return self.options def load(self, option:str = None): if option is None: if self.selected is None: raise ValueError("No model selected. Please first select a model") option = self.selected if option not in self.options: raise ValueError(f"Selected Option is not a valid value, see: {self.options}") if option == "openai/whisper-tiny.en": self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0) self.normalize = self.pipeline.tokenizer.normalize elif option == "facebook/s2t-medium-librispeech-asr": self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr") self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True) # elif option == "nvidia/stt_en_fastconformer_ctc_large": # self.model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/stt_en_fastconformer_ctc_large") def select(self, option:str=None): if option not in self.options: raise ValueError(f"This value is not an option, please see: {self.options}") self.selected = option def process(self, dataset:Dataset): if self.selected is None: raise ValueError("No Model is yet selected. Please select a model first") if self.selected == "openai/whisper-tiny.en": references, predictions = self._process_openai_whisper_tiny_en(dataset) elif self.selected == "facebook/s2t-medium-librispeech-asr": references, predictions = self._process_facebook_s2t_medium(dataset) # elif self.selected == "nvidia/stt_en_fastconformer_ctc_large": # references, predictions = self._process_facebook_s2t_medium(dataset) return references, predictions def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset): def normalise(batch): batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch)) return batch DaTaSeT.normalised(normalise) dataset = DaTaSeT.filter("norm_text") predictions = [] references = [] # run streamed inference for out in self.pipeline(data(dataset), batch_size=16): predictions.append(self.normalize(out["text"])) references.append(out["reference"][0]) return references, predictions def _process_facebook_s2t_medium(self, DaTaSeT:Dataset): def map_to_pred(batch): features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt") input_features = features.input_features attention_mask = features.attention_mask gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask) batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0] return batch DaTaSeT.dataset = DaTaSeT.dataset.take(100) result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"]) predictions = [] references = [] DaTaSeT._check_text() text_column = DaTaSeT.text for sample in result: predictions.append(sample['transcription']) references.append(sample[text_column]) return references, predictions def _process_stt_en_fastconformer_ctc_large(self, DaTaSeT:Dataset): self.model.transcribe(['2086-149220-0033.wav']) predictions = [] references = [] return references, predictions