j-tobias
commited on
Commit
·
09b2769
1
Parent(s):
e3bf44e
updated backend
Browse files- .codetogether.ignore +0 -1
- __pycache__/dataset.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/processing.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +42 -63
- dataset.py +0 -93
- model.py +0 -122
- processing.py +194 -0
- utils.py +11 -0
.codetogether.ignore
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
credentials.json
|
|
|
|
|
|
__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
__pycache__/processing.cpython-310.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from
|
| 3 |
-
from model import Model
|
| 4 |
-
from utils import compute_wer
|
| 5 |
-
import plotly.graph_objs as go
|
| 6 |
-
|
| 7 |
|
| 8 |
# from utils import hf_login
|
| 9 |
# hf_login()
|
|
@@ -14,73 +10,58 @@ import os
|
|
| 14 |
hf_token = os.getenv("HF_Token")
|
| 15 |
login(hf_token)
|
| 16 |
|
| 17 |
-
dataset = Dataset()
|
| 18 |
-
models = Model()
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
MoDeL.select(model)
|
| 24 |
-
MoDeL.load()
|
| 25 |
-
DaTaSeT = Dataset(100)
|
| 26 |
-
DaTaSeT.load(dataset_choice)
|
| 27 |
-
references, predictions = MoDeL.process(DaTaSeT)
|
| 28 |
-
wer = compute_wer(references=references, predictions=predictions)
|
| 29 |
-
return wer
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
| 34 |
-
wer_result_2 = run_tests(data_subset, model_2)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
def get_card(selected_model:str)->str:
|
| 61 |
|
| 62 |
-
print("Selected Model for Card: ", selected_model)
|
| 63 |
with open("cards.txt", "r") as f:
|
| 64 |
cards = f.read()
|
| 65 |
|
| 66 |
-
print(cards)
|
| 67 |
-
|
| 68 |
cards = cards.split("@@")
|
| 69 |
for card in cards:
|
| 70 |
-
print("CARD: ", card)
|
| 71 |
if "ID: "+selected_model in card:
|
| 72 |
return card
|
| 73 |
|
| 74 |
return "Unknown Model"
|
| 75 |
|
| 76 |
-
def is_own(
|
| 77 |
-
if
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
return
|
| 81 |
-
own_audio = None
|
| 82 |
-
own_transcription = None
|
| 83 |
-
return own_audio, own_transcription
|
| 84 |
|
| 85 |
with gr.Blocks() as demo:
|
| 86 |
|
|
@@ -106,31 +87,29 @@ Happy experimenting and comparing! 🚀""")
|
|
| 106 |
pass
|
| 107 |
with gr.Column(scale=5):
|
| 108 |
data_subset = gr.Radio(
|
| 109 |
-
value="
|
| 110 |
-
choices=
|
| 111 |
label="Data subset / Own Sample",
|
| 112 |
)
|
|
|
|
|
|
|
|
|
|
| 113 |
with gr.Column(scale=1):
|
| 114 |
pass
|
| 115 |
|
| 116 |
-
with gr.Row():
|
| 117 |
-
own_audio = gr.Audio(sources=['microphone'],streaming=False,visible=False)
|
| 118 |
-
own_transcription = gr.TextArea(lines=2, visible=False)
|
| 119 |
-
data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
|
| 120 |
-
|
| 121 |
|
| 122 |
with gr.Row():
|
| 123 |
|
| 124 |
with gr.Column(scale=1):
|
| 125 |
model_1 = gr.Dropdown(
|
| 126 |
-
choices=
|
| 127 |
label="Select Model"
|
| 128 |
)
|
| 129 |
model_1_card = gr.Markdown("")
|
| 130 |
|
| 131 |
with gr.Column(scale=1):
|
| 132 |
model_2 = gr.Dropdown(
|
| 133 |
-
choices=
|
| 134 |
label="Select Model"
|
| 135 |
)
|
| 136 |
model_2_card = gr.Markdown("")
|
|
@@ -148,6 +127,6 @@ Happy experimenting and comparing! 🚀""")
|
|
| 148 |
gr.Markdown('## <p style="text-align: center;">Results</p>')
|
| 149 |
results_md = gr.Markdown("")
|
| 150 |
results_plot = gr.Plot(show_label=False)
|
| 151 |
-
eval_btn.click(
|
| 152 |
|
| 153 |
demo.launch(debug=True)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from processing import run
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
# from utils import hf_login
|
| 5 |
# hf_login()
|
|
|
|
| 10 |
hf_token = os.getenv("HF_Token")
|
| 11 |
login(hf_token)
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
MODEL_OPTIONS = ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr"]
|
| 15 |
+
DATASET_OPTIONS = ["Common Voice", "VoxPopuli", "OWN Recoding/Sample"]
|
| 16 |
+
|
| 17 |
|
| 18 |
+
# def eval(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str)->str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# print("OWN AUDIO: ", type(own_audio), own_audio)
|
| 21 |
|
| 22 |
+
# wer_result_1, wer_result_2, references, transcriptions1, transcriptions2 = run(data_subset, model_1, model_2, own_audio, own_transcription)
|
|
|
|
| 23 |
|
| 24 |
+
# results_md = f"""#### {model_1}
|
| 25 |
+
# - WER Score: {wer_result_1}
|
| 26 |
|
| 27 |
+
# #### {model_2}
|
| 28 |
+
# - WER Score: {wer_result_2}"""
|
| 29 |
+
|
| 30 |
+
# # Create the bar plot
|
| 31 |
+
# fig = go.Figure(
|
| 32 |
+
# data=[
|
| 33 |
+
# go.Bar(x=[f"{model_1}"], y=[wer_result_1]),
|
| 34 |
+
# go.Bar(x=[f"{model_2}"], y=[wer_result_2]),
|
| 35 |
+
# ]
|
| 36 |
+
# )
|
| 37 |
+
|
| 38 |
+
# # Update the layout for better visualization
|
| 39 |
+
# fig.update_layout(
|
| 40 |
+
# title="Comparison of Two Models",
|
| 41 |
+
# xaxis_title="Models",
|
| 42 |
+
# yaxis_title="Value",
|
| 43 |
+
# barmode="group",
|
| 44 |
+
# )
|
| 45 |
+
|
| 46 |
+
# return results_md, fig
|
| 47 |
|
| 48 |
def get_card(selected_model:str)->str:
|
| 49 |
|
|
|
|
| 50 |
with open("cards.txt", "r") as f:
|
| 51 |
cards = f.read()
|
| 52 |
|
|
|
|
|
|
|
| 53 |
cards = cards.split("@@")
|
| 54 |
for card in cards:
|
|
|
|
| 55 |
if "ID: "+selected_model in card:
|
| 56 |
return card
|
| 57 |
|
| 58 |
return "Unknown Model"
|
| 59 |
|
| 60 |
+
def is_own(selected_option):
|
| 61 |
+
if selected_option == "OWN Recoding/Sample":
|
| 62 |
+
return gr.update(visible=True), gr.update(visible=True)
|
| 63 |
+
else:
|
| 64 |
+
return gr.update(visible=False), gr.update(visible=False)
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
with gr.Blocks() as demo:
|
| 67 |
|
|
|
|
| 87 |
pass
|
| 88 |
with gr.Column(scale=5):
|
| 89 |
data_subset = gr.Radio(
|
| 90 |
+
value="Common Voice",
|
| 91 |
+
choices=DATASET_OPTIONS,
|
| 92 |
label="Data subset / Own Sample",
|
| 93 |
)
|
| 94 |
+
own_audio = gr.Audio(visible=False)
|
| 95 |
+
own_transcription = gr.TextArea(lines=2, visible=False)
|
| 96 |
+
data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
|
| 97 |
with gr.Column(scale=1):
|
| 98 |
pass
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
with gr.Row():
|
| 102 |
|
| 103 |
with gr.Column(scale=1):
|
| 104 |
model_1 = gr.Dropdown(
|
| 105 |
+
choices=MODEL_OPTIONS,
|
| 106 |
label="Select Model"
|
| 107 |
)
|
| 108 |
model_1_card = gr.Markdown("")
|
| 109 |
|
| 110 |
with gr.Column(scale=1):
|
| 111 |
model_2 = gr.Dropdown(
|
| 112 |
+
choices=MODEL_OPTIONS,
|
| 113 |
label="Select Model"
|
| 114 |
)
|
| 115 |
model_2_card = gr.Markdown("")
|
|
|
|
| 127 |
gr.Markdown('## <p style="text-align: center;">Results</p>')
|
| 128 |
results_md = gr.Markdown("")
|
| 129 |
results_plot = gr.Plot(show_label=False)
|
| 130 |
+
eval_btn.click(run, [data_subset, model_1, model_2, own_audio, own_transcription], [results_md, results_plot])
|
| 131 |
|
| 132 |
demo.launch(debug=True)
|
dataset.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
| 1 |
-
from datasets import load_dataset
|
| 2 |
-
from datasets import Audio
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class Dataset:
|
| 7 |
-
|
| 8 |
-
def __init__(self, n:int = 100):
|
| 9 |
-
|
| 10 |
-
self.n = n
|
| 11 |
-
self.options = ['LibriSpeech Clean', 'LibriSpeech Other', 'Common Voice', 'VoxPopuli', 'TEDLIUM', 'GigaSpeech', 'SPGISpeech', 'AMI', 'OWN']
|
| 12 |
-
self.selected = None
|
| 13 |
-
self.dataset = None
|
| 14 |
-
self.text = None
|
| 15 |
-
|
| 16 |
-
def get_options(self):
|
| 17 |
-
return self.options
|
| 18 |
-
|
| 19 |
-
def _check_text(self):
|
| 20 |
-
sample = next(iter(self.dataset))
|
| 21 |
-
print(sample)
|
| 22 |
-
self._get_text(sample)
|
| 23 |
-
|
| 24 |
-
def _get_text(self, sample):
|
| 25 |
-
if "text" in sample:
|
| 26 |
-
self.text = "text"
|
| 27 |
-
return sample["text"]
|
| 28 |
-
elif "sentence" in sample:
|
| 29 |
-
self.text = "sentence"
|
| 30 |
-
return sample["sentence"]
|
| 31 |
-
elif "normalized_text" in sample:
|
| 32 |
-
self.text = "normalized_text"
|
| 33 |
-
return sample["normalized_text"]
|
| 34 |
-
elif "transcript" in sample:
|
| 35 |
-
self.text = "transcript"
|
| 36 |
-
return sample["transcript"]
|
| 37 |
-
else:
|
| 38 |
-
raise ValueError(f"Sample: {sample.keys()} has no transcript.")
|
| 39 |
-
|
| 40 |
-
def filter(self, input_column:str = None):
|
| 41 |
-
|
| 42 |
-
if input_column is None:
|
| 43 |
-
if self.text is not None:
|
| 44 |
-
input_column = self.text
|
| 45 |
-
else:
|
| 46 |
-
input_column = self._check_text()
|
| 47 |
-
|
| 48 |
-
def is_target_text_in_range(ref):
|
| 49 |
-
if ref.strip() == "ignore time segment in scoring":
|
| 50 |
-
return False
|
| 51 |
-
else:
|
| 52 |
-
return ref.strip() != ""
|
| 53 |
-
|
| 54 |
-
self.dataset = self.dataset.filter(is_target_text_in_range, input_columns=[input_column])
|
| 55 |
-
return self.dataset
|
| 56 |
-
|
| 57 |
-
def normalised(self, normalise):
|
| 58 |
-
self.dataset = self.dataset.map(normalise)
|
| 59 |
-
|
| 60 |
-
def _select(self, option:str):
|
| 61 |
-
if option not in self.options:
|
| 62 |
-
raise ValueError(f"This value is not an option, please see: {self.options}")
|
| 63 |
-
self.selected = option
|
| 64 |
-
|
| 65 |
-
def _preprocess(self):
|
| 66 |
-
|
| 67 |
-
self.dataset = self.dataset.take(self.n)
|
| 68 |
-
self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=16000))
|
| 69 |
-
|
| 70 |
-
def load(self, option:str = None):
|
| 71 |
-
|
| 72 |
-
self._select(option)
|
| 73 |
-
|
| 74 |
-
if option == "OWN":
|
| 75 |
-
pass
|
| 76 |
-
elif option == "LibriSpeech Clean":
|
| 77 |
-
self.dataset = load_dataset("librispeech_asr", "all", split="test.clean", streaming=True)
|
| 78 |
-
elif option == "LibriSpeech Other":
|
| 79 |
-
self.dataset = load_dataset("librispeech_asr", "all", split="test.other", streaming=True)
|
| 80 |
-
elif option == "Common Voice":
|
| 81 |
-
self.dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
|
| 82 |
-
elif option == "VoxPopuli":
|
| 83 |
-
self.dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
|
| 84 |
-
elif option == "TEDLIUM":
|
| 85 |
-
self.dataset = load_dataset("LIUM/tedlium", "release3", split="test", streaming=True, trust_remote_code=True)
|
| 86 |
-
elif option == "GigaSpeech":
|
| 87 |
-
self.dataset = load_dataset("speechcolab/gigaspeech", "xs", split="test", streaming=True, token=True, trust_remote_code=True)
|
| 88 |
-
elif option == "SPGISpeech":
|
| 89 |
-
self.dataset = load_dataset("kensho/spgispeech", "S", split="test", streaming=True, token=True, trust_remote_code=True)
|
| 90 |
-
elif option == "AMI":
|
| 91 |
-
self.dataset = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True, trust_remote_code=True)
|
| 92 |
-
|
| 93 |
-
self._preprocess()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
# from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 2 |
-
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
| 3 |
-
from transformers import pipeline
|
| 4 |
-
|
| 5 |
-
# import nemo.collections.asr as nemo_asr
|
| 6 |
-
|
| 7 |
-
from dataset import Dataset
|
| 8 |
-
from utils import data
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class Model:
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def __init__(self):
|
| 16 |
-
|
| 17 |
-
self.options = [
|
| 18 |
-
"openai/whisper-tiny.en",
|
| 19 |
-
"facebook/s2t-medium-librispeech-asr",
|
| 20 |
-
#"nvidia/stt_en_fastconformer_ctc_large"
|
| 21 |
-
]
|
| 22 |
-
self.selected = None
|
| 23 |
-
self.pipeline = None
|
| 24 |
-
self.normalize = None
|
| 25 |
-
|
| 26 |
-
def get_options(self):
|
| 27 |
-
return self.options
|
| 28 |
-
|
| 29 |
-
def load(self, option:str = None):
|
| 30 |
-
|
| 31 |
-
if option is None:
|
| 32 |
-
if self.selected is None:
|
| 33 |
-
raise ValueError("No model selected. Please first select a model")
|
| 34 |
-
option = self.selected
|
| 35 |
-
|
| 36 |
-
if option not in self.options:
|
| 37 |
-
raise ValueError(f"Selected Option is not a valid value, see: {self.options}")
|
| 38 |
-
|
| 39 |
-
if option == "openai/whisper-tiny.en":
|
| 40 |
-
self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en")
|
| 41 |
-
self.normalize = self.pipeline.tokenizer.normalize
|
| 42 |
-
|
| 43 |
-
elif option == "facebook/s2t-medium-librispeech-asr":
|
| 44 |
-
self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
|
| 45 |
-
self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
|
| 46 |
-
|
| 47 |
-
# elif option == "nvidia/stt_en_fastconformer_ctc_large":
|
| 48 |
-
# self.model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/stt_en_fastconformer_ctc_large")
|
| 49 |
-
|
| 50 |
-
def select(self, option:str=None):
|
| 51 |
-
if option not in self.options:
|
| 52 |
-
raise ValueError(f"This value is not an option, please see: {self.options}")
|
| 53 |
-
self.selected = option
|
| 54 |
-
|
| 55 |
-
def process(self, dataset:Dataset):
|
| 56 |
-
|
| 57 |
-
if self.selected is None:
|
| 58 |
-
raise ValueError("No Model is yet selected. Please select a model first")
|
| 59 |
-
|
| 60 |
-
if self.selected == "openai/whisper-tiny.en":
|
| 61 |
-
references, predictions = self._process_openai_whisper_tiny_en(dataset)
|
| 62 |
-
elif self.selected == "facebook/s2t-medium-librispeech-asr":
|
| 63 |
-
references, predictions = self._process_facebook_s2t_medium(dataset)
|
| 64 |
-
# elif self.selected == "nvidia/stt_en_fastconformer_ctc_large":
|
| 65 |
-
# references, predictions = self._process_facebook_s2t_medium(dataset)
|
| 66 |
-
|
| 67 |
-
return references, predictions
|
| 68 |
-
|
| 69 |
-
def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset):
|
| 70 |
-
|
| 71 |
-
def normalise(batch):
|
| 72 |
-
batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch))
|
| 73 |
-
return batch
|
| 74 |
-
|
| 75 |
-
DaTaSeT.normalised(normalise)
|
| 76 |
-
dataset = DaTaSeT.filter("norm_text")
|
| 77 |
-
|
| 78 |
-
predictions = []
|
| 79 |
-
references = []
|
| 80 |
-
|
| 81 |
-
# run streamed inference
|
| 82 |
-
for out in self.pipeline(data(dataset), batch_size=16):
|
| 83 |
-
predictions.append(self.normalize(out["text"]))
|
| 84 |
-
references.append(out["reference"][0])
|
| 85 |
-
|
| 86 |
-
return references, predictions
|
| 87 |
-
|
| 88 |
-
def _process_facebook_s2t_medium(self, DaTaSeT:Dataset):
|
| 89 |
-
|
| 90 |
-
def map_to_pred(batch):
|
| 91 |
-
features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
| 92 |
-
input_features = features.input_features
|
| 93 |
-
attention_mask = features.attention_mask
|
| 94 |
-
|
| 95 |
-
gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask)
|
| 96 |
-
batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
|
| 97 |
-
return batch
|
| 98 |
-
|
| 99 |
-
DaTaSeT.dataset = DaTaSeT.dataset.take(100)
|
| 100 |
-
result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"])
|
| 101 |
-
|
| 102 |
-
predictions = []
|
| 103 |
-
references = []
|
| 104 |
-
|
| 105 |
-
DaTaSeT._check_text()
|
| 106 |
-
text_column = DaTaSeT.text
|
| 107 |
-
|
| 108 |
-
for sample in result:
|
| 109 |
-
predictions.append(sample['transcription'])
|
| 110 |
-
references.append(sample[text_column])
|
| 111 |
-
|
| 112 |
-
return references, predictions
|
| 113 |
-
|
| 114 |
-
def _process_stt_en_fastconformer_ctc_large(self, DaTaSeT:Dataset):
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
self.model.transcribe(['2086-149220-0033.wav'])
|
| 118 |
-
|
| 119 |
-
predictions = []
|
| 120 |
-
references = []
|
| 121 |
-
|
| 122 |
-
return references, predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processing.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 2 |
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
| 3 |
+
import plotly.graph_objs as go
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from datasets import Audio
|
| 6 |
+
from transformers import pipeline
|
| 7 |
+
import evaluate
|
| 8 |
+
import librosa
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
wer_metric = evaluate.load("wer")
|
| 12 |
+
|
| 13 |
+
def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str):
|
| 14 |
+
|
| 15 |
+
if data_subset is None:
|
| 16 |
+
raise ValueError("No Dataset selected")
|
| 17 |
+
if model_1 is None:
|
| 18 |
+
raise ValueError("No Model 1 selected")
|
| 19 |
+
if model_2 is None:
|
| 20 |
+
raise ValueError("No Model 2 selected")
|
| 21 |
+
|
| 22 |
+
if data_subset == "Common Voice":
|
| 23 |
+
dataset, text_column = load_Common_Voice()
|
| 24 |
+
elif data_subset == "VoxPopuli":
|
| 25 |
+
dataset, text_column = load_Vox_Populi()
|
| 26 |
+
elif data_subset == "OWN Recoding/Sample":
|
| 27 |
+
sr, audio = own_audio
|
| 28 |
+
audio = audio.astype(np.float32) / 32768.0
|
| 29 |
+
print("AUDIO: ", type(audio), audio)
|
| 30 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
| 31 |
+
else:
|
| 32 |
+
# if data_subset is None then still load load_Common_Voice
|
| 33 |
+
dataset, text_column = load_Common_Voice()
|
| 34 |
+
print("Dataset Loaded")
|
| 35 |
+
|
| 36 |
+
# check if models are the same
|
| 37 |
+
model1, processor1 = load_model(model_1)
|
| 38 |
+
model2, processor2 = load_model(model_2)
|
| 39 |
+
print("Models Loaded")
|
| 40 |
+
|
| 41 |
+
if data_subset == "OWN Recoding/Sample":
|
| 42 |
+
sample = {"audio":{"array":audio,"sampling_rate":16000}}
|
| 43 |
+
transcription1 = model_compute(model1, processor1, sample, model_1)
|
| 44 |
+
transcription2 = model_compute(model2, processor2, sample, model_2)
|
| 45 |
+
|
| 46 |
+
transcriptions1 = [transcription1]
|
| 47 |
+
transcriptions2 = [transcription2]
|
| 48 |
+
references = [own_transcription]
|
| 49 |
+
|
| 50 |
+
wer1 = compute_wer(references, transcriptions1)
|
| 51 |
+
wer2 = compute_wer(references, transcriptions2)
|
| 52 |
+
|
| 53 |
+
results_md = f"""#### {model_1}
|
| 54 |
+
- WER Score: {wer1}
|
| 55 |
+
|
| 56 |
+
#### {model_2}
|
| 57 |
+
- WER Score: {wer2}"""
|
| 58 |
+
|
| 59 |
+
# Create the bar plot
|
| 60 |
+
fig = go.Figure(
|
| 61 |
+
data=[
|
| 62 |
+
go.Bar(x=[f"{model_1}"], y=[wer1]),
|
| 63 |
+
go.Bar(x=[f"{model_2}"], y=[wer2]),
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
# Update the layout for better visualization
|
| 67 |
+
fig.update_layout(
|
| 68 |
+
title="Comparison of Two Models",
|
| 69 |
+
xaxis_title="Models",
|
| 70 |
+
yaxis_title="Value",
|
| 71 |
+
barmode="group",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
yield results_md, fig
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
references = []
|
| 78 |
+
transcriptions1 = []
|
| 79 |
+
transcriptions2 = []
|
| 80 |
+
counter = 0
|
| 81 |
+
for sample in dataset:
|
| 82 |
+
print(counter)
|
| 83 |
+
counter += 1
|
| 84 |
+
|
| 85 |
+
references.append(sample[text_column])
|
| 86 |
+
|
| 87 |
+
if model_1 == model_2:
|
| 88 |
+
transcription = model_compute(model1, processor1, sample, model_1)
|
| 89 |
+
|
| 90 |
+
transcriptions1.append(transcription)
|
| 91 |
+
transcriptions2.append(transcription)
|
| 92 |
+
else:
|
| 93 |
+
transcriptions1.append(model_compute(model1, processor1, sample, model_1))
|
| 94 |
+
transcriptions2.append(model_compute(model2, processor2, sample, model_2))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
wer1 = compute_wer(references, transcriptions1)
|
| 98 |
+
wer2 = compute_wer(references, transcriptions2)
|
| 99 |
+
|
| 100 |
+
results_md = f"""#### {model_1}
|
| 101 |
+
- WER Score: {wer1}
|
| 102 |
+
|
| 103 |
+
#### {model_2}
|
| 104 |
+
- WER Score: {wer2}"""
|
| 105 |
+
|
| 106 |
+
# Create the bar plot
|
| 107 |
+
fig = go.Figure(
|
| 108 |
+
data=[
|
| 109 |
+
go.Bar(x=[f"{model_1}"], y=[wer1]),
|
| 110 |
+
go.Bar(x=[f"{model_2}"], y=[wer2]),
|
| 111 |
+
]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Update the layout for better visualization
|
| 115 |
+
fig.update_layout(
|
| 116 |
+
title="Comparison of Two Models",
|
| 117 |
+
xaxis_title="Models",
|
| 118 |
+
yaxis_title="Value",
|
| 119 |
+
barmode="group",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
yield results_md, fig
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# DATASET LOADERS
|
| 130 |
+
def load_Common_Voice():
|
| 131 |
+
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
|
| 132 |
+
text_column = "sentence"
|
| 133 |
+
dataset = dataset.take(100)
|
| 134 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
| 135 |
+
dataset = list(dataset)
|
| 136 |
+
return dataset, text_column
|
| 137 |
+
|
| 138 |
+
def load_Vox_Populi():
|
| 139 |
+
dataset = dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
|
| 140 |
+
print(next(iter(dataset)))
|
| 141 |
+
text_column = "raw_text"
|
| 142 |
+
dataset = dataset.take(100)
|
| 143 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
| 144 |
+
dataset = list(dataset)
|
| 145 |
+
return dataset, text_column
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# MODEL LOADERS
|
| 151 |
+
def load_model(model_id:str):
|
| 152 |
+
if model_id == "openai/whisper-tiny.en":
|
| 153 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
| 154 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
| 155 |
+
elif model_id == "facebook/s2t-medium-librispeech-asr":
|
| 156 |
+
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
|
| 157 |
+
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
|
| 158 |
+
else:
|
| 159 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
| 160 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
| 161 |
+
|
| 162 |
+
return model, processor
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# MODEL INFERENCE
|
| 166 |
+
def model_compute(model, processor, sample, model_id):
|
| 167 |
+
|
| 168 |
+
if model_id == "openai/whisper-tiny.en":
|
| 169 |
+
sample = sample["audio"]
|
| 170 |
+
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
|
| 171 |
+
predicted_ids = model.generate(input_features)
|
| 172 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
| 173 |
+
return transcription[0]
|
| 174 |
+
elif model_id == "facebook/s2t-medium-librispeech-asr":
|
| 175 |
+
sample = sample["audio"]
|
| 176 |
+
features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
| 177 |
+
input_features = features.input_features
|
| 178 |
+
attention_mask = features.attention_mask
|
| 179 |
+
gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
|
| 180 |
+
transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
|
| 181 |
+
return transcription[0]
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
return model(sample)
|
| 185 |
+
|
| 186 |
+
# UTILS
|
| 187 |
+
def compute_wer(references, predictions):
|
| 188 |
+
wer = wer_metric.compute(references=references, predictions=predictions)
|
| 189 |
+
wer = round(100 * wer, 2)
|
| 190 |
+
return wer
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# print(load_Vox_Populi())
|
| 194 |
+
# print(run("Common Voice", "openai/whisper-tiny.en", "openai/whisper-tiny.en", None, None))
|
utils.py
CHANGED
|
@@ -22,3 +22,14 @@ def compute_wer(references, predictions):
|
|
| 22 |
return wer
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
return wer
|
| 23 |
|
| 24 |
|
| 25 |
+
# def run_tests (dataset_choice:str, model:str):
|
| 26 |
+
|
| 27 |
+
# MoDeL = Model()
|
| 28 |
+
# MoDeL.select(model)
|
| 29 |
+
# MoDeL.load()
|
| 30 |
+
# DaTaSeT = Dataset(100)
|
| 31 |
+
# DaTaSeT.load(dataset_choice)
|
| 32 |
+
# references, predictions = MoDeL.process(DaTaSeT)
|
| 33 |
+
# wer = compute_wer(references=references, predictions=predictions)
|
| 34 |
+
# return wer
|
| 35 |
+
|