eolang commited on
Commit
28aa278
·
verified ·
1 Parent(s): 8b2fc06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py CHANGED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ # Silence all transformers warnings
4
+ transformers_logging.set_verbosity_error()
5
+ warnings.filterwarnings("ignore", category=UserWarning)
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from transformers import (
10
+ SpeechT5Processor,
11
+ SpeechT5ForTextToSpeech,
12
+ SpeechT5HifiGan,
13
+ pipeline
14
+ )
15
+ import json
16
+ import soundfile as sf
17
+ import numpy as np
18
+ from huggingface_hub import login
19
+ from jiwer import wer
20
+ from transformers.utils import logging as transformers_logging
21
+ from sklearn.feature_extraction.text import CountVectorizer
22
+ from sklearn.metrics.pairwise import cosine_similarity
23
+ import os
24
+
25
+ # -------------------------------------------------------------------------------------------------------------------
26
+
27
+ # Authentication $ Env Setup
28
+ HF_Key = os.environ.get("HF_Key")
29
+ login(token = HF_Key)
30
+
31
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
32
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
33
+
34
+ # Silence all transformers warnings
35
+ transformers_logging.set_verbosity_error()
36
+ warnings.filterwarnings("ignore", category=UserWarning)
37
+
38
+ # -------------------------------------------------------------------------------------------------------------------
39
+
40
+ def cosine_sim_wer_single(reference, prediction):
41
+ """
42
+ Calculate a WER-like metric based on cosine similarity for a single reference-prediction pair
43
+
44
+ Args:
45
+ reference: Single reference transcript (string)
46
+ prediction: Single model prediction (string)
47
+
48
+ Returns:
49
+ Error rate based on cosine similarity (100% - similarity%)
50
+ """
51
+ # Clean inputs
52
+ ref = reference.strip() if reference else ""
53
+ pred = prediction.strip() if prediction else ""
54
+
55
+ # Handle empty inputs
56
+ if not ref or not pred:
57
+ print("Warning: Empty reference or prediction")
58
+ return 100.0 # Return 100% error for invalid input
59
+
60
+ try:
61
+ # Use character n-grams to handle morphological variations better
62
+ vectorizer = CountVectorizer(analyzer='char_wb', ngram_range=(2, 3))
63
+
64
+ # Fit and transform
65
+ vectors = vectorizer.fit_transform([ref, pred])
66
+
67
+ # Calculate cosine similarity
68
+ similarity = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] * 100
69
+
70
+ # Convert to error rate (100% - similarity%)
71
+ error_rate = 100.0 - similarity
72
+
73
+ print(f"Similarity: {similarity:.2f}%")
74
+ print(f"Error rate: {error_rate:.2f}%")
75
+
76
+ except Exception as e:
77
+ print(f"Error calculating similarity: {e}")
78
+ return 100.0 # Return 100% error in case of calculation failure
79
+
80
+ # -------------------------------------------------------------------------------------------------------------------
81
+
82
+ ## TTS Module
83
+ speaker_file_path = 'speaker2.json'
84
+ model_id = 'eolang/speecht5_v4-2'
85
+
86
+ with open(speaker_file_path, 'r') as file:
87
+ example = json.load(file)
88
+
89
+ speaker_embeddings = torch.tensor(example).unsqueeze(0)
90
+
91
+ l_model = SpeechT5ForTextToSpeech.from_pretrained(
92
+ "eolang/speecht5_v4-2"
93
+ )
94
+
95
+ l_processor = SpeechT5Processor.from_pretrained("eolang/speecht5_v4-2")
96
+ l_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
97
+
98
+ def synthesize(input_text):
99
+ inputs = l_processor(text=input_text, return_tensors="pt")
100
+ speech = l_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=l_vocoder)
101
+
102
+ # Audio(speech.numpy(), rate=16000)
103
+ sf.write('test_output.wav', speech.numpy(), 16000)
104
+
105
+ # return speech
106
+
107
+ # -------------------------------------------------------------------------------------------------------------------
108
+
109
+ ## STT Module
110
+ ### Custom/Tunned Whisper
111
+ tuned_pipeline = pipeline(
112
+ "automatic-speech-recognition",
113
+ model="eolang/whisper-small-sw-WER-13-zindi",
114
+ device = device,
115
+ return_timestamps=True,
116
+ generate_kwargs={
117
+ "no_repeat_ngram_size": 3, # Blocks repeating 3-grams
118
+ "repetition_penalty": 1.5, # Penalize repetitions (1.0 = no penalty)
119
+ }
120
+ )
121
+
122
+
123
+ def tunned_transcribe(filepath):
124
+ transcription = tuned_pipeline(filepath, return_timestamps=True)
125
+ return transcription["text"]
126
+
127
+
128
+
129
+ ### OpenAI WHisper (Un-tuned)
130
+ openai_pipeline = pipeline(
131
+ "automatic-speech-recognition",
132
+ model="openai/whisper-small",
133
+ device = device,
134
+ return_timestamps=True,
135
+ generate_kwargs={
136
+ "no_repeat_ngram_size": 3, # Blocks repeating 3-grams
137
+ "repetition_penalty": 1.5, # Penalize repetitions (1.0 = no penalty)
138
+ }
139
+ )
140
+
141
+
142
+ def openai_transcribe(filepath):
143
+ transcription = openai_pipeline(filepath, return_timestamps=True)
144
+ return transcription["text"]
145
+
146
+ # -------------------------------------------------------------------------------------------------------------------
147
+
148
+ ## Full Loop module
149
+ def full_loop(ref_text):
150
+ # synthesize
151
+ synthesize(ref_text)
152
+
153
+ # Get transcriptions USING THE WRAPPER FUNCTIONS that return just text
154
+ tunned_transcription = tunned_transcribe('test_output.wav')
155
+ openai_trancsription = openai_transcribe('test_output.wav')
156
+
157
+ tunned_WER = wer(ref_text, tunned_transcription)
158
+ base_WER = wer(ref_text, openai_trancsription)
159
+
160
+ result = f'Tunned Model transciption: {tunned_transcription}\n'
161
+ result += f"Word error rate for the tunned model: {round(tunned_WER, 2)}\n"
162
+
163
+ # Call cosine sim for tuned model (this will print results)
164
+ cosine_sim_wer_single(ref_text, tunned_transcription)
165
+
166
+ result += f'\nBase Model transciption: {openai_trancsription}\n'
167
+ result += f"Word error rate for base-untunned model: {round(base_WER, 2)}\n"
168
+
169
+ # Call cosine sim for base model (this will print results)
170
+ cosine_sim_wer_single(ref_text, openai_trancsription)
171
+
172
+ return 'test_output.wav', result
173
+
174
+ # -------------------------------------------------------------------------------------------------------------------
175
+ # Add minimal Gradio wrapper
176
+
177
+ # Create a simple Gradio interface
178
+ demo = gr.Interface(
179
+ fn=full_loop, # Use your existing function without modifications
180
+ inputs=gr.Textbox(value="Kuna mambo kadhaa yanayoitajika kuzingatiwa wakati wa kufundisha modeli."),
181
+ outputs=[gr.Audio(), gr.Textbox()],
182
+ title="TTS-STT Evaluation"
183
+ )
184
+
185
+ # Launch the interface
186
+ if __name__ == "__main__":
187
+ demo.launch()