ASR_Model_Comparison / processing.py
j-tobias
updated backend
09b2769
raw
history blame
6.83 kB
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
import plotly.graph_objs as go
from datasets import load_dataset
from datasets import Audio
from transformers import pipeline
import evaluate
import librosa
import numpy as np
wer_metric = evaluate.load("wer")
def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str):
if data_subset is None:
raise ValueError("No Dataset selected")
if model_1 is None:
raise ValueError("No Model 1 selected")
if model_2 is None:
raise ValueError("No Model 2 selected")
if data_subset == "Common Voice":
dataset, text_column = load_Common_Voice()
elif data_subset == "VoxPopuli":
dataset, text_column = load_Vox_Populi()
elif data_subset == "OWN Recoding/Sample":
sr, audio = own_audio
audio = audio.astype(np.float32) / 32768.0
print("AUDIO: ", type(audio), audio)
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
else:
# if data_subset is None then still load load_Common_Voice
dataset, text_column = load_Common_Voice()
print("Dataset Loaded")
# check if models are the same
model1, processor1 = load_model(model_1)
model2, processor2 = load_model(model_2)
print("Models Loaded")
if data_subset == "OWN Recoding/Sample":
sample = {"audio":{"array":audio,"sampling_rate":16000}}
transcription1 = model_compute(model1, processor1, sample, model_1)
transcription2 = model_compute(model2, processor2, sample, model_2)
transcriptions1 = [transcription1]
transcriptions2 = [transcription2]
references = [own_transcription]
wer1 = compute_wer(references, transcriptions1)
wer2 = compute_wer(references, transcriptions2)
results_md = f"""#### {model_1}
- WER Score: {wer1}
#### {model_2}
- WER Score: {wer2}"""
# Create the bar plot
fig = go.Figure(
data=[
go.Bar(x=[f"{model_1}"], y=[wer1]),
go.Bar(x=[f"{model_2}"], y=[wer2]),
]
)
# Update the layout for better visualization
fig.update_layout(
title="Comparison of Two Models",
xaxis_title="Models",
yaxis_title="Value",
barmode="group",
)
yield results_md, fig
else:
references = []
transcriptions1 = []
transcriptions2 = []
counter = 0
for sample in dataset:
print(counter)
counter += 1
references.append(sample[text_column])
if model_1 == model_2:
transcription = model_compute(model1, processor1, sample, model_1)
transcriptions1.append(transcription)
transcriptions2.append(transcription)
else:
transcriptions1.append(model_compute(model1, processor1, sample, model_1))
transcriptions2.append(model_compute(model2, processor2, sample, model_2))
wer1 = compute_wer(references, transcriptions1)
wer2 = compute_wer(references, transcriptions2)
results_md = f"""#### {model_1}
- WER Score: {wer1}
#### {model_2}
- WER Score: {wer2}"""
# Create the bar plot
fig = go.Figure(
data=[
go.Bar(x=[f"{model_1}"], y=[wer1]),
go.Bar(x=[f"{model_2}"], y=[wer2]),
]
)
# Update the layout for better visualization
fig.update_layout(
title="Comparison of Two Models",
xaxis_title="Models",
yaxis_title="Value",
barmode="group",
)
yield results_md, fig
# DATASET LOADERS
def load_Common_Voice():
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
text_column = "sentence"
dataset = dataset.take(100)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = list(dataset)
return dataset, text_column
def load_Vox_Populi():
dataset = dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
print(next(iter(dataset)))
text_column = "raw_text"
dataset = dataset.take(100)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = list(dataset)
return dataset, text_column
# MODEL LOADERS
def load_model(model_id:str):
if model_id == "openai/whisper-tiny.en":
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
elif model_id == "facebook/s2t-medium-librispeech-asr":
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
else:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
return model, processor
# MODEL INFERENCE
def model_compute(model, processor, sample, model_id):
if model_id == "openai/whisper-tiny.en":
sample = sample["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
elif model_id == "facebook/s2t-medium-librispeech-asr":
sample = sample["audio"]
features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
input_features = features.input_features
attention_mask = features.attention_mask
gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
return transcription[0]
else:
return model(sample)
# UTILS
def compute_wer(references, predictions):
wer = wer_metric.compute(references=references, predictions=predictions)
wer = round(100 * wer, 2)
return wer
# print(load_Vox_Populi())
# print(run("Common Voice", "openai/whisper-tiny.en", "openai/whisper-tiny.en", None, None))