davidmeikle commited on
Commit
cc34edf
·
verified ·
1 Parent(s): 34b214d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
6
+ import platform
7
+ import librosa
8
+ import multiprocessing
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Tuple, List
11
+
12
+ @dataclass
13
+ class ModelConfig:
14
+ name: str
15
+ processor: Wav2Vec2Processor
16
+ model: Wav2Vec2ForCTC
17
+ description: str
18
+
19
+ class PhoneticEnhancer:
20
+ def __init__(self):
21
+ # Vowel length rules
22
+ self.long_vowels = {
23
+ 'i': 'iː',
24
+ 'u': 'uː',
25
+ 'a': 'ɑː',
26
+ 'ɑ': 'ɑː',
27
+ 'e': 'eː',
28
+ 'o': 'oː'
29
+ }
30
+
31
+ # Common diphthongs
32
+ self.diphthongs = {
33
+ 'ei': 'eɪ',
34
+ 'ai': 'aɪ',
35
+ 'oi': 'ɔɪ',
36
+ 'ou': 'əʊ',
37
+ 'au': 'aʊ'
38
+ }
39
+
40
+ # Vowel quality adjustments
41
+ self.vowel_quality = {
42
+ 'ə': 'æ', # In stressed positions
43
+ 'ɐ': 'æ' # Common substitution
44
+ }
45
+
46
+ # Stress pattern rules
47
+ self.stress_patterns = [
48
+ # (pattern, position) - position is index from start
49
+ (['CV', 'CV'], 1), # For words like "piage"
50
+ (['CVV', 'CV'], 0), # For words with long first vowel
51
+ ]
52
+
53
+ def _is_vowel(self, phoneme: str) -> bool:
54
+ vowels = set('aeiouɑɐəæɛɪʊʌɔ')
55
+ return any(char in vowels for char in phoneme)
56
+
57
+ def _split_into_syllables(self, phonemes: List[str]) -> List[List[str]]:
58
+ syllables = []
59
+ current_syllable = []
60
+
61
+ for phoneme in phonemes:
62
+ current_syllable.append(phoneme)
63
+ if self._is_vowel(phoneme) and len(current_syllable) > 0:
64
+ syllables.append(current_syllable)
65
+ current_syllable = []
66
+
67
+ if current_syllable:
68
+ if len(syllables) > 0:
69
+ syllables[-1].extend(current_syllable)
70
+ else:
71
+ syllables.append(current_syllable)
72
+
73
+ return syllables
74
+
75
+ def enhance_transcription(self, raw_phonemes: str, enhancements: List[str] = None) -> str:
76
+ if enhancements is None:
77
+ enhancements = ['length', 'quality', 'stress', 'diphthongs']
78
+
79
+ # Split into individual phonemes
80
+ phonemes = raw_phonemes.split()
81
+ enhanced_phonemes = phonemes.copy()
82
+
83
+ if 'length' in enhancements:
84
+ # Apply vowel length rules
85
+ for i, phoneme in enumerate(enhanced_phonemes):
86
+ if phoneme in self.long_vowels:
87
+ enhanced_phonemes[i] = self.long_vowels[phoneme]
88
+
89
+ if 'quality' in enhancements:
90
+ # Apply vowel quality adjustments
91
+ for i, phoneme in enumerate(enhanced_phonemes):
92
+ if phoneme in self.vowel_quality:
93
+ enhanced_phonemes[i] = self.vowel_quality[phoneme]
94
+
95
+ if 'diphthongs' in enhancements:
96
+ # Apply diphthong rules
97
+ i = 0
98
+ while i < len(enhanced_phonemes) - 1:
99
+ pair = enhanced_phonemes[i] + enhanced_phonemes[i + 1]
100
+ if pair in self.diphthongs:
101
+ enhanced_phonemes[i] = self.diphthongs[pair]
102
+ enhanced_phonemes.pop(i + 1)
103
+ i += 1
104
+
105
+ if 'stress' in enhancements:
106
+ # Add stress marks based on syllable structure
107
+ syllables = self._split_into_syllables(enhanced_phonemes)
108
+ if len(syllables) > 1:
109
+ # Add stress to the syllable containing 'æ' if present
110
+ for i, syll in enumerate(syllables):
111
+ if any('æ' in p for p in syll):
112
+ syllables[i].insert(0, 'ˈ')
113
+ break
114
+ # If no 'æ', add stress to first syllable by default
115
+ else:
116
+ syllables[0].insert(0, 'ˈ')
117
+
118
+ # Flatten syllables back to phonemes
119
+ enhanced_phonemes = [p for syll in syllables for p in syll]
120
+
121
+ return ' '.join(enhanced_phonemes)
122
+
123
+ class PhonemeTranscriber:
124
+ def __init__(self):
125
+ self.device = self._get_optimal_device()
126
+ print(f"Using device: {self.device}")
127
+
128
+ self.model_config = self._initialize_model()
129
+ self.target_sample_rate = 16_000
130
+ self.enhancer = PhoneticEnhancer()
131
+
132
+ def _get_optimal_device(self):
133
+ if torch.cuda.is_available():
134
+ return "cuda"
135
+ elif torch.backends.mps.is_available() and platform.system() == 'Darwin':
136
+ return "mps"
137
+ return "cpu"
138
+
139
+ def _initialize_model(self) -> ModelConfig:
140
+ model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
141
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
142
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(self.device)
143
+ model.eval()
144
+
145
+ return ModelConfig(
146
+ name=model_name,
147
+ processor=processor,
148
+ model=model,
149
+ description="LV-60 + CommonVoice (26 langs) + eSpeak"
150
+ )
151
+
152
+ def preprocess_audio(self, audio):
153
+ """Preprocess audio data for model input."""
154
+ if isinstance(audio, tuple):
155
+ sample_rate, audio_data = audio
156
+ else:
157
+ return None
158
+
159
+ if audio_data.dtype != np.float32:
160
+ audio_data = audio_data.astype(np.float32)
161
+
162
+ if audio_data.max() > 1.0 or audio_data.min() < -1.0:
163
+ audio_data = audio_data / 32768.0
164
+
165
+ if len(audio_data.shape) > 1:
166
+ audio_data = audio_data.mean(axis=1)
167
+
168
+ if sample_rate != self.target_sample_rate:
169
+ audio_data = librosa.resample(
170
+ y=audio_data,
171
+ orig_sr=sample_rate,
172
+ target_sr=self.target_sample_rate
173
+ )
174
+
175
+ return audio_data
176
+
177
+ @spaces.GPU
178
+ def transcribe_to_phonemes(self, audio, enhancements):
179
+ """Transcribe audio to phonemes with enhancements."""
180
+ try:
181
+ audio_data = self.preprocess_audio(audio)
182
+ if audio_data is None:
183
+ return "Please provide valid audio input"
184
+ selected_enhancements = enhancements.split(',') if enhancements else []
185
+ inputs = self.model_config.processor(
186
+ audio_data,
187
+ sampling_rate=self.target_sample_rate,
188
+ return_tensors="pt",
189
+ padding=True
190
+ ).input_values.to(self.device)
191
+
192
+ with torch.no_grad():
193
+ logits = self.model_config.model(inputs).logits
194
+
195
+ predicted_ids = torch.argmax(logits, dim=-1)
196
+ transcription = self.model_config.processor.batch_decode(predicted_ids)[0]
197
+
198
+ enhanced = self.enhancer.enhance_transcription(
199
+ transcription,
200
+ selected_enhancements
201
+ )
202
+
203
+ return f"""Raw IPA: {transcription}
204
+ Enhanced IPA: {enhanced}
205
+ Applied enhancements: {', '.join(selected_enhancements) or 'none'}"""
206
+
207
+ except Exception as e:
208
+ import traceback
209
+ return f"Error processing audio: {str(e)}\n{traceback.format_exc()}"
210
+
211
+ if __name__ == "__main__":
212
+ multiprocessing.freeze_support()
213
+ transcriber = PhonemeTranscriber()
214
+ iface = gr.Interface(
215
+ fn=transcriber.transcribe_to_phonemes,
216
+ inputs=[
217
+ gr.Audio(sources=["microphone", "upload"], type="numpy"),
218
+ gr.Textbox(
219
+ label="Enhancements (comma-separated)",
220
+ value="length,quality,stress,diphthongs",
221
+ placeholder="e.g., length,quality,stress,diphthongs"
222
+ )
223
+ ],
224
+ outputs="text",
225
+ title="Speech to Phoneme Converter - Enhanced IPA",
226
+ description=f"""Convert speech to phonemes with customizable IPA enhancements.
227
+ Currently using device: {transcriber.device}
228
+
229
+ Available enhancements:
230
+ - length: Add vowel length markers (ː)
231
+ - quality: Adjust vowel quality (e.g., ə → æ)
232
+ - stress: Add stress marks (ˈ)
233
+ - diphthongs: Combine vowels into diphthongs (e.g., ei → eɪ)
234
+ """
235
+ )
236
+
237
+ iface.launch()