j-tobias
cleaned
f3d14a8
raw
history blame
4.4 kB
# from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
from transformers import pipeline
# import nemo.collections.asr as nemo_asr
from dataset import Dataset
from utils import data
class Model:
def __init__(self):
self.options = [
"openai/whisper-tiny.en",
"facebook/s2t-medium-librispeech-asr",
"nvidia/stt_en_fastconformer_ctc_large"
]
self.selected = None
self.pipeline = None
self.normalize = None
def get_options(self):
return self.options
def load(self, option:str = None):
if option is None:
if self.selected is None:
raise ValueError("No model selected. Please first select a model")
option = self.selected
if option not in self.options:
raise ValueError(f"Selected Option is not a valid value, see: {self.options}")
if option == "openai/whisper-tiny.en":
self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0)
self.normalize = self.pipeline.tokenizer.normalize
elif option == "facebook/s2t-medium-librispeech-asr":
self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
# elif option == "nvidia/stt_en_fastconformer_ctc_large":
# self.model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/stt_en_fastconformer_ctc_large")
def select(self, option:str=None):
if option not in self.options:
raise ValueError(f"This value is not an option, please see: {self.options}")
self.selected = option
def process(self, dataset:Dataset):
if self.selected is None:
raise ValueError("No Model is yet selected. Please select a model first")
if self.selected == "openai/whisper-tiny.en":
references, predictions = self._process_openai_whisper_tiny_en(dataset)
elif self.selected == "facebook/s2t-medium-librispeech-asr":
references, predictions = self._process_facebook_s2t_medium(dataset)
# elif self.selected == "nvidia/stt_en_fastconformer_ctc_large":
# references, predictions = self._process_facebook_s2t_medium(dataset)
return references, predictions
def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset):
def normalise(batch):
batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch))
return batch
DaTaSeT.normalised(normalise)
dataset = DaTaSeT.filter("norm_text")
predictions = []
references = []
# run streamed inference
for out in self.pipeline(data(dataset), batch_size=16):
predictions.append(self.normalize(out["text"]))
references.append(out["reference"][0])
return references, predictions
def _process_facebook_s2t_medium(self, DaTaSeT:Dataset):
def map_to_pred(batch):
features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
input_features = features.input_features
attention_mask = features.attention_mask
gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask)
batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
return batch
DaTaSeT.dataset = DaTaSeT.dataset.take(100)
result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"])
predictions = []
references = []
DaTaSeT._check_text()
text_column = DaTaSeT.text
for sample in result:
predictions.append(sample['transcription'])
references.append(sample[text_column])
return references, predictions
def _process_stt_en_fastconformer_ctc_large(self, DaTaSeT:Dataset):
self.model.transcribe(['2086-149220-0033.wav'])
predictions = []
references = []
return references, predictions