j-tobias
commited on
Commit
·
09b2769
1
Parent(s):
e3bf44e
updated backend
Browse files- .codetogether.ignore +0 -1
- __pycache__/dataset.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/processing.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +42 -63
- dataset.py +0 -93
- model.py +0 -122
- processing.py +194 -0
- utils.py +11 -0
.codetogether.ignore
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
credentials.json
|
|
|
|
__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (3.34 kB). View file
|
|
__pycache__/model.cpython-310.pyc
ADDED
Binary file (3.74 kB). View file
|
|
__pycache__/processing.cpython-310.pyc
ADDED
Binary file (4.24 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.04 kB). View file
|
|
app.py
CHANGED
@@ -1,9 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
from model import Model
|
4 |
-
from utils import compute_wer
|
5 |
-
import plotly.graph_objs as go
|
6 |
-
|
7 |
|
8 |
# from utils import hf_login
|
9 |
# hf_login()
|
@@ -14,73 +10,58 @@ import os
|
|
14 |
hf_token = os.getenv("HF_Token")
|
15 |
login(hf_token)
|
16 |
|
17 |
-
dataset = Dataset()
|
18 |
-
models = Model()
|
19 |
|
20 |
-
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
MoDeL.select(model)
|
24 |
-
MoDeL.load()
|
25 |
-
DaTaSeT = Dataset(100)
|
26 |
-
DaTaSeT.load(dataset_choice)
|
27 |
-
references, predictions = MoDeL.process(DaTaSeT)
|
28 |
-
wer = compute_wer(references=references, predictions=predictions)
|
29 |
-
return wer
|
30 |
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
wer_result_2 = run_tests(data_subset, model_2)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
|
60 |
def get_card(selected_model:str)->str:
|
61 |
|
62 |
-
print("Selected Model for Card: ", selected_model)
|
63 |
with open("cards.txt", "r") as f:
|
64 |
cards = f.read()
|
65 |
|
66 |
-
print(cards)
|
67 |
-
|
68 |
cards = cards.split("@@")
|
69 |
for card in cards:
|
70 |
-
print("CARD: ", card)
|
71 |
if "ID: "+selected_model in card:
|
72 |
return card
|
73 |
|
74 |
return "Unknown Model"
|
75 |
|
76 |
-
def is_own(
|
77 |
-
if
|
78 |
-
|
79 |
-
|
80 |
-
return
|
81 |
-
own_audio = None
|
82 |
-
own_transcription = None
|
83 |
-
return own_audio, own_transcription
|
84 |
|
85 |
with gr.Blocks() as demo:
|
86 |
|
@@ -106,31 +87,29 @@ Happy experimenting and comparing! 🚀""")
|
|
106 |
pass
|
107 |
with gr.Column(scale=5):
|
108 |
data_subset = gr.Radio(
|
109 |
-
value="
|
110 |
-
choices=
|
111 |
label="Data subset / Own Sample",
|
112 |
)
|
|
|
|
|
|
|
113 |
with gr.Column(scale=1):
|
114 |
pass
|
115 |
|
116 |
-
with gr.Row():
|
117 |
-
own_audio = gr.Audio(sources=['microphone'],streaming=False,visible=False)
|
118 |
-
own_transcription = gr.TextArea(lines=2, visible=False)
|
119 |
-
data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
|
120 |
-
|
121 |
|
122 |
with gr.Row():
|
123 |
|
124 |
with gr.Column(scale=1):
|
125 |
model_1 = gr.Dropdown(
|
126 |
-
choices=
|
127 |
label="Select Model"
|
128 |
)
|
129 |
model_1_card = gr.Markdown("")
|
130 |
|
131 |
with gr.Column(scale=1):
|
132 |
model_2 = gr.Dropdown(
|
133 |
-
choices=
|
134 |
label="Select Model"
|
135 |
)
|
136 |
model_2_card = gr.Markdown("")
|
@@ -148,6 +127,6 @@ Happy experimenting and comparing! 🚀""")
|
|
148 |
gr.Markdown('## <p style="text-align: center;">Results</p>')
|
149 |
results_md = gr.Markdown("")
|
150 |
results_plot = gr.Plot(show_label=False)
|
151 |
-
eval_btn.click(
|
152 |
|
153 |
demo.launch(debug=True)
|
|
|
1 |
import gradio as gr
|
2 |
+
from processing import run
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# from utils import hf_login
|
5 |
# hf_login()
|
|
|
10 |
hf_token = os.getenv("HF_Token")
|
11 |
login(hf_token)
|
12 |
|
|
|
|
|
13 |
|
14 |
+
MODEL_OPTIONS = ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr"]
|
15 |
+
DATASET_OPTIONS = ["Common Voice", "VoxPopuli", "OWN Recoding/Sample"]
|
16 |
+
|
17 |
|
18 |
+
# def eval(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str)->str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
# print("OWN AUDIO: ", type(own_audio), own_audio)
|
21 |
|
22 |
+
# wer_result_1, wer_result_2, references, transcriptions1, transcriptions2 = run(data_subset, model_1, model_2, own_audio, own_transcription)
|
|
|
23 |
|
24 |
+
# results_md = f"""#### {model_1}
|
25 |
+
# - WER Score: {wer_result_1}
|
26 |
|
27 |
+
# #### {model_2}
|
28 |
+
# - WER Score: {wer_result_2}"""
|
29 |
+
|
30 |
+
# # Create the bar plot
|
31 |
+
# fig = go.Figure(
|
32 |
+
# data=[
|
33 |
+
# go.Bar(x=[f"{model_1}"], y=[wer_result_1]),
|
34 |
+
# go.Bar(x=[f"{model_2}"], y=[wer_result_2]),
|
35 |
+
# ]
|
36 |
+
# )
|
37 |
+
|
38 |
+
# # Update the layout for better visualization
|
39 |
+
# fig.update_layout(
|
40 |
+
# title="Comparison of Two Models",
|
41 |
+
# xaxis_title="Models",
|
42 |
+
# yaxis_title="Value",
|
43 |
+
# barmode="group",
|
44 |
+
# )
|
45 |
+
|
46 |
+
# return results_md, fig
|
47 |
|
48 |
def get_card(selected_model:str)->str:
|
49 |
|
|
|
50 |
with open("cards.txt", "r") as f:
|
51 |
cards = f.read()
|
52 |
|
|
|
|
|
53 |
cards = cards.split("@@")
|
54 |
for card in cards:
|
|
|
55 |
if "ID: "+selected_model in card:
|
56 |
return card
|
57 |
|
58 |
return "Unknown Model"
|
59 |
|
60 |
+
def is_own(selected_option):
|
61 |
+
if selected_option == "OWN Recoding/Sample":
|
62 |
+
return gr.update(visible=True), gr.update(visible=True)
|
63 |
+
else:
|
64 |
+
return gr.update(visible=False), gr.update(visible=False)
|
|
|
|
|
|
|
65 |
|
66 |
with gr.Blocks() as demo:
|
67 |
|
|
|
87 |
pass
|
88 |
with gr.Column(scale=5):
|
89 |
data_subset = gr.Radio(
|
90 |
+
value="Common Voice",
|
91 |
+
choices=DATASET_OPTIONS,
|
92 |
label="Data subset / Own Sample",
|
93 |
)
|
94 |
+
own_audio = gr.Audio(visible=False)
|
95 |
+
own_transcription = gr.TextArea(lines=2, visible=False)
|
96 |
+
data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
|
97 |
with gr.Column(scale=1):
|
98 |
pass
|
99 |
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
with gr.Row():
|
102 |
|
103 |
with gr.Column(scale=1):
|
104 |
model_1 = gr.Dropdown(
|
105 |
+
choices=MODEL_OPTIONS,
|
106 |
label="Select Model"
|
107 |
)
|
108 |
model_1_card = gr.Markdown("")
|
109 |
|
110 |
with gr.Column(scale=1):
|
111 |
model_2 = gr.Dropdown(
|
112 |
+
choices=MODEL_OPTIONS,
|
113 |
label="Select Model"
|
114 |
)
|
115 |
model_2_card = gr.Markdown("")
|
|
|
127 |
gr.Markdown('## <p style="text-align: center;">Results</p>')
|
128 |
results_md = gr.Markdown("")
|
129 |
results_plot = gr.Plot(show_label=False)
|
130 |
+
eval_btn.click(run, [data_subset, model_1, model_2, own_audio, own_transcription], [results_md, results_plot])
|
131 |
|
132 |
demo.launch(debug=True)
|
dataset.py
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
-
from datasets import Audio
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
class Dataset:
|
7 |
-
|
8 |
-
def __init__(self, n:int = 100):
|
9 |
-
|
10 |
-
self.n = n
|
11 |
-
self.options = ['LibriSpeech Clean', 'LibriSpeech Other', 'Common Voice', 'VoxPopuli', 'TEDLIUM', 'GigaSpeech', 'SPGISpeech', 'AMI', 'OWN']
|
12 |
-
self.selected = None
|
13 |
-
self.dataset = None
|
14 |
-
self.text = None
|
15 |
-
|
16 |
-
def get_options(self):
|
17 |
-
return self.options
|
18 |
-
|
19 |
-
def _check_text(self):
|
20 |
-
sample = next(iter(self.dataset))
|
21 |
-
print(sample)
|
22 |
-
self._get_text(sample)
|
23 |
-
|
24 |
-
def _get_text(self, sample):
|
25 |
-
if "text" in sample:
|
26 |
-
self.text = "text"
|
27 |
-
return sample["text"]
|
28 |
-
elif "sentence" in sample:
|
29 |
-
self.text = "sentence"
|
30 |
-
return sample["sentence"]
|
31 |
-
elif "normalized_text" in sample:
|
32 |
-
self.text = "normalized_text"
|
33 |
-
return sample["normalized_text"]
|
34 |
-
elif "transcript" in sample:
|
35 |
-
self.text = "transcript"
|
36 |
-
return sample["transcript"]
|
37 |
-
else:
|
38 |
-
raise ValueError(f"Sample: {sample.keys()} has no transcript.")
|
39 |
-
|
40 |
-
def filter(self, input_column:str = None):
|
41 |
-
|
42 |
-
if input_column is None:
|
43 |
-
if self.text is not None:
|
44 |
-
input_column = self.text
|
45 |
-
else:
|
46 |
-
input_column = self._check_text()
|
47 |
-
|
48 |
-
def is_target_text_in_range(ref):
|
49 |
-
if ref.strip() == "ignore time segment in scoring":
|
50 |
-
return False
|
51 |
-
else:
|
52 |
-
return ref.strip() != ""
|
53 |
-
|
54 |
-
self.dataset = self.dataset.filter(is_target_text_in_range, input_columns=[input_column])
|
55 |
-
return self.dataset
|
56 |
-
|
57 |
-
def normalised(self, normalise):
|
58 |
-
self.dataset = self.dataset.map(normalise)
|
59 |
-
|
60 |
-
def _select(self, option:str):
|
61 |
-
if option not in self.options:
|
62 |
-
raise ValueError(f"This value is not an option, please see: {self.options}")
|
63 |
-
self.selected = option
|
64 |
-
|
65 |
-
def _preprocess(self):
|
66 |
-
|
67 |
-
self.dataset = self.dataset.take(self.n)
|
68 |
-
self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=16000))
|
69 |
-
|
70 |
-
def load(self, option:str = None):
|
71 |
-
|
72 |
-
self._select(option)
|
73 |
-
|
74 |
-
if option == "OWN":
|
75 |
-
pass
|
76 |
-
elif option == "LibriSpeech Clean":
|
77 |
-
self.dataset = load_dataset("librispeech_asr", "all", split="test.clean", streaming=True)
|
78 |
-
elif option == "LibriSpeech Other":
|
79 |
-
self.dataset = load_dataset("librispeech_asr", "all", split="test.other", streaming=True)
|
80 |
-
elif option == "Common Voice":
|
81 |
-
self.dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
|
82 |
-
elif option == "VoxPopuli":
|
83 |
-
self.dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
|
84 |
-
elif option == "TEDLIUM":
|
85 |
-
self.dataset = load_dataset("LIUM/tedlium", "release3", split="test", streaming=True, trust_remote_code=True)
|
86 |
-
elif option == "GigaSpeech":
|
87 |
-
self.dataset = load_dataset("speechcolab/gigaspeech", "xs", split="test", streaming=True, token=True, trust_remote_code=True)
|
88 |
-
elif option == "SPGISpeech":
|
89 |
-
self.dataset = load_dataset("kensho/spgispeech", "S", split="test", streaming=True, token=True, trust_remote_code=True)
|
90 |
-
elif option == "AMI":
|
91 |
-
self.dataset = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True, trust_remote_code=True)
|
92 |
-
|
93 |
-
self._preprocess()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
# from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
2 |
-
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
3 |
-
from transformers import pipeline
|
4 |
-
|
5 |
-
# import nemo.collections.asr as nemo_asr
|
6 |
-
|
7 |
-
from dataset import Dataset
|
8 |
-
from utils import data
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
class Model:
|
13 |
-
|
14 |
-
|
15 |
-
def __init__(self):
|
16 |
-
|
17 |
-
self.options = [
|
18 |
-
"openai/whisper-tiny.en",
|
19 |
-
"facebook/s2t-medium-librispeech-asr",
|
20 |
-
#"nvidia/stt_en_fastconformer_ctc_large"
|
21 |
-
]
|
22 |
-
self.selected = None
|
23 |
-
self.pipeline = None
|
24 |
-
self.normalize = None
|
25 |
-
|
26 |
-
def get_options(self):
|
27 |
-
return self.options
|
28 |
-
|
29 |
-
def load(self, option:str = None):
|
30 |
-
|
31 |
-
if option is None:
|
32 |
-
if self.selected is None:
|
33 |
-
raise ValueError("No model selected. Please first select a model")
|
34 |
-
option = self.selected
|
35 |
-
|
36 |
-
if option not in self.options:
|
37 |
-
raise ValueError(f"Selected Option is not a valid value, see: {self.options}")
|
38 |
-
|
39 |
-
if option == "openai/whisper-tiny.en":
|
40 |
-
self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en")
|
41 |
-
self.normalize = self.pipeline.tokenizer.normalize
|
42 |
-
|
43 |
-
elif option == "facebook/s2t-medium-librispeech-asr":
|
44 |
-
self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
|
45 |
-
self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
|
46 |
-
|
47 |
-
# elif option == "nvidia/stt_en_fastconformer_ctc_large":
|
48 |
-
# self.model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/stt_en_fastconformer_ctc_large")
|
49 |
-
|
50 |
-
def select(self, option:str=None):
|
51 |
-
if option not in self.options:
|
52 |
-
raise ValueError(f"This value is not an option, please see: {self.options}")
|
53 |
-
self.selected = option
|
54 |
-
|
55 |
-
def process(self, dataset:Dataset):
|
56 |
-
|
57 |
-
if self.selected is None:
|
58 |
-
raise ValueError("No Model is yet selected. Please select a model first")
|
59 |
-
|
60 |
-
if self.selected == "openai/whisper-tiny.en":
|
61 |
-
references, predictions = self._process_openai_whisper_tiny_en(dataset)
|
62 |
-
elif self.selected == "facebook/s2t-medium-librispeech-asr":
|
63 |
-
references, predictions = self._process_facebook_s2t_medium(dataset)
|
64 |
-
# elif self.selected == "nvidia/stt_en_fastconformer_ctc_large":
|
65 |
-
# references, predictions = self._process_facebook_s2t_medium(dataset)
|
66 |
-
|
67 |
-
return references, predictions
|
68 |
-
|
69 |
-
def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset):
|
70 |
-
|
71 |
-
def normalise(batch):
|
72 |
-
batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch))
|
73 |
-
return batch
|
74 |
-
|
75 |
-
DaTaSeT.normalised(normalise)
|
76 |
-
dataset = DaTaSeT.filter("norm_text")
|
77 |
-
|
78 |
-
predictions = []
|
79 |
-
references = []
|
80 |
-
|
81 |
-
# run streamed inference
|
82 |
-
for out in self.pipeline(data(dataset), batch_size=16):
|
83 |
-
predictions.append(self.normalize(out["text"]))
|
84 |
-
references.append(out["reference"][0])
|
85 |
-
|
86 |
-
return references, predictions
|
87 |
-
|
88 |
-
def _process_facebook_s2t_medium(self, DaTaSeT:Dataset):
|
89 |
-
|
90 |
-
def map_to_pred(batch):
|
91 |
-
features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
92 |
-
input_features = features.input_features
|
93 |
-
attention_mask = features.attention_mask
|
94 |
-
|
95 |
-
gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask)
|
96 |
-
batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
|
97 |
-
return batch
|
98 |
-
|
99 |
-
DaTaSeT.dataset = DaTaSeT.dataset.take(100)
|
100 |
-
result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"])
|
101 |
-
|
102 |
-
predictions = []
|
103 |
-
references = []
|
104 |
-
|
105 |
-
DaTaSeT._check_text()
|
106 |
-
text_column = DaTaSeT.text
|
107 |
-
|
108 |
-
for sample in result:
|
109 |
-
predictions.append(sample['transcription'])
|
110 |
-
references.append(sample[text_column])
|
111 |
-
|
112 |
-
return references, predictions
|
113 |
-
|
114 |
-
def _process_stt_en_fastconformer_ctc_large(self, DaTaSeT:Dataset):
|
115 |
-
|
116 |
-
|
117 |
-
self.model.transcribe(['2086-149220-0033.wav'])
|
118 |
-
|
119 |
-
predictions = []
|
120 |
-
references = []
|
121 |
-
|
122 |
-
return references, predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processing.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
2 |
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
3 |
+
import plotly.graph_objs as go
|
4 |
+
from datasets import load_dataset
|
5 |
+
from datasets import Audio
|
6 |
+
from transformers import pipeline
|
7 |
+
import evaluate
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
wer_metric = evaluate.load("wer")
|
12 |
+
|
13 |
+
def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str):
|
14 |
+
|
15 |
+
if data_subset is None:
|
16 |
+
raise ValueError("No Dataset selected")
|
17 |
+
if model_1 is None:
|
18 |
+
raise ValueError("No Model 1 selected")
|
19 |
+
if model_2 is None:
|
20 |
+
raise ValueError("No Model 2 selected")
|
21 |
+
|
22 |
+
if data_subset == "Common Voice":
|
23 |
+
dataset, text_column = load_Common_Voice()
|
24 |
+
elif data_subset == "VoxPopuli":
|
25 |
+
dataset, text_column = load_Vox_Populi()
|
26 |
+
elif data_subset == "OWN Recoding/Sample":
|
27 |
+
sr, audio = own_audio
|
28 |
+
audio = audio.astype(np.float32) / 32768.0
|
29 |
+
print("AUDIO: ", type(audio), audio)
|
30 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
31 |
+
else:
|
32 |
+
# if data_subset is None then still load load_Common_Voice
|
33 |
+
dataset, text_column = load_Common_Voice()
|
34 |
+
print("Dataset Loaded")
|
35 |
+
|
36 |
+
# check if models are the same
|
37 |
+
model1, processor1 = load_model(model_1)
|
38 |
+
model2, processor2 = load_model(model_2)
|
39 |
+
print("Models Loaded")
|
40 |
+
|
41 |
+
if data_subset == "OWN Recoding/Sample":
|
42 |
+
sample = {"audio":{"array":audio,"sampling_rate":16000}}
|
43 |
+
transcription1 = model_compute(model1, processor1, sample, model_1)
|
44 |
+
transcription2 = model_compute(model2, processor2, sample, model_2)
|
45 |
+
|
46 |
+
transcriptions1 = [transcription1]
|
47 |
+
transcriptions2 = [transcription2]
|
48 |
+
references = [own_transcription]
|
49 |
+
|
50 |
+
wer1 = compute_wer(references, transcriptions1)
|
51 |
+
wer2 = compute_wer(references, transcriptions2)
|
52 |
+
|
53 |
+
results_md = f"""#### {model_1}
|
54 |
+
- WER Score: {wer1}
|
55 |
+
|
56 |
+
#### {model_2}
|
57 |
+
- WER Score: {wer2}"""
|
58 |
+
|
59 |
+
# Create the bar plot
|
60 |
+
fig = go.Figure(
|
61 |
+
data=[
|
62 |
+
go.Bar(x=[f"{model_1}"], y=[wer1]),
|
63 |
+
go.Bar(x=[f"{model_2}"], y=[wer2]),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
# Update the layout for better visualization
|
67 |
+
fig.update_layout(
|
68 |
+
title="Comparison of Two Models",
|
69 |
+
xaxis_title="Models",
|
70 |
+
yaxis_title="Value",
|
71 |
+
barmode="group",
|
72 |
+
)
|
73 |
+
|
74 |
+
yield results_md, fig
|
75 |
+
|
76 |
+
else:
|
77 |
+
references = []
|
78 |
+
transcriptions1 = []
|
79 |
+
transcriptions2 = []
|
80 |
+
counter = 0
|
81 |
+
for sample in dataset:
|
82 |
+
print(counter)
|
83 |
+
counter += 1
|
84 |
+
|
85 |
+
references.append(sample[text_column])
|
86 |
+
|
87 |
+
if model_1 == model_2:
|
88 |
+
transcription = model_compute(model1, processor1, sample, model_1)
|
89 |
+
|
90 |
+
transcriptions1.append(transcription)
|
91 |
+
transcriptions2.append(transcription)
|
92 |
+
else:
|
93 |
+
transcriptions1.append(model_compute(model1, processor1, sample, model_1))
|
94 |
+
transcriptions2.append(model_compute(model2, processor2, sample, model_2))
|
95 |
+
|
96 |
+
|
97 |
+
wer1 = compute_wer(references, transcriptions1)
|
98 |
+
wer2 = compute_wer(references, transcriptions2)
|
99 |
+
|
100 |
+
results_md = f"""#### {model_1}
|
101 |
+
- WER Score: {wer1}
|
102 |
+
|
103 |
+
#### {model_2}
|
104 |
+
- WER Score: {wer2}"""
|
105 |
+
|
106 |
+
# Create the bar plot
|
107 |
+
fig = go.Figure(
|
108 |
+
data=[
|
109 |
+
go.Bar(x=[f"{model_1}"], y=[wer1]),
|
110 |
+
go.Bar(x=[f"{model_2}"], y=[wer2]),
|
111 |
+
]
|
112 |
+
)
|
113 |
+
|
114 |
+
# Update the layout for better visualization
|
115 |
+
fig.update_layout(
|
116 |
+
title="Comparison of Two Models",
|
117 |
+
xaxis_title="Models",
|
118 |
+
yaxis_title="Value",
|
119 |
+
barmode="group",
|
120 |
+
)
|
121 |
+
|
122 |
+
yield results_md, fig
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
# DATASET LOADERS
|
130 |
+
def load_Common_Voice():
|
131 |
+
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
|
132 |
+
text_column = "sentence"
|
133 |
+
dataset = dataset.take(100)
|
134 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
135 |
+
dataset = list(dataset)
|
136 |
+
return dataset, text_column
|
137 |
+
|
138 |
+
def load_Vox_Populi():
|
139 |
+
dataset = dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
|
140 |
+
print(next(iter(dataset)))
|
141 |
+
text_column = "raw_text"
|
142 |
+
dataset = dataset.take(100)
|
143 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
144 |
+
dataset = list(dataset)
|
145 |
+
return dataset, text_column
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
# MODEL LOADERS
|
151 |
+
def load_model(model_id:str):
|
152 |
+
if model_id == "openai/whisper-tiny.en":
|
153 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
154 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
155 |
+
elif model_id == "facebook/s2t-medium-librispeech-asr":
|
156 |
+
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
|
157 |
+
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
|
158 |
+
else:
|
159 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
160 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
161 |
+
|
162 |
+
return model, processor
|
163 |
+
|
164 |
+
|
165 |
+
# MODEL INFERENCE
|
166 |
+
def model_compute(model, processor, sample, model_id):
|
167 |
+
|
168 |
+
if model_id == "openai/whisper-tiny.en":
|
169 |
+
sample = sample["audio"]
|
170 |
+
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
|
171 |
+
predicted_ids = model.generate(input_features)
|
172 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
173 |
+
return transcription[0]
|
174 |
+
elif model_id == "facebook/s2t-medium-librispeech-asr":
|
175 |
+
sample = sample["audio"]
|
176 |
+
features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
177 |
+
input_features = features.input_features
|
178 |
+
attention_mask = features.attention_mask
|
179 |
+
gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
|
180 |
+
transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
|
181 |
+
return transcription[0]
|
182 |
+
|
183 |
+
else:
|
184 |
+
return model(sample)
|
185 |
+
|
186 |
+
# UTILS
|
187 |
+
def compute_wer(references, predictions):
|
188 |
+
wer = wer_metric.compute(references=references, predictions=predictions)
|
189 |
+
wer = round(100 * wer, 2)
|
190 |
+
return wer
|
191 |
+
|
192 |
+
|
193 |
+
# print(load_Vox_Populi())
|
194 |
+
# print(run("Common Voice", "openai/whisper-tiny.en", "openai/whisper-tiny.en", None, None))
|
utils.py
CHANGED
@@ -22,3 +22,14 @@ def compute_wer(references, predictions):
|
|
22 |
return wer
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
return wer
|
23 |
|
24 |
|
25 |
+
# def run_tests (dataset_choice:str, model:str):
|
26 |
+
|
27 |
+
# MoDeL = Model()
|
28 |
+
# MoDeL.select(model)
|
29 |
+
# MoDeL.load()
|
30 |
+
# DaTaSeT = Dataset(100)
|
31 |
+
# DaTaSeT.load(dataset_choice)
|
32 |
+
# references, predictions = MoDeL.process(DaTaSeT)
|
33 |
+
# wer = compute_wer(references=references, predictions=predictions)
|
34 |
+
# return wer
|
35 |
+
|