j-tobias
added Model Cards
61ba593
raw
history blame
3.73 kB
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
from transformers import pipeline
from dataset import Dataset
from utils import data
class Model:
def __init__(self):
self.options = [
"openai/whisper-tiny.en",
"facebook/s2t-medium-librispeech-asr"
]
self.selected = None
self.pipeline = None
self.normalize = None
def get_options(self):
return self.options
def load(self, option:str = None):
if option is None:
if self.selected is None:
raise ValueError("No model selected. Please first select a model")
option = self.selected
if option not in self.options:
raise ValueError(f"Selected Option is not a valid value, see: {self.options}")
if option == "openai/whisper-tiny.en":
self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0)
self.normalize = self.pipeline.tokenizer.normalize
elif option == "facebook/s2t-medium-librispeech-asr":
self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
def select(self, option:str=None):
if option not in self.options:
raise ValueError(f"This value is not an option, please see: {self.options}")
self.selected = option
def process(self, dataset:Dataset):
if self.selected is None:
raise ValueError("No Model is yet selected. Please select a model first")
if self.selected == "openai/whisper-tiny.en":
references, predictions = self._process_openai_whisper_tiny_en(dataset)
elif self.selected == "facebook/s2t-medium-librispeech-asr":
references, predictions = self._process_facebook_s2t_medium(dataset)
return references, predictions
def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset):
def normalise(batch):
batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch))
return batch
DaTaSeT.normalised(normalise)
dataset = DaTaSeT.filter("norm_text")
predictions = []
references = []
# run streamed inference
for out in self.pipeline(data(dataset), batch_size=16):
predictions.append(self.normalize(out["text"]))
references.append(out["reference"][0])
return references, predictions
def _process_facebook_s2t_medium(self, DaTaSeT:Dataset):
def map_to_pred(batch):
features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
input_features = features.input_features
attention_mask = features.attention_mask
gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask)
batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
return batch
DaTaSeT.dataset = DaTaSeT.dataset.take(100)
result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"])
predictions = []
references = []
DaTaSeT._check_text()
text_column = DaTaSeT.text
for sample in result:
predictions.append(sample['transcription'])
references.append(sample[text_column])
return references, predictions