Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -125,7 +125,9 @@ class PhonemeTranscriber:
|
|
125 |
self.device = self._get_optimal_device()
|
126 |
print(f"Using device: {self.device}")
|
127 |
|
128 |
-
|
|
|
|
|
129 |
self.target_sample_rate = 16_000
|
130 |
self.enhancer = PhoneticEnhancer()
|
131 |
|
@@ -135,18 +137,6 @@ class PhonemeTranscriber:
|
|
135 |
elif torch.backends.mps.is_available() and platform.system() == 'Darwin':
|
136 |
return "mps"
|
137 |
return "cpu"
|
138 |
-
|
139 |
-
def _initialize_model(self) -> ModelConfig:
|
140 |
-
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
141 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
142 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
143 |
-
|
144 |
-
return ModelConfig(
|
145 |
-
name=model_name,
|
146 |
-
processor=processor,
|
147 |
-
model=model,
|
148 |
-
description="LV-60 + CommonVoice (26 langs) + eSpeak"
|
149 |
-
)
|
150 |
|
151 |
def preprocess_audio(self, audio):
|
152 |
"""Preprocess audio data for model input."""
|
@@ -180,8 +170,13 @@ class PhonemeTranscriber:
|
|
180 |
audio_data = self.preprocess_audio(audio)
|
181 |
if audio_data is None:
|
182 |
return "Please provide valid audio input"
|
|
|
|
|
|
|
|
|
|
|
183 |
selected_enhancements = enhancements.split(',') if enhancements else []
|
184 |
-
inputs = self.
|
185 |
audio_data,
|
186 |
sampling_rate=self.target_sample_rate,
|
187 |
return_tensors="pt",
|
@@ -189,19 +184,24 @@ class PhonemeTranscriber:
|
|
189 |
).input_values.to(self.device)
|
190 |
|
191 |
with torch.no_grad():
|
192 |
-
logits =
|
193 |
|
194 |
predicted_ids = torch.argmax(logits, dim=-1)
|
195 |
-
transcription = self.
|
196 |
|
197 |
enhanced = self.enhancer.enhance_transcription(
|
198 |
transcription,
|
199 |
selected_enhancements
|
200 |
)
|
201 |
|
|
|
|
|
|
|
|
|
|
|
202 |
return f"""Raw IPA: {transcription}
|
203 |
-
|
204 |
-
|
205 |
|
206 |
except Exception as e:
|
207 |
import traceback
|
|
|
125 |
self.device = self._get_optimal_device()
|
126 |
print(f"Using device: {self.device}")
|
127 |
|
128 |
+
# Store model name and initialize processor only
|
129 |
+
self.model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
130 |
+
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
|
131 |
self.target_sample_rate = 16_000
|
132 |
self.enhancer = PhoneticEnhancer()
|
133 |
|
|
|
137 |
elif torch.backends.mps.is_available() and platform.system() == 'Darwin':
|
138 |
return "mps"
|
139 |
return "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
def preprocess_audio(self, audio):
|
142 |
"""Preprocess audio data for model input."""
|
|
|
170 |
audio_data = self.preprocess_audio(audio)
|
171 |
if audio_data is None:
|
172 |
return "Please provide valid audio input"
|
173 |
+
|
174 |
+
# Load model inside GPU context
|
175 |
+
model = Wav2Vec2ForCTC.from_pretrained(self.model_name).to(self.device)
|
176 |
+
model.eval()
|
177 |
+
|
178 |
selected_enhancements = enhancements.split(',') if enhancements else []
|
179 |
+
inputs = self.processor(
|
180 |
audio_data,
|
181 |
sampling_rate=self.target_sample_rate,
|
182 |
return_tensors="pt",
|
|
|
184 |
).input_values.to(self.device)
|
185 |
|
186 |
with torch.no_grad():
|
187 |
+
logits = model(inputs).logits
|
188 |
|
189 |
predicted_ids = torch.argmax(logits, dim=-1)
|
190 |
+
transcription = self.processor.batch_decode(predicted_ids)[0]
|
191 |
|
192 |
enhanced = self.enhancer.enhance_transcription(
|
193 |
transcription,
|
194 |
selected_enhancements
|
195 |
)
|
196 |
|
197 |
+
# Clean up to free GPU memory
|
198 |
+
del model
|
199 |
+
if torch.cuda.is_available():
|
200 |
+
torch.cuda.empty_cache()
|
201 |
+
|
202 |
return f"""Raw IPA: {transcription}
|
203 |
+
Enhanced IPA: {enhanced}
|
204 |
+
Applied enhancements: {', '.join(selected_enhancements) or 'none'}"""
|
205 |
|
206 |
except Exception as e:
|
207 |
import traceback
|