|
|
|
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor |
|
from transformers import pipeline |
|
|
|
|
|
|
|
from dataset import Dataset |
|
from utils import data |
|
|
|
|
|
|
|
class Model: |
|
|
|
|
|
def __init__(self): |
|
|
|
self.options = [ |
|
"openai/whisper-tiny.en", |
|
"facebook/s2t-medium-librispeech-asr", |
|
"nvidia/stt_en_fastconformer_ctc_large" |
|
] |
|
self.selected = None |
|
self.pipeline = None |
|
self.normalize = None |
|
|
|
def get_options(self): |
|
return self.options |
|
|
|
def load(self, option:str = None): |
|
|
|
if option is None: |
|
if self.selected is None: |
|
raise ValueError("No model selected. Please first select a model") |
|
option = self.selected |
|
|
|
if option not in self.options: |
|
raise ValueError(f"Selected Option is not a valid value, see: {self.options}") |
|
|
|
if option == "openai/whisper-tiny.en": |
|
self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0) |
|
self.normalize = self.pipeline.tokenizer.normalize |
|
|
|
elif option == "facebook/s2t-medium-librispeech-asr": |
|
self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr") |
|
self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True) |
|
|
|
|
|
|
|
|
|
def select(self, option:str=None): |
|
if option not in self.options: |
|
raise ValueError(f"This value is not an option, please see: {self.options}") |
|
self.selected = option |
|
|
|
def process(self, dataset:Dataset): |
|
|
|
if self.selected is None: |
|
raise ValueError("No Model is yet selected. Please select a model first") |
|
|
|
if self.selected == "openai/whisper-tiny.en": |
|
references, predictions = self._process_openai_whisper_tiny_en(dataset) |
|
elif self.selected == "facebook/s2t-medium-librispeech-asr": |
|
references, predictions = self._process_facebook_s2t_medium(dataset) |
|
|
|
|
|
|
|
return references, predictions |
|
|
|
def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset): |
|
|
|
def normalise(batch): |
|
batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch)) |
|
return batch |
|
|
|
DaTaSeT.normalised(normalise) |
|
dataset = DaTaSeT.filter("norm_text") |
|
|
|
predictions = [] |
|
references = [] |
|
|
|
|
|
for out in self.pipeline(data(dataset), batch_size=16): |
|
predictions.append(self.normalize(out["text"])) |
|
references.append(out["reference"][0]) |
|
|
|
return references, predictions |
|
|
|
def _process_facebook_s2t_medium(self, DaTaSeT:Dataset): |
|
|
|
def map_to_pred(batch): |
|
features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt") |
|
input_features = features.input_features |
|
attention_mask = features.attention_mask |
|
|
|
gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask) |
|
batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0] |
|
return batch |
|
|
|
DaTaSeT.dataset = DaTaSeT.dataset.take(100) |
|
result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"]) |
|
|
|
predictions = [] |
|
references = [] |
|
|
|
DaTaSeT._check_text() |
|
text_column = DaTaSeT.text |
|
|
|
for sample in result: |
|
predictions.append(sample['transcription']) |
|
references.append(sample[text_column]) |
|
|
|
return references, predictions |
|
|
|
def _process_stt_en_fastconformer_ctc_large(self, DaTaSeT:Dataset): |
|
|
|
|
|
self.model.transcribe(['2086-149220-0033.wav']) |
|
|
|
predictions = [] |
|
references = [] |
|
|
|
return references, predictions |
|
|