j-tobias commited on
Commit
8414736
Β·
1 Parent(s): d521dce

new model + new dataset

Browse files
Files changed (4) hide show
  1. __pycache__/processing.cpython-310.pyc +0 -0
  2. app.py +12 -12
  3. cards.txt +9 -0
  4. processing.py +43 -16
__pycache__/processing.cpython-310.pyc ADDED
Binary file (6.05 kB). View file
 
app.py CHANGED
@@ -11,12 +11,12 @@ import os
11
  hf_token = os.getenv("HF_Token")
12
  login(hf_token)
13
 
14
- def hf_login():
15
- hf_token = os.getenv("HF_Token")
16
- if hf_token is None:
17
- with open("credentials.json", "r") as f:
18
- hf_token = json.load(f)["token"]
19
- login(token=hf_token, add_to_git_credential=True)
20
 
21
  # hf_login()
22
 
@@ -25,8 +25,8 @@ def hf_login():
25
 
26
 
27
  # GENERAL OPTIONS FOR MODELS AND DATASETS
28
- MODEL_OPTIONS = ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr"]
29
- DATASET_OPTIONS = ["Common Voice", "OWN Recoding/Sample"]
30
 
31
  # HELPER FUNCTIONS
32
  def get_card(selected_model:str)->str:
@@ -48,7 +48,7 @@ def is_own(selected_option):
48
  return gr.update(visible=False), gr.update(visible=False)
49
 
50
  def make_visible():
51
- return gr.update(visible=True), gr.update(visible=True)
52
 
53
 
54
 
@@ -83,7 +83,7 @@ Happy experimenting and comparing! πŸš€""")
83
  choices=DATASET_OPTIONS,
84
  label="Data subset / Own Sample",
85
  )
86
- own_audio = gr.Audio(visible=False)
87
  own_transcription = gr.TextArea(lines=2, visible=False)
88
  data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
89
  with gr.Column(scale=1):
@@ -116,7 +116,7 @@ Happy experimenting and comparing! πŸš€""")
116
  variant="primary",
117
  size="sm")
118
 
119
- gr.Markdown('## <p style="text-align: center;">Results</p>')
120
  results_md = gr.Markdown("")
121
  results_plot = gr.Plot(show_label=False, visible=False)
122
  results_df = gr.DataFrame(
@@ -125,7 +125,7 @@ Happy experimenting and comparing! πŸš€""")
125
  interactive=False, # Allow users to interact with the DataFrame
126
  wrap=True, # Ensure text wraps to multiple lines
127
  )
128
- eval_btn.click(make_visible, outputs=[results_plot, results_df])
129
  eval_btn.click(run, [data_subset, model_1, model_2, own_audio, own_transcription], [results_md, results_plot, results_df], show_progress=False)
130
 
131
  demo.launch(debug=True)
 
11
  hf_token = os.getenv("HF_Token")
12
  login(hf_token)
13
 
14
+ # def hf_login():
15
+ # hf_token = os.getenv("HF_Token")
16
+ # if hf_token is None:
17
+ # with open("credentials.json", "r") as f:
18
+ # hf_token = json.load(f)["token"]
19
+ # login(token=hf_token, add_to_git_credential=True)
20
 
21
  # hf_login()
22
 
 
25
 
26
 
27
  # GENERAL OPTIONS FOR MODELS AND DATASETS
28
+ MODEL_OPTIONS = ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr", "facebook/wav2vec2-base-960h"]
29
+ DATASET_OPTIONS = ["Common Voice", "Librispeech ASR clean", "OWN Recoding/Sample"]
30
 
31
  # HELPER FUNCTIONS
32
  def get_card(selected_model:str)->str:
 
48
  return gr.update(visible=False), gr.update(visible=False)
49
 
50
  def make_visible():
51
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
52
 
53
 
54
 
 
83
  choices=DATASET_OPTIONS,
84
  label="Data subset / Own Sample",
85
  )
86
+ own_audio = gr.Audio(sources=['microphone'], visible=False)
87
  own_transcription = gr.TextArea(lines=2, visible=False)
88
  data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
89
  with gr.Column(scale=1):
 
116
  variant="primary",
117
  size="sm")
118
 
119
+ results_title = gr.Markdown('## <p style="text-align: center;">Results</p>', visible=False)
120
  results_md = gr.Markdown("")
121
  results_plot = gr.Plot(show_label=False, visible=False)
122
  results_df = gr.DataFrame(
 
125
  interactive=False, # Allow users to interact with the DataFrame
126
  wrap=True, # Ensure text wraps to multiple lines
127
  )
128
+ eval_btn.click(make_visible, outputs=[results_plot, results_df, results_title])
129
  eval_btn.click(run, [data_subset, model_1, model_2, own_audio, own_transcription], [results_md, results_plot, results_df], show_progress=False)
130
 
131
  demo.launch(debug=True)
cards.txt CHANGED
@@ -15,4 +15,13 @@
15
  - Model Size: 71.2 M Parameters
16
  - Model Paper: [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171)
17
  - Training Data: [LibriSpeech ASR Corpus](https://www.openslr.org/12)
 
 
 
 
 
 
 
 
 
18
  @@
 
15
  - Model Size: 71.2 M Parameters
16
  - Model Paper: [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171)
17
  - Training Data: [LibriSpeech ASR Corpus](https://www.openslr.org/12)
18
+ @@
19
+ ####
20
+ - ID: facebook/wav2vec2-base-960h
21
+ - Hugging Face: [model](https://huggingface.co/facebook/wav2vec2-base-960h)
22
+ - Creator: facebook
23
+ - Finetuned: No
24
+ - Model Size: 94.4 M Parameters
25
+ - Model Paper: [Wav2vec 2.0: Learning the structure of speech from raw audio](https://ai.meta.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
26
+ - Training Data: ?
27
  @@
processing.py CHANGED
@@ -1,10 +1,12 @@
1
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
2
  from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
 
3
  import plotly.graph_objs as go
4
  from datasets import load_dataset
5
  from datasets import Audio
6
  import evaluate
7
  import librosa
 
8
  import numpy as np
9
  import pandas as pd
10
 
@@ -25,6 +27,8 @@ def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:
25
  dataset, text_column = load_Common_Voice()
26
  elif data_subset == "VoxPopuli":
27
  dataset, text_column = load_Vox_Populi()
 
 
28
  elif data_subset == "OWN Recoding/Sample":
29
  sr, audio = own_audio
30
  audio = audio.astype(np.float32) / 32768.0
@@ -50,8 +54,8 @@ def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:
50
  transcriptions2 = [transcription2]
51
  references = [own_transcription]
52
 
53
- wer1 = compute_wer(references, transcriptions1)
54
- wer2 = compute_wer(references, transcriptions2)
55
 
56
  results_md = f"""
57
  #### {model_1}
@@ -113,16 +117,16 @@ def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:
113
  {i}/{len(dataset)}-{'#'*i}{'_'*(N_SAMPLES-i)}
114
 
115
  #### {model_1}
116
- - WER Score: {sum(WER1s)/N_SAMPLES}
117
 
118
  #### {model_2}
119
- - WER Score: {sum(WER2s)/N_SAMPLES}"""
120
-
121
  # Create the bar plot
122
  fig = go.Figure(
123
  data=[
124
- go.Bar(x=[f"{model_1}"], y=[sum(WER1s)/N_SAMPLES], showlegend=False),
125
- go.Bar(x=[f"{model_2}"], y=[sum(WER2s)/N_SAMPLES], showlegend=False),
126
  ]
127
  )
128
 
@@ -148,6 +152,8 @@ def load_Common_Voice():
148
  dataset = dataset.take(N_SAMPLES)
149
  dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
150
  dataset = list(dataset)
 
 
151
  return dataset, text_column
152
 
153
  def load_Vox_Populi():
@@ -174,6 +180,17 @@ def load_Vox_Populi():
174
  dataset = list(dataset)
175
  return dataset, text_column
176
 
 
 
 
 
 
 
 
 
 
 
 
177
  def is_valid_sample(text, audio):
178
  # Check if 'normalized_text' is valid
179
  text = text.strip()
@@ -200,6 +217,9 @@ def load_model(model_id:str):
200
  elif model_id == "facebook/s2t-medium-librispeech-asr":
201
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
202
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr")
 
 
 
203
  else:
204
  model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
205
  processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
@@ -215,25 +235,32 @@ def model_compute(model, processor, sample, model_id):
215
  input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
216
  predicted_ids = model.generate(input_features)
217
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
218
- return transcription[0]
 
219
  elif model_id == "facebook/s2t-medium-librispeech-asr":
220
  sample = sample["audio"]
221
  features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
222
  input_features = features.input_features
223
  attention_mask = features.attention_mask
224
  gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
225
- transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
226
- return transcription
227
-
 
 
 
 
 
 
228
  else:
229
- return model(sample)
 
 
 
 
230
 
231
  # UTILS
232
  def compute_wer(references, predictions):
233
  wer = wer_metric.compute(references=references, predictions=predictions)
234
- wer = round(N_SAMPLES * wer, 2)
235
  return wer
236
 
237
-
238
- # print(load_Vox_Populi())
239
- # print(run("Common Voice", "openai/whisper-tiny.en", "openai/whisper-tiny.en", None, None))
 
1
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
2
  from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
4
  import plotly.graph_objs as go
5
  from datasets import load_dataset
6
  from datasets import Audio
7
  import evaluate
8
  import librosa
9
+ import torch
10
  import numpy as np
11
  import pandas as pd
12
 
 
27
  dataset, text_column = load_Common_Voice()
28
  elif data_subset == "VoxPopuli":
29
  dataset, text_column = load_Vox_Populi()
30
+ elif data_subset == "Librispeech ASR clean":
31
+ dataset, text_column = load_Librispeech_ASR_clean()
32
  elif data_subset == "OWN Recoding/Sample":
33
  sr, audio = own_audio
34
  audio = audio.astype(np.float32) / 32768.0
 
54
  transcriptions2 = [transcription2]
55
  references = [own_transcription]
56
 
57
+ wer1 = round(N_SAMPLES * compute_wer(references, transcriptions1), 2)
58
+ wer2 = round(N_SAMPLES * compute_wer(references, transcriptions2), 2)
59
 
60
  results_md = f"""
61
  #### {model_1}
 
117
  {i}/{len(dataset)}-{'#'*i}{'_'*(N_SAMPLES-i)}
118
 
119
  #### {model_1}
120
+ - WER Score: {round(sum(WER1s)/len(WER1s), 2)}
121
 
122
  #### {model_2}
123
+ - WER Score: {round(sum(WER2s)/len(WER2s), 2)}"""
124
+
125
  # Create the bar plot
126
  fig = go.Figure(
127
  data=[
128
+ go.Bar(x=[f"{model_1}"], y=[sum(WER1s)/len(WER1s)], showlegend=False),
129
+ go.Bar(x=[f"{model_2}"], y=[sum(WER2s)/len(WER2s)], showlegend=False),
130
  ]
131
  )
132
 
 
152
  dataset = dataset.take(N_SAMPLES)
153
  dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
154
  dataset = list(dataset)
155
+ for sample in dataset:
156
+ sample["text"] = sample["text"].lower()
157
  return dataset, text_column
158
 
159
  def load_Vox_Populi():
 
180
  dataset = list(dataset)
181
  return dataset, text_column
182
 
183
+ def load_Librispeech_ASR_clean():
184
+ dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True, token=True, trust_remote_code=True)
185
+ print(next(iter(dataset)))
186
+ text_column = "text"
187
+ dataset = dataset.take(N_SAMPLES)
188
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
189
+ dataset = list(dataset)
190
+ for sample in dataset:
191
+ sample["text"] = sample["text"].lower()
192
+ return dataset, text_column
193
+
194
  def is_valid_sample(text, audio):
195
  # Check if 'normalized_text' is valid
196
  text = text.strip()
 
217
  elif model_id == "facebook/s2t-medium-librispeech-asr":
218
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
219
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr")
220
+ elif model_id == "facebook/wav2vec2-base-960h":
221
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
222
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
223
  else:
224
  model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
225
  processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
 
235
  input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
236
  predicted_ids = model.generate(input_features)
237
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
238
+ transcription = processor.tokenizer.normalize(transcription[0])
239
+ return transcription
240
  elif model_id == "facebook/s2t-medium-librispeech-asr":
241
  sample = sample["audio"]
242
  features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
243
  input_features = features.input_features
244
  attention_mask = features.attention_mask
245
  gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
246
+ transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)
247
+ return transcription[0]
248
+ elif model_id == "facebook/wav2vec2-base-960h":
249
+ sample = sample["audio"]
250
+ input_values = processor(sample["array"], sampling_rate=16000, return_tensors="pt", padding="longest").input_values # Batch size 1
251
+ logits = model(input_values).logits
252
+ predicted_ids = torch.argmax(logits, dim=-1)
253
+ transcription = processor.batch_decode(predicted_ids)
254
+ return transcription[0].lower()
255
  else:
256
+ sample = sample["audio"]
257
+ input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
258
+ predicted_ids = model.generate(input_features)
259
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
260
+ return transcription[0]
261
 
262
  # UTILS
263
  def compute_wer(references, predictions):
264
  wer = wer_metric.compute(references=references, predictions=predictions)
 
265
  return wer
266