j-tobias
		
	commited on
		
		
					Commit 
							
							Β·
						
						8414736
	
1
								Parent(s):
							
							d521dce
								
new model + new dataset
Browse files- __pycache__/processing.cpython-310.pyc +0 -0
- app.py +12 -12
- cards.txt +9 -0
- processing.py +43 -16
    	
        __pycache__/processing.cpython-310.pyc
    ADDED
    
    | Binary file (6.05 kB). View file | 
|  | 
    	
        app.py
    CHANGED
    
    | @@ -11,12 +11,12 @@ import os | |
| 11 | 
             
            hf_token = os.getenv("HF_Token")
         | 
| 12 | 
             
            login(hf_token)
         | 
| 13 |  | 
| 14 | 
            -
            def hf_login():
         | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 |  | 
| 21 | 
             
            # hf_login()
         | 
| 22 |  | 
| @@ -25,8 +25,8 @@ def hf_login(): | |
| 25 |  | 
| 26 |  | 
| 27 | 
             
            # GENERAL OPTIONS FOR MODELS AND DATASETS
         | 
| 28 | 
            -
            MODEL_OPTIONS = ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr"]
         | 
| 29 | 
            -
            DATASET_OPTIONS = ["Common Voice", "OWN Recoding/Sample"]
         | 
| 30 |  | 
| 31 | 
             
            # HELPER FUNCTIONS
         | 
| 32 | 
             
            def get_card(selected_model:str)->str:
         | 
| @@ -48,7 +48,7 @@ def is_own(selected_option): | |
| 48 | 
             
                    return gr.update(visible=False), gr.update(visible=False)
         | 
| 49 |  | 
| 50 | 
             
            def make_visible():
         | 
| 51 | 
            -
                return gr.update(visible=True), gr.update(visible=True)
         | 
| 52 |  | 
| 53 |  | 
| 54 |  | 
| @@ -83,7 +83,7 @@ Happy experimenting and comparing! π""") | |
| 83 | 
             
                            choices=DATASET_OPTIONS,
         | 
| 84 | 
             
                            label="Data subset / Own Sample",
         | 
| 85 | 
             
                        )
         | 
| 86 | 
            -
                        own_audio = gr.Audio(visible=False)
         | 
| 87 | 
             
                        own_transcription = gr.TextArea(lines=2, visible=False)
         | 
| 88 | 
             
                        data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
         | 
| 89 | 
             
                    with gr.Column(scale=1):
         | 
| @@ -116,7 +116,7 @@ Happy experimenting and comparing! π""") | |
| 116 | 
             
                    variant="primary",
         | 
| 117 | 
             
                    size="sm")
         | 
| 118 |  | 
| 119 | 
            -
                gr.Markdown('## <p style="text-align: center;">Results</p>')
         | 
| 120 | 
             
                results_md = gr.Markdown("")
         | 
| 121 | 
             
                results_plot = gr.Plot(show_label=False, visible=False)
         | 
| 122 | 
             
                results_df = gr.DataFrame(
         | 
| @@ -125,7 +125,7 @@ Happy experimenting and comparing! π""") | |
| 125 | 
             
                    interactive=False,  # Allow users to interact with the DataFrame
         | 
| 126 | 
             
                    wrap=True,  # Ensure text wraps to multiple lines
         | 
| 127 | 
             
                )
         | 
| 128 | 
            -
                eval_btn.click(make_visible, outputs=[results_plot, results_df])
         | 
| 129 | 
             
                eval_btn.click(run, [data_subset, model_1, model_2, own_audio, own_transcription], [results_md, results_plot, results_df], show_progress=False)
         | 
| 130 |  | 
| 131 | 
             
            demo.launch(debug=True)
         | 
|  | |
| 11 | 
             
            hf_token = os.getenv("HF_Token")
         | 
| 12 | 
             
            login(hf_token)
         | 
| 13 |  | 
| 14 | 
            +
            # def hf_login():
         | 
| 15 | 
            +
            #     hf_token = os.getenv("HF_Token")
         | 
| 16 | 
            +
            #     if hf_token is None:
         | 
| 17 | 
            +
            #         with open("credentials.json", "r") as f:
         | 
| 18 | 
            +
            #             hf_token = json.load(f)["token"]
         | 
| 19 | 
            +
            #     login(token=hf_token, add_to_git_credential=True)
         | 
| 20 |  | 
| 21 | 
             
            # hf_login()
         | 
| 22 |  | 
|  | |
| 25 |  | 
| 26 |  | 
| 27 | 
             
            # GENERAL OPTIONS FOR MODELS AND DATASETS
         | 
| 28 | 
            +
            MODEL_OPTIONS = ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr", "facebook/wav2vec2-base-960h"]
         | 
| 29 | 
            +
            DATASET_OPTIONS = ["Common Voice", "Librispeech ASR clean", "OWN Recoding/Sample"]
         | 
| 30 |  | 
| 31 | 
             
            # HELPER FUNCTIONS
         | 
| 32 | 
             
            def get_card(selected_model:str)->str:
         | 
|  | |
| 48 | 
             
                    return gr.update(visible=False), gr.update(visible=False)
         | 
| 49 |  | 
| 50 | 
             
            def make_visible():
         | 
| 51 | 
            +
                return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
         | 
| 52 |  | 
| 53 |  | 
| 54 |  | 
|  | |
| 83 | 
             
                            choices=DATASET_OPTIONS,
         | 
| 84 | 
             
                            label="Data subset / Own Sample",
         | 
| 85 | 
             
                        )
         | 
| 86 | 
            +
                        own_audio = gr.Audio(sources=['microphone'], visible=False)
         | 
| 87 | 
             
                        own_transcription = gr.TextArea(lines=2, visible=False)
         | 
| 88 | 
             
                        data_subset.change(is_own, inputs=[data_subset], outputs=[own_audio, own_transcription])
         | 
| 89 | 
             
                    with gr.Column(scale=1):
         | 
|  | |
| 116 | 
             
                    variant="primary",
         | 
| 117 | 
             
                    size="sm")
         | 
| 118 |  | 
| 119 | 
            +
                results_title = gr.Markdown('## <p style="text-align: center;">Results</p>', visible=False)
         | 
| 120 | 
             
                results_md = gr.Markdown("")
         | 
| 121 | 
             
                results_plot = gr.Plot(show_label=False, visible=False)
         | 
| 122 | 
             
                results_df = gr.DataFrame(
         | 
|  | |
| 125 | 
             
                    interactive=False,  # Allow users to interact with the DataFrame
         | 
| 126 | 
             
                    wrap=True,  # Ensure text wraps to multiple lines
         | 
| 127 | 
             
                )
         | 
| 128 | 
            +
                eval_btn.click(make_visible, outputs=[results_plot, results_df, results_title])
         | 
| 129 | 
             
                eval_btn.click(run, [data_subset, model_1, model_2, own_audio, own_transcription], [results_md, results_plot, results_df], show_progress=False)
         | 
| 130 |  | 
| 131 | 
             
            demo.launch(debug=True)
         | 
    	
        cards.txt
    CHANGED
    
    | @@ -15,4 +15,13 @@ | |
| 15 | 
             
            - Model Size: 71.2 M Parameters
         | 
| 16 | 
             
            - Model Paper: [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171)
         | 
| 17 | 
             
            - Training Data: [LibriSpeech ASR Corpus](https://www.openslr.org/12)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 18 | 
             
            @@
         | 
|  | |
| 15 | 
             
            - Model Size: 71.2 M Parameters
         | 
| 16 | 
             
            - Model Paper: [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171)
         | 
| 17 | 
             
            - Training Data: [LibriSpeech ASR Corpus](https://www.openslr.org/12)
         | 
| 18 | 
            +
            @@
         | 
| 19 | 
            +
            ####
         | 
| 20 | 
            +
            - ID: facebook/wav2vec2-base-960h
         | 
| 21 | 
            +
            - Hugging Face: [model](https://huggingface.co/facebook/wav2vec2-base-960h)
         | 
| 22 | 
            +
            - Creator: facebook
         | 
| 23 | 
            +
            - Finetuned: No
         | 
| 24 | 
            +
            - Model Size: 94.4 M Parameters
         | 
| 25 | 
            +
            - Model Paper: [Wav2vec 2.0: Learning the structure of speech from raw audio](https://ai.meta.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
         | 
| 26 | 
            +
            - Training Data: ?
         | 
| 27 | 
             
            @@
         | 
    	
        processing.py
    CHANGED
    
    | @@ -1,10 +1,12 @@ | |
| 1 | 
             
            from transformers import WhisperProcessor, WhisperForConditionalGeneration
         | 
| 2 | 
             
            from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
         | 
|  | |
| 3 | 
             
            import plotly.graph_objs as go
         | 
| 4 | 
             
            from datasets import load_dataset
         | 
| 5 | 
             
            from datasets import Audio
         | 
| 6 | 
             
            import evaluate
         | 
| 7 | 
             
            import librosa
         | 
|  | |
| 8 | 
             
            import numpy as np
         | 
| 9 | 
             
            import pandas as pd
         | 
| 10 |  | 
| @@ -25,6 +27,8 @@ def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription: | |
| 25 | 
             
                    dataset, text_column = load_Common_Voice()
         | 
| 26 | 
             
                elif data_subset == "VoxPopuli":
         | 
| 27 | 
             
                    dataset, text_column = load_Vox_Populi()
         | 
|  | |
|  | |
| 28 | 
             
                elif data_subset == "OWN Recoding/Sample":
         | 
| 29 | 
             
                    sr, audio = own_audio
         | 
| 30 | 
             
                    audio = audio.astype(np.float32) / 32768.0
         | 
| @@ -50,8 +54,8 @@ def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription: | |
| 50 | 
             
                    transcriptions2 = [transcription2]
         | 
| 51 | 
             
                    references = [own_transcription]
         | 
| 52 |  | 
| 53 | 
            -
                    wer1 = compute_wer(references, transcriptions1)
         | 
| 54 | 
            -
                    wer2 = compute_wer(references, transcriptions2)
         | 
| 55 |  | 
| 56 | 
             
                    results_md = f"""
         | 
| 57 | 
             
                    #### {model_1} 
         | 
| @@ -113,16 +117,16 @@ def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription: | |
| 113 | 
             
                        {i}/{len(dataset)}-{'#'*i}{'_'*(N_SAMPLES-i)}
         | 
| 114 |  | 
| 115 | 
             
                        #### {model_1} 
         | 
| 116 | 
            -
                        - WER Score: {sum(WER1s)/ | 
| 117 |  | 
| 118 | 
             
                        #### {model_2} 
         | 
| 119 | 
            -
                        - WER Score: {sum(WER2s)/ | 
| 120 | 
            -
             | 
| 121 | 
             
                        # Create the bar plot
         | 
| 122 | 
             
                        fig = go.Figure(
         | 
| 123 | 
             
                            data=[
         | 
| 124 | 
            -
                                go.Bar(x=[f"{model_1}"], y=[sum(WER1s)/ | 
| 125 | 
            -
                                go.Bar(x=[f"{model_2}"], y=[sum(WER2s)/ | 
| 126 | 
             
                            ]
         | 
| 127 | 
             
                        )
         | 
| 128 |  | 
| @@ -148,6 +152,8 @@ def load_Common_Voice(): | |
| 148 | 
             
                dataset = dataset.take(N_SAMPLES)
         | 
| 149 | 
             
                dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
         | 
| 150 | 
             
                dataset = list(dataset)
         | 
|  | |
|  | |
| 151 | 
             
                return dataset, text_column
         | 
| 152 |  | 
| 153 | 
             
            def load_Vox_Populi():
         | 
| @@ -174,6 +180,17 @@ def load_Vox_Populi(): | |
| 174 | 
             
                dataset = list(dataset)
         | 
| 175 | 
             
                return dataset, text_column
         | 
| 176 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 177 | 
             
            def is_valid_sample(text, audio):
         | 
| 178 | 
             
                # Check if 'normalized_text' is valid
         | 
| 179 | 
             
                text = text.strip()
         | 
| @@ -200,6 +217,9 @@ def load_model(model_id:str): | |
| 200 | 
             
                elif model_id == "facebook/s2t-medium-librispeech-asr":
         | 
| 201 | 
             
                    model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
         | 
| 202 | 
             
                    processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr")
         | 
|  | |
|  | |
|  | |
| 203 | 
             
                else:
         | 
| 204 | 
             
                    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
         | 
| 205 | 
             
                    processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
         | 
| @@ -215,25 +235,32 @@ def model_compute(model, processor, sample, model_id): | |
| 215 | 
             
                    input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
         | 
| 216 | 
             
                    predicted_ids = model.generate(input_features)
         | 
| 217 | 
             
                    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
         | 
| 218 | 
            -
                     | 
|  | |
| 219 | 
             
                elif model_id == "facebook/s2t-medium-librispeech-asr":
         | 
| 220 | 
             
                    sample = sample["audio"]
         | 
| 221 | 
             
                    features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
         | 
| 222 | 
             
                    input_features = features.input_features
         | 
| 223 | 
             
                    attention_mask = features.attention_mask
         | 
| 224 | 
             
                    gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
         | 
| 225 | 
            -
                    transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True) | 
| 226 | 
            -
                    return transcription
         | 
| 227 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 228 | 
             
                else:
         | 
| 229 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 230 |  | 
| 231 | 
             
            # UTILS
         | 
| 232 | 
             
            def compute_wer(references, predictions):
         | 
| 233 | 
             
                wer = wer_metric.compute(references=references, predictions=predictions)
         | 
| 234 | 
            -
                wer = round(N_SAMPLES * wer, 2)
         | 
| 235 | 
             
                return wer
         | 
| 236 |  | 
| 237 | 
            -
             | 
| 238 | 
            -
            # print(load_Vox_Populi())
         | 
| 239 | 
            -
            # print(run("Common Voice", "openai/whisper-tiny.en", "openai/whisper-tiny.en", None, None))
         | 
|  | |
| 1 | 
             
            from transformers import WhisperProcessor, WhisperForConditionalGeneration
         | 
| 2 | 
             
            from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
         | 
| 3 | 
            +
            from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
         | 
| 4 | 
             
            import plotly.graph_objs as go
         | 
| 5 | 
             
            from datasets import load_dataset
         | 
| 6 | 
             
            from datasets import Audio
         | 
| 7 | 
             
            import evaluate
         | 
| 8 | 
             
            import librosa
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
             
            import numpy as np
         | 
| 11 | 
             
            import pandas as pd
         | 
| 12 |  | 
|  | |
| 27 | 
             
                    dataset, text_column = load_Common_Voice()
         | 
| 28 | 
             
                elif data_subset == "VoxPopuli":
         | 
| 29 | 
             
                    dataset, text_column = load_Vox_Populi()
         | 
| 30 | 
            +
                elif data_subset == "Librispeech ASR clean":
         | 
| 31 | 
            +
                    dataset, text_column = load_Librispeech_ASR_clean()
         | 
| 32 | 
             
                elif data_subset == "OWN Recoding/Sample":
         | 
| 33 | 
             
                    sr, audio = own_audio
         | 
| 34 | 
             
                    audio = audio.astype(np.float32) / 32768.0
         | 
|  | |
| 54 | 
             
                    transcriptions2 = [transcription2]
         | 
| 55 | 
             
                    references = [own_transcription]
         | 
| 56 |  | 
| 57 | 
            +
                    wer1 = round(N_SAMPLES * compute_wer(references, transcriptions1), 2)
         | 
| 58 | 
            +
                    wer2 = round(N_SAMPLES * compute_wer(references, transcriptions2), 2)
         | 
| 59 |  | 
| 60 | 
             
                    results_md = f"""
         | 
| 61 | 
             
                    #### {model_1} 
         | 
|  | |
| 117 | 
             
                        {i}/{len(dataset)}-{'#'*i}{'_'*(N_SAMPLES-i)}
         | 
| 118 |  | 
| 119 | 
             
                        #### {model_1} 
         | 
| 120 | 
            +
                        - WER Score: {round(sum(WER1s)/len(WER1s), 2)}
         | 
| 121 |  | 
| 122 | 
             
                        #### {model_2} 
         | 
| 123 | 
            +
                        - WER Score: {round(sum(WER2s)/len(WER2s), 2)}"""
         | 
| 124 | 
            +
                        
         | 
| 125 | 
             
                        # Create the bar plot
         | 
| 126 | 
             
                        fig = go.Figure(
         | 
| 127 | 
             
                            data=[
         | 
| 128 | 
            +
                                go.Bar(x=[f"{model_1}"], y=[sum(WER1s)/len(WER1s)], showlegend=False),
         | 
| 129 | 
            +
                                go.Bar(x=[f"{model_2}"], y=[sum(WER2s)/len(WER2s)], showlegend=False),
         | 
| 130 | 
             
                            ]
         | 
| 131 | 
             
                        )
         | 
| 132 |  | 
|  | |
| 152 | 
             
                dataset = dataset.take(N_SAMPLES)
         | 
| 153 | 
             
                dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
         | 
| 154 | 
             
                dataset = list(dataset)
         | 
| 155 | 
            +
                for sample in dataset:
         | 
| 156 | 
            +
                    sample["text"] = sample["text"].lower()
         | 
| 157 | 
             
                return dataset, text_column
         | 
| 158 |  | 
| 159 | 
             
            def load_Vox_Populi():
         | 
|  | |
| 180 | 
             
                dataset = list(dataset)
         | 
| 181 | 
             
                return dataset, text_column
         | 
| 182 |  | 
| 183 | 
            +
            def load_Librispeech_ASR_clean():
         | 
| 184 | 
            +
                dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True, token=True, trust_remote_code=True)
         | 
| 185 | 
            +
                print(next(iter(dataset)))
         | 
| 186 | 
            +
                text_column = "text"
         | 
| 187 | 
            +
                dataset = dataset.take(N_SAMPLES)
         | 
| 188 | 
            +
                dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
         | 
| 189 | 
            +
                dataset = list(dataset)
         | 
| 190 | 
            +
                for sample in dataset:
         | 
| 191 | 
            +
                    sample["text"] = sample["text"].lower()
         | 
| 192 | 
            +
                return dataset, text_column
         | 
| 193 | 
            +
             | 
| 194 | 
             
            def is_valid_sample(text, audio):
         | 
| 195 | 
             
                # Check if 'normalized_text' is valid
         | 
| 196 | 
             
                text = text.strip()
         | 
|  | |
| 217 | 
             
                elif model_id == "facebook/s2t-medium-librispeech-asr":
         | 
| 218 | 
             
                    model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
         | 
| 219 | 
             
                    processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr")
         | 
| 220 | 
            +
                elif model_id == "facebook/wav2vec2-base-960h":
         | 
| 221 | 
            +
                    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
         | 
| 222 | 
            +
                    model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")  
         | 
| 223 | 
             
                else:
         | 
| 224 | 
             
                    model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
         | 
| 225 | 
             
                    processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
         | 
|  | |
| 235 | 
             
                    input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
         | 
| 236 | 
             
                    predicted_ids = model.generate(input_features)
         | 
| 237 | 
             
                    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
         | 
| 238 | 
            +
                    transcription = processor.tokenizer.normalize(transcription[0])
         | 
| 239 | 
            +
                    return transcription
         | 
| 240 | 
             
                elif model_id == "facebook/s2t-medium-librispeech-asr":
         | 
| 241 | 
             
                    sample = sample["audio"]
         | 
| 242 | 
             
                    features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
         | 
| 243 | 
             
                    input_features = features.input_features
         | 
| 244 | 
             
                    attention_mask = features.attention_mask
         | 
| 245 | 
             
                    gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
         | 
| 246 | 
            +
                    transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)
         | 
| 247 | 
            +
                    return transcription[0]
         | 
| 248 | 
            +
                elif model_id == "facebook/wav2vec2-base-960h":
         | 
| 249 | 
            +
                    sample = sample["audio"]
         | 
| 250 | 
            +
                    input_values = processor(sample["array"], sampling_rate=16000, return_tensors="pt", padding="longest").input_values  # Batch size 1
         | 
| 251 | 
            +
                    logits = model(input_values).logits
         | 
| 252 | 
            +
                    predicted_ids = torch.argmax(logits, dim=-1)
         | 
| 253 | 
            +
                    transcription = processor.batch_decode(predicted_ids)
         | 
| 254 | 
            +
                    return transcription[0].lower()
         | 
| 255 | 
             
                else:
         | 
| 256 | 
            +
                    sample = sample["audio"]
         | 
| 257 | 
            +
                    input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
         | 
| 258 | 
            +
                    predicted_ids = model.generate(input_features)
         | 
| 259 | 
            +
                    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
         | 
| 260 | 
            +
                    return transcription[0]
         | 
| 261 |  | 
| 262 | 
             
            # UTILS
         | 
| 263 | 
             
            def compute_wer(references, predictions):
         | 
| 264 | 
             
                wer = wer_metric.compute(references=references, predictions=predictions)
         | 
|  | |
| 265 | 
             
                return wer
         | 
| 266 |  | 
|  | |
|  | |
|  |