j-tobias commited on
Commit
09b2769
·
1 Parent(s): e3bf44e

updated backend

Browse files
.codetogether.ignore DELETED
@@ -1 +0,0 @@
1
- credentials.json
 
 
__pycache__/dataset.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (3.74 kB). View file
 
__pycache__/processing.cpython-310.pyc ADDED
Binary file (4.24 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
app.py CHANGED
@@ -1,9 +1,5 @@
1
  import gradio as gr
2
- from dataset import Dataset
3
- from model import Model
4
- from utils import compute_wer
5
- import plotly.graph_objs as go
6
-
7
 
8
  # from utils import hf_login
9
  # hf_login()
@@ -14,73 +10,58 @@ import os
14
  hf_token = os.getenv("HF_Token")
15
  login(hf_token)
16
 
17
- dataset = Dataset()
18
- models = Model()
19
 
20
- def run_tests (dataset_choice:str, model:str):
 
 
21
 
22
- MoDeL = Model()
23
- MoDeL.select(model)
24
- MoDeL.load()
25
- DaTaSeT = Dataset(100)
26
- DaTaSeT.load(dataset_choice)
27
- references, predictions = MoDeL.process(DaTaSeT)
28
- wer = compute_wer(references=references, predictions=predictions)
29
- return wer
30
 
31
- def eval(data_subset:str, model_1:str, model_2:str)->str:
32
 
33
- wer_result_1 = run_tests(data_subset, model_1)
34
- wer_result_2 = run_tests(data_subset, model_2)
35
 
36
- results_md = f"""#### {model_1}
37
- - WER Score: {wer_result_1}
38
 
39
- #### {model_2}
40
- - WER Score: {wer_result_2}"""
41
-
42
- # Create the bar plot
43
- fig = go.Figure(
44
- data=[
45
- go.Bar(x=[f"{model_1}"], y=[wer_result_1]),
46
- go.Bar(x=[f"{model_2}"], y=[wer_result_2]),
47
- ]
48
- )
49
-
50
- # Update the layout for better visualization
51
- fig.update_layout(
52
- title="Comparison of Two Models",
53
- xaxis_title="Models",
54
- yaxis_title="Value",
55
- barmode="group",
56
- )
57
-
58
- return results_md, fig
59
 
60
  def get_card(selected_model:str)->str:
61
 
62
- print("Selected Model for Card: ", selected_model)
63
  with open("cards.txt", "r") as f:
64
  cards = f.read()
65
 
66
- print(cards)
67
-
68
  cards = cards.split("@@")
69
  for card in cards:
70
- print("CARD: ", card)
71
  if "ID: "+selected_model in card:
72
  return card
73
 
74
  return "Unknown Model"
75
 
76
- def is_own(data_subset:str):
77
- if data_subset == "own":
78
- own_audio = gr.Audio(sources=['microphone'],streaming=False)
79
- own_transcription = gr.TextArea(lines=2)
80
- return own_audio, own_transcription
81
- own_audio = None
82
- own_transcription = None
83
- return own_audio, own_transcription
84
 
85
  with gr.Blocks() as demo:
86
 
@@ -106,31 +87,29 @@ Happy experimenting and comparing! 🚀""")
106
  pass
107
  with gr.Column(scale=5):
108
  data_subset = gr.Radio(
109
- value="LibriSpeech Clean",
110
- choices=dataset.get_options(),
111
  label="Data subset / Own Sample",
112
  )
 
 
 
113
  with gr.Column(scale=1):
114
  pass
115
 
116
- with gr.Row():
117
- own_audio = gr.Audio(sources=['microphone'],streaming=False,visible=False)
118
- own_transcription = gr.TextArea(lines=2, visible=False)
119
- data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
120
-
121
 
122
  with gr.Row():
123
 
124
  with gr.Column(scale=1):
125
  model_1 = gr.Dropdown(
126
- choices=models.get_options(),
127
  label="Select Model"
128
  )
129
  model_1_card = gr.Markdown("")
130
 
131
  with gr.Column(scale=1):
132
  model_2 = gr.Dropdown(
133
- choices=models.get_options(),
134
  label="Select Model"
135
  )
136
  model_2_card = gr.Markdown("")
@@ -148,6 +127,6 @@ Happy experimenting and comparing! 🚀""")
148
  gr.Markdown('## <p style="text-align: center;">Results</p>')
149
  results_md = gr.Markdown("")
150
  results_plot = gr.Plot(show_label=False)
151
- eval_btn.click(eval, [data_subset, model_1, model_2], [results_md, results_plot])
152
 
153
  demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from processing import run
 
 
 
 
3
 
4
  # from utils import hf_login
5
  # hf_login()
 
10
  hf_token = os.getenv("HF_Token")
11
  login(hf_token)
12
 
 
 
13
 
14
+ MODEL_OPTIONS = ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr"]
15
+ DATASET_OPTIONS = ["Common Voice", "VoxPopuli", "OWN Recoding/Sample"]
16
+
17
 
18
+ # def eval(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str)->str:
 
 
 
 
 
 
 
19
 
20
+ # print("OWN AUDIO: ", type(own_audio), own_audio)
21
 
22
+ # wer_result_1, wer_result_2, references, transcriptions1, transcriptions2 = run(data_subset, model_1, model_2, own_audio, own_transcription)
 
23
 
24
+ # results_md = f"""#### {model_1}
25
+ # - WER Score: {wer_result_1}
26
 
27
+ # #### {model_2}
28
+ # - WER Score: {wer_result_2}"""
29
+
30
+ # # Create the bar plot
31
+ # fig = go.Figure(
32
+ # data=[
33
+ # go.Bar(x=[f"{model_1}"], y=[wer_result_1]),
34
+ # go.Bar(x=[f"{model_2}"], y=[wer_result_2]),
35
+ # ]
36
+ # )
37
+
38
+ # # Update the layout for better visualization
39
+ # fig.update_layout(
40
+ # title="Comparison of Two Models",
41
+ # xaxis_title="Models",
42
+ # yaxis_title="Value",
43
+ # barmode="group",
44
+ # )
45
+
46
+ # return results_md, fig
47
 
48
  def get_card(selected_model:str)->str:
49
 
 
50
  with open("cards.txt", "r") as f:
51
  cards = f.read()
52
 
 
 
53
  cards = cards.split("@@")
54
  for card in cards:
 
55
  if "ID: "+selected_model in card:
56
  return card
57
 
58
  return "Unknown Model"
59
 
60
+ def is_own(selected_option):
61
+ if selected_option == "OWN Recoding/Sample":
62
+ return gr.update(visible=True), gr.update(visible=True)
63
+ else:
64
+ return gr.update(visible=False), gr.update(visible=False)
 
 
 
65
 
66
  with gr.Blocks() as demo:
67
 
 
87
  pass
88
  with gr.Column(scale=5):
89
  data_subset = gr.Radio(
90
+ value="Common Voice",
91
+ choices=DATASET_OPTIONS,
92
  label="Data subset / Own Sample",
93
  )
94
+ own_audio = gr.Audio(visible=False)
95
+ own_transcription = gr.TextArea(lines=2, visible=False)
96
+ data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
97
  with gr.Column(scale=1):
98
  pass
99
 
 
 
 
 
 
100
 
101
  with gr.Row():
102
 
103
  with gr.Column(scale=1):
104
  model_1 = gr.Dropdown(
105
+ choices=MODEL_OPTIONS,
106
  label="Select Model"
107
  )
108
  model_1_card = gr.Markdown("")
109
 
110
  with gr.Column(scale=1):
111
  model_2 = gr.Dropdown(
112
+ choices=MODEL_OPTIONS,
113
  label="Select Model"
114
  )
115
  model_2_card = gr.Markdown("")
 
127
  gr.Markdown('## <p style="text-align: center;">Results</p>')
128
  results_md = gr.Markdown("")
129
  results_plot = gr.Plot(show_label=False)
130
+ eval_btn.click(run, [data_subset, model_1, model_2, own_audio, own_transcription], [results_md, results_plot])
131
 
132
  demo.launch(debug=True)
dataset.py DELETED
@@ -1,93 +0,0 @@
1
- from datasets import load_dataset
2
- from datasets import Audio
3
-
4
-
5
-
6
- class Dataset:
7
-
8
- def __init__(self, n:int = 100):
9
-
10
- self.n = n
11
- self.options = ['LibriSpeech Clean', 'LibriSpeech Other', 'Common Voice', 'VoxPopuli', 'TEDLIUM', 'GigaSpeech', 'SPGISpeech', 'AMI', 'OWN']
12
- self.selected = None
13
- self.dataset = None
14
- self.text = None
15
-
16
- def get_options(self):
17
- return self.options
18
-
19
- def _check_text(self):
20
- sample = next(iter(self.dataset))
21
- print(sample)
22
- self._get_text(sample)
23
-
24
- def _get_text(self, sample):
25
- if "text" in sample:
26
- self.text = "text"
27
- return sample["text"]
28
- elif "sentence" in sample:
29
- self.text = "sentence"
30
- return sample["sentence"]
31
- elif "normalized_text" in sample:
32
- self.text = "normalized_text"
33
- return sample["normalized_text"]
34
- elif "transcript" in sample:
35
- self.text = "transcript"
36
- return sample["transcript"]
37
- else:
38
- raise ValueError(f"Sample: {sample.keys()} has no transcript.")
39
-
40
- def filter(self, input_column:str = None):
41
-
42
- if input_column is None:
43
- if self.text is not None:
44
- input_column = self.text
45
- else:
46
- input_column = self._check_text()
47
-
48
- def is_target_text_in_range(ref):
49
- if ref.strip() == "ignore time segment in scoring":
50
- return False
51
- else:
52
- return ref.strip() != ""
53
-
54
- self.dataset = self.dataset.filter(is_target_text_in_range, input_columns=[input_column])
55
- return self.dataset
56
-
57
- def normalised(self, normalise):
58
- self.dataset = self.dataset.map(normalise)
59
-
60
- def _select(self, option:str):
61
- if option not in self.options:
62
- raise ValueError(f"This value is not an option, please see: {self.options}")
63
- self.selected = option
64
-
65
- def _preprocess(self):
66
-
67
- self.dataset = self.dataset.take(self.n)
68
- self.dataset = self.dataset.cast_column("audio", Audio(sampling_rate=16000))
69
-
70
- def load(self, option:str = None):
71
-
72
- self._select(option)
73
-
74
- if option == "OWN":
75
- pass
76
- elif option == "LibriSpeech Clean":
77
- self.dataset = load_dataset("librispeech_asr", "all", split="test.clean", streaming=True)
78
- elif option == "LibriSpeech Other":
79
- self.dataset = load_dataset("librispeech_asr", "all", split="test.other", streaming=True)
80
- elif option == "Common Voice":
81
- self.dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
82
- elif option == "VoxPopuli":
83
- self.dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
84
- elif option == "TEDLIUM":
85
- self.dataset = load_dataset("LIUM/tedlium", "release3", split="test", streaming=True, trust_remote_code=True)
86
- elif option == "GigaSpeech":
87
- self.dataset = load_dataset("speechcolab/gigaspeech", "xs", split="test", streaming=True, token=True, trust_remote_code=True)
88
- elif option == "SPGISpeech":
89
- self.dataset = load_dataset("kensho/spgispeech", "S", split="test", streaming=True, token=True, trust_remote_code=True)
90
- elif option == "AMI":
91
- self.dataset = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True, trust_remote_code=True)
92
-
93
- self._preprocess()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py DELETED
@@ -1,122 +0,0 @@
1
- # from transformers import WhisperProcessor, WhisperForConditionalGeneration
2
- from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
3
- from transformers import pipeline
4
-
5
- # import nemo.collections.asr as nemo_asr
6
-
7
- from dataset import Dataset
8
- from utils import data
9
-
10
-
11
-
12
- class Model:
13
-
14
-
15
- def __init__(self):
16
-
17
- self.options = [
18
- "openai/whisper-tiny.en",
19
- "facebook/s2t-medium-librispeech-asr",
20
- #"nvidia/stt_en_fastconformer_ctc_large"
21
- ]
22
- self.selected = None
23
- self.pipeline = None
24
- self.normalize = None
25
-
26
- def get_options(self):
27
- return self.options
28
-
29
- def load(self, option:str = None):
30
-
31
- if option is None:
32
- if self.selected is None:
33
- raise ValueError("No model selected. Please first select a model")
34
- option = self.selected
35
-
36
- if option not in self.options:
37
- raise ValueError(f"Selected Option is not a valid value, see: {self.options}")
38
-
39
- if option == "openai/whisper-tiny.en":
40
- self.pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en")
41
- self.normalize = self.pipeline.tokenizer.normalize
42
-
43
- elif option == "facebook/s2t-medium-librispeech-asr":
44
- self.model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
45
- self.processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
46
-
47
- # elif option == "nvidia/stt_en_fastconformer_ctc_large":
48
- # self.model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="nvidia/stt_en_fastconformer_ctc_large")
49
-
50
- def select(self, option:str=None):
51
- if option not in self.options:
52
- raise ValueError(f"This value is not an option, please see: {self.options}")
53
- self.selected = option
54
-
55
- def process(self, dataset:Dataset):
56
-
57
- if self.selected is None:
58
- raise ValueError("No Model is yet selected. Please select a model first")
59
-
60
- if self.selected == "openai/whisper-tiny.en":
61
- references, predictions = self._process_openai_whisper_tiny_en(dataset)
62
- elif self.selected == "facebook/s2t-medium-librispeech-asr":
63
- references, predictions = self._process_facebook_s2t_medium(dataset)
64
- # elif self.selected == "nvidia/stt_en_fastconformer_ctc_large":
65
- # references, predictions = self._process_facebook_s2t_medium(dataset)
66
-
67
- return references, predictions
68
-
69
- def _process_openai_whisper_tiny_en(self, DaTaSeT:Dataset):
70
-
71
- def normalise(batch):
72
- batch["norm_text"] = self.normalize(DaTaSeT._get_text(batch))
73
- return batch
74
-
75
- DaTaSeT.normalised(normalise)
76
- dataset = DaTaSeT.filter("norm_text")
77
-
78
- predictions = []
79
- references = []
80
-
81
- # run streamed inference
82
- for out in self.pipeline(data(dataset), batch_size=16):
83
- predictions.append(self.normalize(out["text"]))
84
- references.append(out["reference"][0])
85
-
86
- return references, predictions
87
-
88
- def _process_facebook_s2t_medium(self, DaTaSeT:Dataset):
89
-
90
- def map_to_pred(batch):
91
- features = self.processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
92
- input_features = features.input_features
93
- attention_mask = features.attention_mask
94
-
95
- gen_tokens = self.model.generate(input_features=input_features, attention_mask=attention_mask)
96
- batch["transcription"] = self.processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
97
- return batch
98
-
99
- DaTaSeT.dataset = DaTaSeT.dataset.take(100)
100
- result = DaTaSeT.dataset.map(map_to_pred, remove_columns=["audio"])
101
-
102
- predictions = []
103
- references = []
104
-
105
- DaTaSeT._check_text()
106
- text_column = DaTaSeT.text
107
-
108
- for sample in result:
109
- predictions.append(sample['transcription'])
110
- references.append(sample[text_column])
111
-
112
- return references, predictions
113
-
114
- def _process_stt_en_fastconformer_ctc_large(self, DaTaSeT:Dataset):
115
-
116
-
117
- self.model.transcribe(['2086-149220-0033.wav'])
118
-
119
- predictions = []
120
- references = []
121
-
122
- return references, predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
processing.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
2
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
3
+ import plotly.graph_objs as go
4
+ from datasets import load_dataset
5
+ from datasets import Audio
6
+ from transformers import pipeline
7
+ import evaluate
8
+ import librosa
9
+ import numpy as np
10
+
11
+ wer_metric = evaluate.load("wer")
12
+
13
+ def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str):
14
+
15
+ if data_subset is None:
16
+ raise ValueError("No Dataset selected")
17
+ if model_1 is None:
18
+ raise ValueError("No Model 1 selected")
19
+ if model_2 is None:
20
+ raise ValueError("No Model 2 selected")
21
+
22
+ if data_subset == "Common Voice":
23
+ dataset, text_column = load_Common_Voice()
24
+ elif data_subset == "VoxPopuli":
25
+ dataset, text_column = load_Vox_Populi()
26
+ elif data_subset == "OWN Recoding/Sample":
27
+ sr, audio = own_audio
28
+ audio = audio.astype(np.float32) / 32768.0
29
+ print("AUDIO: ", type(audio), audio)
30
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
31
+ else:
32
+ # if data_subset is None then still load load_Common_Voice
33
+ dataset, text_column = load_Common_Voice()
34
+ print("Dataset Loaded")
35
+
36
+ # check if models are the same
37
+ model1, processor1 = load_model(model_1)
38
+ model2, processor2 = load_model(model_2)
39
+ print("Models Loaded")
40
+
41
+ if data_subset == "OWN Recoding/Sample":
42
+ sample = {"audio":{"array":audio,"sampling_rate":16000}}
43
+ transcription1 = model_compute(model1, processor1, sample, model_1)
44
+ transcription2 = model_compute(model2, processor2, sample, model_2)
45
+
46
+ transcriptions1 = [transcription1]
47
+ transcriptions2 = [transcription2]
48
+ references = [own_transcription]
49
+
50
+ wer1 = compute_wer(references, transcriptions1)
51
+ wer2 = compute_wer(references, transcriptions2)
52
+
53
+ results_md = f"""#### {model_1}
54
+ - WER Score: {wer1}
55
+
56
+ #### {model_2}
57
+ - WER Score: {wer2}"""
58
+
59
+ # Create the bar plot
60
+ fig = go.Figure(
61
+ data=[
62
+ go.Bar(x=[f"{model_1}"], y=[wer1]),
63
+ go.Bar(x=[f"{model_2}"], y=[wer2]),
64
+ ]
65
+ )
66
+ # Update the layout for better visualization
67
+ fig.update_layout(
68
+ title="Comparison of Two Models",
69
+ xaxis_title="Models",
70
+ yaxis_title="Value",
71
+ barmode="group",
72
+ )
73
+
74
+ yield results_md, fig
75
+
76
+ else:
77
+ references = []
78
+ transcriptions1 = []
79
+ transcriptions2 = []
80
+ counter = 0
81
+ for sample in dataset:
82
+ print(counter)
83
+ counter += 1
84
+
85
+ references.append(sample[text_column])
86
+
87
+ if model_1 == model_2:
88
+ transcription = model_compute(model1, processor1, sample, model_1)
89
+
90
+ transcriptions1.append(transcription)
91
+ transcriptions2.append(transcription)
92
+ else:
93
+ transcriptions1.append(model_compute(model1, processor1, sample, model_1))
94
+ transcriptions2.append(model_compute(model2, processor2, sample, model_2))
95
+
96
+
97
+ wer1 = compute_wer(references, transcriptions1)
98
+ wer2 = compute_wer(references, transcriptions2)
99
+
100
+ results_md = f"""#### {model_1}
101
+ - WER Score: {wer1}
102
+
103
+ #### {model_2}
104
+ - WER Score: {wer2}"""
105
+
106
+ # Create the bar plot
107
+ fig = go.Figure(
108
+ data=[
109
+ go.Bar(x=[f"{model_1}"], y=[wer1]),
110
+ go.Bar(x=[f"{model_2}"], y=[wer2]),
111
+ ]
112
+ )
113
+
114
+ # Update the layout for better visualization
115
+ fig.update_layout(
116
+ title="Comparison of Two Models",
117
+ xaxis_title="Models",
118
+ yaxis_title="Value",
119
+ barmode="group",
120
+ )
121
+
122
+ yield results_md, fig
123
+
124
+
125
+
126
+
127
+
128
+
129
+ # DATASET LOADERS
130
+ def load_Common_Voice():
131
+ dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
132
+ text_column = "sentence"
133
+ dataset = dataset.take(100)
134
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
135
+ dataset = list(dataset)
136
+ return dataset, text_column
137
+
138
+ def load_Vox_Populi():
139
+ dataset = dataset = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True, trust_remote_code=True)
140
+ print(next(iter(dataset)))
141
+ text_column = "raw_text"
142
+ dataset = dataset.take(100)
143
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
144
+ dataset = list(dataset)
145
+ return dataset, text_column
146
+
147
+
148
+
149
+
150
+ # MODEL LOADERS
151
+ def load_model(model_id:str):
152
+ if model_id == "openai/whisper-tiny.en":
153
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
154
+ processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
155
+ elif model_id == "facebook/s2t-medium-librispeech-asr":
156
+ model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
157
+ processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
158
+ else:
159
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
160
+ processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
161
+
162
+ return model, processor
163
+
164
+
165
+ # MODEL INFERENCE
166
+ def model_compute(model, processor, sample, model_id):
167
+
168
+ if model_id == "openai/whisper-tiny.en":
169
+ sample = sample["audio"]
170
+ input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
171
+ predicted_ids = model.generate(input_features)
172
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
173
+ return transcription[0]
174
+ elif model_id == "facebook/s2t-medium-librispeech-asr":
175
+ sample = sample["audio"]
176
+ features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
177
+ input_features = features.input_features
178
+ attention_mask = features.attention_mask
179
+ gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
180
+ transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
181
+ return transcription[0]
182
+
183
+ else:
184
+ return model(sample)
185
+
186
+ # UTILS
187
+ def compute_wer(references, predictions):
188
+ wer = wer_metric.compute(references=references, predictions=predictions)
189
+ wer = round(100 * wer, 2)
190
+ return wer
191
+
192
+
193
+ # print(load_Vox_Populi())
194
+ # print(run("Common Voice", "openai/whisper-tiny.en", "openai/whisper-tiny.en", None, None))
utils.py CHANGED
@@ -22,3 +22,14 @@ def compute_wer(references, predictions):
22
  return wer
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  return wer
23
 
24
 
25
+ # def run_tests (dataset_choice:str, model:str):
26
+
27
+ # MoDeL = Model()
28
+ # MoDeL.select(model)
29
+ # MoDeL.load()
30
+ # DaTaSeT = Dataset(100)
31
+ # DaTaSeT.load(dataset_choice)
32
+ # references, predictions = MoDeL.process(DaTaSeT)
33
+ # wer = compute_wer(references=references, predictions=predictions)
34
+ # return wer
35
+