Upload 3 files
Browse files- app.py +186 -878
- evaluate.py +341 -0
- translator.py +435 -0
app.py
CHANGED
@@ -1,879 +1,187 @@
|
|
1 |
-
#
|
2 |
-
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import
|
6 |
-
|
7 |
-
|
8 |
-
logging
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
"
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
import glob
|
38 |
-
import numpy as np
|
39 |
-
import torch
|
40 |
-
from pydub import AudioSegment
|
41 |
-
import tempfile
|
42 |
-
import
|
43 |
-
import
|
44 |
-
from
|
45 |
-
from
|
46 |
-
|
47 |
-
|
48 |
-
from
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
logger.info(
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
)
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
"
|
99 |
-
"
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
)
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
"
|
137 |
-
"
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
)
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
# Initialize direct language pair statuses based on loaded models
|
188 |
-
translation_status = {}
|
189 |
-
|
190 |
-
# Add status for direct model pairs
|
191 |
-
for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]:
|
192 |
-
translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[
|
193 |
-
lang_pair] is not None else "failed"
|
194 |
-
|
195 |
-
# Add special phi model status
|
196 |
-
phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed"
|
197 |
-
translation_status["pam-fil"] = phi_status
|
198 |
-
translation_status["fil-pam"] = phi_status
|
199 |
-
translation_status["pam-tgl"] = phi_status # Using phi model but replacing tgl with fil
|
200 |
-
translation_status["tgl-pam"] = phi_status # Using phi model but replacing tgl with fil
|
201 |
-
|
202 |
-
health_status = {
|
203 |
-
"api_status": "online",
|
204 |
-
"asr_model": "loaded" if asr_model is not None else "failed",
|
205 |
-
"tts_models": {lang: "loaded" if model is not None else "failed"
|
206 |
-
for lang, model in tts_models.items()},
|
207 |
-
"translation_models": translation_status,
|
208 |
-
"device": device
|
209 |
-
}
|
210 |
-
return jsonify(health_status)
|
211 |
-
|
212 |
-
|
213 |
-
@app.route("/check_references", methods=["GET"])
|
214 |
-
def check_references():
|
215 |
-
"""Endpoint to check if reference files exist and are accessible"""
|
216 |
-
ref_patterns = ["mayap_a_abak", "mayap_a_ugtu", "mayap_a_gatpanapun",
|
217 |
-
"mayap_a_bengi", "komusta_ka"]
|
218 |
-
results = {}
|
219 |
-
|
220 |
-
for pattern in ref_patterns:
|
221 |
-
pattern_dir = os.path.join(REFERENCE_AUDIO_DIR, pattern)
|
222 |
-
if os.path.exists(pattern_dir):
|
223 |
-
wav_files = glob.glob(os.path.join(pattern_dir, "*.wav"))
|
224 |
-
results[pattern] = {
|
225 |
-
"exists": True,
|
226 |
-
"path": pattern_dir,
|
227 |
-
"file_count": len(wav_files),
|
228 |
-
"files": [os.path.basename(f) for f in wav_files]
|
229 |
-
}
|
230 |
-
else:
|
231 |
-
results[pattern] = {
|
232 |
-
"exists": False,
|
233 |
-
"path": pattern_dir
|
234 |
-
}
|
235 |
-
|
236 |
-
return jsonify({
|
237 |
-
"reference_audio_dir": REFERENCE_AUDIO_DIR,
|
238 |
-
"directory_exists": os.path.exists(REFERENCE_AUDIO_DIR),
|
239 |
-
"patterns": results
|
240 |
-
})
|
241 |
-
|
242 |
-
|
243 |
-
@app.route("/asr", methods=["POST"])
|
244 |
-
def transcribe_audio():
|
245 |
-
if asr_model is None or asr_processor is None:
|
246 |
-
logger.error("β ASR endpoint called but models aren't loaded")
|
247 |
-
return jsonify({"error": "ASR model not available"}), 503
|
248 |
-
|
249 |
-
try:
|
250 |
-
if "audio" not in request.files:
|
251 |
-
logger.warning("β οΈ ASR request missing audio file")
|
252 |
-
return jsonify({"error": "No audio file uploaded"}), 400
|
253 |
-
|
254 |
-
audio_file = request.files["audio"]
|
255 |
-
language = request.form.get("language", "english").lower()
|
256 |
-
|
257 |
-
if language not in LANGUAGE_CODES:
|
258 |
-
logger.warning(f"β οΈ Unsupported language requested: {language}")
|
259 |
-
return jsonify(
|
260 |
-
{"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
|
261 |
-
|
262 |
-
lang_code = LANGUAGE_CODES[language]
|
263 |
-
logger.info(f"π Processing {language} audio for ASR")
|
264 |
-
|
265 |
-
# Save the uploaded file temporarily
|
266 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
|
267 |
-
temp_audio.write(audio_file.read())
|
268 |
-
temp_audio_path = temp_audio.name
|
269 |
-
logger.debug(f"π Temporary audio saved to {temp_audio_path}")
|
270 |
-
|
271 |
-
# Convert to WAV if necessary
|
272 |
-
wav_path = temp_audio_path
|
273 |
-
if not audio_file.filename.lower().endswith(".wav"):
|
274 |
-
wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav")
|
275 |
-
logger.info(f"π Converting audio to WAV format: {wav_path}")
|
276 |
-
try:
|
277 |
-
audio = AudioSegment.from_file(temp_audio_path)
|
278 |
-
audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
|
279 |
-
audio.export(wav_path, format="wav")
|
280 |
-
except Exception as e:
|
281 |
-
logger.error(f"β Audio conversion failed: {str(e)}")
|
282 |
-
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
|
283 |
-
|
284 |
-
# Load and process the WAV file
|
285 |
-
try:
|
286 |
-
waveform, sr = torchaudio.load(wav_path)
|
287 |
-
logger.debug(f"β
Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
|
288 |
-
|
289 |
-
# Resample if needed
|
290 |
-
if sr != SAMPLE_RATE:
|
291 |
-
logger.info(f"π Resampling audio from {sr}Hz to {SAMPLE_RATE}Hz")
|
292 |
-
waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
|
293 |
-
|
294 |
-
waveform = waveform / torch.max(torch.abs(waveform))
|
295 |
-
except Exception as e:
|
296 |
-
logger.error(f"β Failed to load or process audio: {str(e)}")
|
297 |
-
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
298 |
-
|
299 |
-
# Process audio for ASR
|
300 |
-
try:
|
301 |
-
inputs = asr_processor(
|
302 |
-
waveform.squeeze().numpy(),
|
303 |
-
sampling_rate=SAMPLE_RATE,
|
304 |
-
return_tensors="pt",
|
305 |
-
language=lang_code
|
306 |
-
)
|
307 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
308 |
-
except Exception as e:
|
309 |
-
logger.error(f"β ASR preprocessing failed: {str(e)}")
|
310 |
-
return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
|
311 |
-
|
312 |
-
# Perform ASR
|
313 |
-
try:
|
314 |
-
with torch.no_grad():
|
315 |
-
logits = asr_model(**inputs).logits
|
316 |
-
ids = torch.argmax(logits, dim=-1)[0]
|
317 |
-
transcription = asr_processor.decode(ids)
|
318 |
-
|
319 |
-
logger.info(f"β
Transcription ({language}): {transcription}")
|
320 |
-
|
321 |
-
# Clean up temp files
|
322 |
-
try:
|
323 |
-
os.unlink(temp_audio_path)
|
324 |
-
if wav_path != temp_audio_path:
|
325 |
-
os.unlink(wav_path)
|
326 |
-
except Exception as e:
|
327 |
-
logger.warning(f"β οΈ Failed to clean up temp files: {str(e)}")
|
328 |
-
|
329 |
-
return jsonify({
|
330 |
-
"transcription": transcription,
|
331 |
-
"language": language,
|
332 |
-
"language_code": lang_code
|
333 |
-
})
|
334 |
-
except Exception as e:
|
335 |
-
logger.error(f"β ASR inference failed: {str(e)}")
|
336 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
337 |
-
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
338 |
-
|
339 |
-
except Exception as e:
|
340 |
-
logger.error(f"β Unhandled exception in ASR endpoint: {str(e)}")
|
341 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
342 |
-
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
343 |
-
|
344 |
-
|
345 |
-
@app.route("/tts", methods=["POST"])
|
346 |
-
def generate_tts():
|
347 |
-
try:
|
348 |
-
data = request.get_json()
|
349 |
-
if not data:
|
350 |
-
logger.warning("β οΈ TTS endpoint called with no JSON data")
|
351 |
-
return jsonify({"error": "No JSON data provided"}), 400
|
352 |
-
|
353 |
-
text_input = data.get("text", "").strip()
|
354 |
-
language = data.get("language", "kapampangan").lower()
|
355 |
-
|
356 |
-
if not text_input:
|
357 |
-
logger.warning("β οΈ TTS request with empty text")
|
358 |
-
return jsonify({"error": "No text provided"}), 400
|
359 |
-
|
360 |
-
if language not in TTS_MODELS:
|
361 |
-
logger.warning(f"β οΈ TTS requested for unsupported language: {language}")
|
362 |
-
return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
|
363 |
-
|
364 |
-
if tts_models[language] is None:
|
365 |
-
logger.error(f"β TTS model for {language} not loaded")
|
366 |
-
return jsonify({"error": f"TTS model for {language} not available"}), 503
|
367 |
-
|
368 |
-
logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'")
|
369 |
-
|
370 |
-
try:
|
371 |
-
processor = tts_processors[language]
|
372 |
-
model = tts_models[language]
|
373 |
-
inputs = processor(text_input, return_tensors="pt")
|
374 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
375 |
-
except Exception as e:
|
376 |
-
logger.error(f"β TTS preprocessing failed: {str(e)}")
|
377 |
-
return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
|
378 |
-
|
379 |
-
# Generate speech
|
380 |
-
try:
|
381 |
-
with torch.no_grad():
|
382 |
-
output = model(**inputs).waveform
|
383 |
-
waveform = output.squeeze().cpu().numpy()
|
384 |
-
except Exception as e:
|
385 |
-
logger.error(f"β TTS inference failed: {str(e)}")
|
386 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
387 |
-
return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
|
388 |
-
|
389 |
-
# Save to file
|
390 |
-
try:
|
391 |
-
output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
|
392 |
-
sampling_rate = model.config.sampling_rate
|
393 |
-
sf.write(output_filename, waveform, sampling_rate)
|
394 |
-
logger.info(f"β
Speech generated! File saved: {output_filename}")
|
395 |
-
except Exception as e:
|
396 |
-
logger.error(f"β Failed to save audio file: {str(e)}")
|
397 |
-
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
|
398 |
-
|
399 |
-
return jsonify({
|
400 |
-
"message": "TTS audio generated",
|
401 |
-
"file_url": f"/download/{os.path.basename(output_filename)}",
|
402 |
-
"language": language,
|
403 |
-
"text_length": len(text_input)
|
404 |
-
})
|
405 |
-
except Exception as e:
|
406 |
-
logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}")
|
407 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
408 |
-
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
409 |
-
|
410 |
-
|
411 |
-
@app.route("/download/<filename>", methods=["GET"])
|
412 |
-
def download_audio(filename):
|
413 |
-
file_path = os.path.join(OUTPUT_DIR, filename)
|
414 |
-
if os.path.exists(file_path):
|
415 |
-
logger.info(f"π€ Serving audio file: {file_path}")
|
416 |
-
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
|
417 |
-
|
418 |
-
logger.warning(f"β οΈ Requested file not found: {file_path}")
|
419 |
-
return jsonify({"error": "File not found"}), 404
|
420 |
-
|
421 |
-
|
422 |
-
@app.route("/translate", methods=["POST"])
|
423 |
-
def translate_text():
|
424 |
-
try:
|
425 |
-
data = request.get_json()
|
426 |
-
if not data:
|
427 |
-
logger.warning("β οΈ Translation endpoint called with no JSON data")
|
428 |
-
return jsonify({"error": "No JSON data provided"}), 400
|
429 |
-
|
430 |
-
source_text = data.get("text", "").strip()
|
431 |
-
source_language = data.get("source_language", "").lower()
|
432 |
-
target_language = data.get("target_language", "").lower()
|
433 |
-
|
434 |
-
if not source_text:
|
435 |
-
logger.warning("β οΈ Translation request with empty text")
|
436 |
-
return jsonify({"error": "No text provided"}), 400
|
437 |
-
|
438 |
-
# Map language names to codes
|
439 |
-
source_code = LANGUAGE_CODES.get(source_language, source_language)
|
440 |
-
target_code = LANGUAGE_CODES.get(target_language, target_language)
|
441 |
-
|
442 |
-
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
443 |
-
|
444 |
-
# Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
|
445 |
-
use_phi_model = False
|
446 |
-
actual_source_code = source_code
|
447 |
-
actual_target_code = target_code
|
448 |
-
|
449 |
-
# Check if we need to use the phi model with fil replacement
|
450 |
-
if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
|
451 |
-
use_phi_model = True
|
452 |
-
elif (source_code == "pam" and target_code == "tgl"):
|
453 |
-
use_phi_model = True
|
454 |
-
actual_target_code = "fil" # Replace tgl with fil for the phi model
|
455 |
-
elif (source_code == "tgl" and target_code == "pam"):
|
456 |
-
use_phi_model = True
|
457 |
-
actual_source_code = "fil" # Replace tgl with fil for the phi model
|
458 |
-
|
459 |
-
if use_phi_model:
|
460 |
-
model_key = "phi"
|
461 |
-
|
462 |
-
# Check if we have the phi model
|
463 |
-
if model_key not in translation_models or translation_models[model_key] is None:
|
464 |
-
logger.error(f"β Translation model for {model_key} not loaded")
|
465 |
-
return jsonify({"error": f"Translation model not available"}), 503
|
466 |
-
|
467 |
-
try:
|
468 |
-
# Get the phi model and tokenizer
|
469 |
-
model = translation_models[model_key]
|
470 |
-
tokenizer = translation_tokenizers[model_key]
|
471 |
-
|
472 |
-
# Prepend target language token to input
|
473 |
-
input_text = f">>{actual_target_code}<< {source_text}"
|
474 |
-
|
475 |
-
logger.info(f"π Using phi model with input: '{input_text}'")
|
476 |
-
|
477 |
-
# Tokenize the text
|
478 |
-
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
479 |
-
tokenized = {k: v.to(device) for k, v in tokenized.items()}
|
480 |
-
|
481 |
-
# Generate translation
|
482 |
-
with torch.no_grad():
|
483 |
-
translated = model.generate(**tokenized)
|
484 |
-
|
485 |
-
# Decode the translation
|
486 |
-
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
487 |
-
|
488 |
-
logger.info(f"β
Translation result: '{result}'")
|
489 |
-
|
490 |
-
return jsonify({
|
491 |
-
"translated_text": result,
|
492 |
-
"source_language": source_language,
|
493 |
-
"target_language": target_language
|
494 |
-
})
|
495 |
-
except Exception as e:
|
496 |
-
logger.error(f"β Translation processing failed: {str(e)}")
|
497 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
498 |
-
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
499 |
-
else:
|
500 |
-
# Create the regular language pair key for other language pairs
|
501 |
-
lang_pair = f"{source_code}-{target_code}"
|
502 |
-
|
503 |
-
# Check if we have a model for this language pair
|
504 |
-
if lang_pair not in translation_models:
|
505 |
-
logger.warning(f"β οΈ No translation model available for {lang_pair}")
|
506 |
-
return jsonify(
|
507 |
-
{"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
|
508 |
-
|
509 |
-
if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
|
510 |
-
logger.error(f"β Translation model for {lang_pair} not loaded")
|
511 |
-
return jsonify({"error": f"Translation model not available"}), 503
|
512 |
-
|
513 |
-
try:
|
514 |
-
# Regular translation process for other language pairs
|
515 |
-
model = translation_models[lang_pair]
|
516 |
-
tokenizer = translation_tokenizers[lang_pair]
|
517 |
-
|
518 |
-
# Tokenize the text
|
519 |
-
tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
|
520 |
-
tokenized = {k: v.to(device) for k, v in tokenized.items()}
|
521 |
-
|
522 |
-
# Generate translation
|
523 |
-
with torch.no_grad():
|
524 |
-
translated = model.generate(**tokenized)
|
525 |
-
|
526 |
-
# Decode the translation
|
527 |
-
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
528 |
-
|
529 |
-
logger.info(f"β
Translation result: '{result}'")
|
530 |
-
|
531 |
-
return jsonify({
|
532 |
-
"translated_text": result,
|
533 |
-
"source_language": source_language,
|
534 |
-
"target_language": target_language
|
535 |
-
})
|
536 |
-
except Exception as e:
|
537 |
-
logger.error(f"β Translation processing failed: {str(e)}")
|
538 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
539 |
-
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
540 |
-
|
541 |
-
except Exception as e:
|
542 |
-
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}")
|
543 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
544 |
-
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
545 |
-
|
546 |
-
|
547 |
-
# Add this function to your app.py
|
548 |
-
def calculate_similarity(text1, text2):
|
549 |
-
"""Calculate text similarity percentage."""
|
550 |
-
|
551 |
-
def clean_text(text):
|
552 |
-
return text.lower()
|
553 |
-
|
554 |
-
clean1 = clean_text(text1)
|
555 |
-
clean2 = clean_text(text2)
|
556 |
-
|
557 |
-
matcher = SequenceMatcher(None, clean1, clean2)
|
558 |
-
return matcher.ratio() * 100
|
559 |
-
|
560 |
-
@app.route("/evaluate", methods=["POST"])
|
561 |
-
def evaluate_pronunciation():
|
562 |
-
request_id = f"req-{id(request)}" # Create unique ID for this request
|
563 |
-
logger.info(f"[{request_id}] π Starting new pronunciation evaluation request")
|
564 |
-
|
565 |
-
if asr_model is None or asr_processor is None:
|
566 |
-
logger.error(f"[{request_id}] β Evaluation endpoint called but ASR models aren't loaded")
|
567 |
-
return jsonify({"error": "ASR model not available"}), 503
|
568 |
-
|
569 |
-
try:
|
570 |
-
if "audio" not in request.files:
|
571 |
-
logger.warning(f"[{request_id}] β οΈ Evaluation request missing audio file")
|
572 |
-
return jsonify({"error": "No audio file uploaded"}), 400
|
573 |
-
|
574 |
-
audio_file = request.files["audio"]
|
575 |
-
reference_locator = request.form.get("reference_locator", "").strip()
|
576 |
-
language = request.form.get("language", "kapampangan").lower()
|
577 |
-
|
578 |
-
# Validate reference locator
|
579 |
-
if not reference_locator:
|
580 |
-
logger.warning(f"[{request_id}] β οΈ No reference locator provided")
|
581 |
-
return jsonify({"error": "Reference locator is required"}), 400
|
582 |
-
|
583 |
-
# Construct full reference directory path
|
584 |
-
reference_dir = os.path.join(REFERENCE_AUDIO_DIR, reference_locator)
|
585 |
-
logger.info(f"[{request_id}] π Reference directory path: {reference_dir}")
|
586 |
-
|
587 |
-
if not os.path.exists(reference_dir):
|
588 |
-
logger.warning(f"[{request_id}] β οΈ Reference directory not found: {reference_dir}")
|
589 |
-
return jsonify({"error": f"Reference audio directory not found: {reference_locator}"}), 404
|
590 |
-
|
591 |
-
reference_files = glob.glob(os.path.join(reference_dir, "*.wav"))
|
592 |
-
logger.info(f"[{request_id}] π Found {len(reference_files)} reference files")
|
593 |
-
|
594 |
-
if not reference_files:
|
595 |
-
logger.warning(f"[{request_id}] β οΈ No reference audio files found in {reference_dir}")
|
596 |
-
return jsonify({"error": f"No reference audio found for {reference_locator}"}), 404
|
597 |
-
|
598 |
-
lang_code = LANGUAGE_CODES.get(language, language)
|
599 |
-
logger.info(f"[{request_id}] π Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")
|
600 |
-
|
601 |
-
# Create a request-specific temp directory to avoid conflicts
|
602 |
-
temp_dir = os.path.join(OUTPUT_DIR, f"temp_{request_id}")
|
603 |
-
os.makedirs(temp_dir, exist_ok=True)
|
604 |
-
|
605 |
-
# Process user audio
|
606 |
-
user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
|
607 |
-
with open(user_audio_path, 'wb') as f:
|
608 |
-
f.write(audio_file.read())
|
609 |
-
|
610 |
-
try:
|
611 |
-
logger.info(f"[{request_id}] π Processing user audio file")
|
612 |
-
audio = AudioSegment.from_file(user_audio_path)
|
613 |
-
audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
|
614 |
-
|
615 |
-
processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
|
616 |
-
audio.export(processed_path, format="wav")
|
617 |
-
|
618 |
-
user_waveform, sr = torchaudio.load(processed_path)
|
619 |
-
user_waveform = user_waveform.squeeze().numpy()
|
620 |
-
logger.info(f"[{request_id}] β
User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
|
621 |
-
|
622 |
-
user_audio_path = processed_path
|
623 |
-
except Exception as e:
|
624 |
-
logger.error(f"[{request_id}] β Audio processing failed: {str(e)}")
|
625 |
-
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
626 |
-
|
627 |
-
# Transcribe user audio
|
628 |
-
try:
|
629 |
-
logger.info(f"[{request_id}] π Transcribing user audio")
|
630 |
-
inputs = asr_processor(
|
631 |
-
user_waveform,
|
632 |
-
sampling_rate=SAMPLE_RATE,
|
633 |
-
return_tensors="pt",
|
634 |
-
language=lang_code
|
635 |
-
)
|
636 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
637 |
-
|
638 |
-
with torch.no_grad():
|
639 |
-
logits = asr_model(**inputs).logits
|
640 |
-
ids = torch.argmax(logits, dim=-1)[0]
|
641 |
-
user_transcription = asr_processor.decode(ids)
|
642 |
-
|
643 |
-
logger.info(f"[{request_id}] β
User transcription: '{user_transcription}'")
|
644 |
-
except Exception as e:
|
645 |
-
logger.error(f"[{request_id}] β ASR inference failed: {str(e)}")
|
646 |
-
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
647 |
-
|
648 |
-
# Process reference files in batches
|
649 |
-
batch_size = 2 # Process 2 files at a time - adjust based on your hardware
|
650 |
-
results = []
|
651 |
-
best_score = 0
|
652 |
-
best_reference = None
|
653 |
-
best_transcription = None
|
654 |
-
|
655 |
-
# Use this if you want to limit the number of files to process
|
656 |
-
max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
|
657 |
-
reference_files = reference_files[:max_files_to_check]
|
658 |
-
|
659 |
-
logger.info(f"[{request_id}] π Processing {len(reference_files)} reference files in batches of {batch_size}")
|
660 |
-
|
661 |
-
# Function to process a single reference file
|
662 |
-
def process_reference_file(ref_file):
|
663 |
-
ref_filename = os.path.basename(ref_file)
|
664 |
-
try:
|
665 |
-
# Load and resample reference audio
|
666 |
-
ref_waveform, ref_sr = torchaudio.load(ref_file)
|
667 |
-
if ref_sr != SAMPLE_RATE:
|
668 |
-
ref_waveform = torchaudio.transforms.Resample(ref_sr, SAMPLE_RATE)(ref_waveform)
|
669 |
-
ref_waveform = ref_waveform.squeeze().numpy()
|
670 |
-
|
671 |
-
# Transcribe reference audio
|
672 |
-
inputs = asr_processor(
|
673 |
-
ref_waveform,
|
674 |
-
sampling_rate=SAMPLE_RATE,
|
675 |
-
return_tensors="pt",
|
676 |
-
language=lang_code
|
677 |
-
)
|
678 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
679 |
-
|
680 |
-
with torch.no_grad():
|
681 |
-
logits = asr_model(**inputs).logits
|
682 |
-
ids = torch.argmax(logits, dim=-1)[0]
|
683 |
-
ref_transcription = asr_processor.decode(ids)
|
684 |
-
|
685 |
-
# Calculate similarity
|
686 |
-
similarity = calculate_similarity(user_transcription, ref_transcription)
|
687 |
-
|
688 |
-
logger.info(f"[{request_id}] π Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
|
689 |
-
|
690 |
-
return {
|
691 |
-
"reference_file": ref_filename,
|
692 |
-
"reference_text": ref_transcription,
|
693 |
-
"similarity_score": similarity
|
694 |
-
}
|
695 |
-
except Exception as e:
|
696 |
-
logger.error(f"[{request_id}] β Error processing {ref_filename}: {str(e)}")
|
697 |
-
return {
|
698 |
-
"reference_file": ref_filename,
|
699 |
-
"reference_text": "Error",
|
700 |
-
"similarity_score": 0,
|
701 |
-
"error": str(e)
|
702 |
-
}
|
703 |
-
|
704 |
-
# Process files in batches using ThreadPoolExecutor
|
705 |
-
from concurrent.futures import ThreadPoolExecutor
|
706 |
-
|
707 |
-
with ThreadPoolExecutor(max_workers=batch_size) as executor:
|
708 |
-
batch_results = list(executor.map(process_reference_file, reference_files))
|
709 |
-
results.extend(batch_results)
|
710 |
-
|
711 |
-
# Find the best result
|
712 |
-
for result in batch_results:
|
713 |
-
if result["similarity_score"] > best_score:
|
714 |
-
best_score = result["similarity_score"]
|
715 |
-
best_reference = result["reference_file"]
|
716 |
-
best_transcription = result["reference_text"]
|
717 |
-
|
718 |
-
# Exit early if we found a very good match (optional)
|
719 |
-
if best_score > 80.0:
|
720 |
-
logger.info(f"[{request_id}] π Found excellent match: {best_score:.2f}%")
|
721 |
-
break
|
722 |
-
|
723 |
-
# Clean up temp files
|
724 |
-
try:
|
725 |
-
import shutil
|
726 |
-
shutil.rmtree(temp_dir)
|
727 |
-
logger.debug(f"[{request_id}] π§Ή Cleaned up temporary directory")
|
728 |
-
except Exception as e:
|
729 |
-
logger.warning(f"[{request_id}] β οΈ Failed to clean up temp files: {str(e)}")
|
730 |
-
|
731 |
-
# Determine feedback based on score
|
732 |
-
is_correct = best_score >= 70.0
|
733 |
-
|
734 |
-
if best_score >= 90.0:
|
735 |
-
feedback = "Perfect pronunciation! Excellent job!"
|
736 |
-
elif best_score >= 80.0:
|
737 |
-
feedback = "Great pronunciation! Your accent is very good."
|
738 |
-
elif best_score >= 70.0:
|
739 |
-
feedback = "Good pronunciation. Keep practicing!"
|
740 |
-
elif best_score >= 50.0:
|
741 |
-
feedback = "Fair attempt. Try focusing on the syllables that differ from the sample."
|
742 |
-
else:
|
743 |
-
feedback = "Try again. Listen carefully to the sample pronunciation."
|
744 |
-
|
745 |
-
logger.info(f"[{request_id}] π Final evaluation results: score={best_score:.2f}%, is_correct={is_correct}")
|
746 |
-
logger.info(f"[{request_id}] π Feedback: '{feedback}'")
|
747 |
-
logger.info(f"[{request_id}] β
Evaluation complete")
|
748 |
-
|
749 |
-
# Sort results by score descending
|
750 |
-
results.sort(key=lambda x: x["similarity_score"], reverse=True)
|
751 |
-
|
752 |
-
return jsonify({
|
753 |
-
"is_correct": is_correct,
|
754 |
-
"score": best_score,
|
755 |
-
"feedback": feedback,
|
756 |
-
"user_transcription": user_transcription,
|
757 |
-
"best_reference_transcription": best_transcription,
|
758 |
-
"reference_locator": reference_locator,
|
759 |
-
"details": results
|
760 |
-
})
|
761 |
-
|
762 |
-
except Exception as e:
|
763 |
-
logger.error(f"[{request_id}] β Unhandled exception in evaluation endpoint: {str(e)}")
|
764 |
-
logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
|
765 |
-
|
766 |
-
# Clean up on error
|
767 |
-
try:
|
768 |
-
import shutil
|
769 |
-
shutil.rmtree(temp_dir)
|
770 |
-
except:
|
771 |
-
pass
|
772 |
-
|
773 |
-
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
774 |
-
|
775 |
-
|
776 |
-
@app.route("/upload_reference", methods=["POST"])
|
777 |
-
def upload_reference_audio():
|
778 |
-
try:
|
779 |
-
if "audio" not in request.files:
|
780 |
-
logger.warning("β οΈ Reference upload missing audio file")
|
781 |
-
return jsonify({"error": "No audio file uploaded"}), 400
|
782 |
-
|
783 |
-
reference_word = request.form.get("reference_word", "").strip()
|
784 |
-
if not reference_word:
|
785 |
-
logger.warning("β οΈ Reference upload missing reference word")
|
786 |
-
return jsonify({"error": "No reference word provided"}), 400
|
787 |
-
|
788 |
-
# Validate reference word
|
789 |
-
reference_patterns = [
|
790 |
-
"mayap_a_abak", "mayap_a_ugtu", "mayap_a_gatpanapun", "mayap_a_bengi", "komusta_ka", "malaus_ko_pu","malaus_kayu","agaganaka_da_ka", "pagdulapan_da_ka","kaluguran_da_ka","dakal_a_salamat","panapaya_mu_ku"
|
791 |
-
]
|
792 |
-
|
793 |
-
if reference_word not in reference_patterns:
|
794 |
-
logger.warning(f"β οΈ Invalid reference word: {reference_word}")
|
795 |
-
return jsonify({"error": f"Invalid reference word. Available: {reference_patterns}"}), 400
|
796 |
-
|
797 |
-
# Create directory for reference pattern if it doesn't exist
|
798 |
-
pattern_dir = os.path.join(REFERENCE_AUDIO_DIR, reference_word)
|
799 |
-
os.makedirs(pattern_dir, exist_ok=True)
|
800 |
-
|
801 |
-
# Save the reference audio file
|
802 |
-
audio_file = request.files["audio"]
|
803 |
-
file_path = os.path.join(pattern_dir, secure_filename(audio_file.filename))
|
804 |
-
audio_file.save(file_path)
|
805 |
-
|
806 |
-
# Convert to WAV if not already in that format
|
807 |
-
if not file_path.lower().endswith('.wav'):
|
808 |
-
base_path = os.path.splitext(file_path)[0]
|
809 |
-
wav_path = f"{base_path}.wav"
|
810 |
-
try:
|
811 |
-
audio = AudioSegment.from_file(file_path)
|
812 |
-
audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
|
813 |
-
audio.export(wav_path, format="wav")
|
814 |
-
# Remove original file if conversion successful
|
815 |
-
os.unlink(file_path)
|
816 |
-
file_path = wav_path
|
817 |
-
except Exception as e:
|
818 |
-
logger.error(f"β Reference audio conversion failed: {str(e)}")
|
819 |
-
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
|
820 |
-
|
821 |
-
logger.info(f"β
Reference audio saved successfully for {reference_word}: {file_path}")
|
822 |
-
|
823 |
-
# Count how many references we have now
|
824 |
-
references = glob.glob(os.path.join(pattern_dir, "*.wav"))
|
825 |
-
return jsonify({
|
826 |
-
"message": "Reference audio uploaded successfully",
|
827 |
-
"reference_word": reference_word,
|
828 |
-
"file": os.path.basename(file_path),
|
829 |
-
"total_references": len(references)
|
830 |
-
})
|
831 |
-
|
832 |
-
except Exception as e:
|
833 |
-
logger.error(f"β Unhandled exception in reference upload: {str(e)}")
|
834 |
-
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
835 |
-
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
836 |
-
|
837 |
-
|
838 |
-
def init_reference_audio():
|
839 |
-
try:
|
840 |
-
# Create the output directory first
|
841 |
-
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
842 |
-
logger.info(f"π Created output directory: {OUTPUT_DIR}")
|
843 |
-
|
844 |
-
# Check if the reference audio directory exists in the repository
|
845 |
-
if os.path.exists(REFERENCE_AUDIO_DIR):
|
846 |
-
logger.info(f"β
Found reference audio directory: {REFERENCE_AUDIO_DIR}")
|
847 |
-
|
848 |
-
# Log the contents to verify
|
849 |
-
pattern_dirs = [d for d in os.listdir(REFERENCE_AUDIO_DIR)
|
850 |
-
if os.path.isdir(os.path.join(REFERENCE_AUDIO_DIR, d))]
|
851 |
-
logger.info(f"π Found reference patterns: {pattern_dirs}")
|
852 |
-
|
853 |
-
# Check each pattern directory for wav files
|
854 |
-
for pattern_dir_name in pattern_dirs:
|
855 |
-
pattern_path = os.path.join(REFERENCE_AUDIO_DIR, pattern_dir_name)
|
856 |
-
wav_files = glob.glob(os.path.join(pattern_path, "*.wav"))
|
857 |
-
logger.info(f"π Found {len(wav_files)} wav files in {pattern_dir_name}")
|
858 |
-
else:
|
859 |
-
logger.warning(f"β οΈ Reference audio directory not found: {REFERENCE_AUDIO_DIR}")
|
860 |
-
except Exception as e:
|
861 |
-
logger.error(f"β Failed to set up reference audio directory: {str(e)}")
|
862 |
-
|
863 |
-
|
864 |
-
# Add an initialization route that will be called before the first request
|
865 |
-
@app.before_request
|
866 |
-
def before_request():
|
867 |
-
if not hasattr(g, 'initialized'):
|
868 |
-
init_reference_audio()
|
869 |
-
g.initialized = True
|
870 |
-
|
871 |
-
|
872 |
-
if __name__ == "__main__":
|
873 |
-
init_reference_audio()
|
874 |
-
logger.info("π Starting Speech API server")
|
875 |
-
logger.info(f"π System status: ASR model: {'β
' if asr_model else 'β'}")
|
876 |
-
for lang, model in tts_models.items():
|
877 |
-
logger.info(f"π TTS model {lang}: {'β
' if model else 'β'}")
|
878 |
-
|
879 |
app.run(host="0.0.0.0", port=7860, debug=True)
|
|
|
1 |
+
# app.py - Main application file
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(
|
10 |
+
level=logging.INFO,
|
11 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
12 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
13 |
+
)
|
14 |
+
logger = logging.getLogger("speech_api")
|
15 |
+
|
16 |
+
# Set all cache directories to locations within /tmp
|
17 |
+
cache_dirs = {
|
18 |
+
"HF_HOME": "/tmp/hf_home",
|
19 |
+
"TRANSFORMERS_CACHE": "/tmp/transformers_cache",
|
20 |
+
"HUGGINGFACE_HUB_CACHE": "/tmp/huggingface_hub_cache",
|
21 |
+
"TORCH_HOME": "/tmp/torch_home",
|
22 |
+
"XDG_CACHE_HOME": "/tmp/xdg_cache"
|
23 |
+
}
|
24 |
+
|
25 |
+
# Set environment variables and create directories
|
26 |
+
for env_var, path in cache_dirs.items():
|
27 |
+
os.environ[env_var] = path
|
28 |
+
try:
|
29 |
+
os.makedirs(path, exist_ok=True)
|
30 |
+
logger.info(f"π Created cache directory: {path}")
|
31 |
+
except Exception as e:
|
32 |
+
logger.error(f"β Failed to create directory {path}: {str(e)}")
|
33 |
+
|
34 |
+
# Now import the rest of the libraries
|
35 |
+
try:
|
36 |
+
import librosa
|
37 |
+
import glob
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
from pydub import AudioSegment
|
41 |
+
import tempfile
|
42 |
+
import soundfile as sf
|
43 |
+
from flask import Flask, request, jsonify, send_file, g
|
44 |
+
from flask_cors import CORS
|
45 |
+
from werkzeug.utils import secure_filename
|
46 |
+
|
47 |
+
# Import functionality from other modules
|
48 |
+
from translator import (
|
49 |
+
init_models, check_model_status, handle_asr_request,
|
50 |
+
handle_tts_request, handle_translation_request
|
51 |
+
)
|
52 |
+
from evaluate import (
|
53 |
+
handle_evaluation_request, handle_upload_reference,
|
54 |
+
init_reference_audio, calculate_similarity
|
55 |
+
)
|
56 |
+
|
57 |
+
logger.info("β
All required libraries imported successfully")
|
58 |
+
except ImportError as e:
|
59 |
+
logger.critical(f"β Failed to import necessary libraries: {str(e)}")
|
60 |
+
sys.exit(1)
|
61 |
+
|
62 |
+
# Check CUDA availability
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
logger.info(f"π CUDA available: {torch.cuda.get_device_name(0)}")
|
65 |
+
device = "cuda"
|
66 |
+
else:
|
67 |
+
logger.info("β οΈ CUDA not available, using CPU")
|
68 |
+
device = "cpu"
|
69 |
+
|
70 |
+
# Constants
|
71 |
+
SAMPLE_RATE = 16000
|
72 |
+
OUTPUT_DIR = "/tmp/audio_outputs"
|
73 |
+
REFERENCE_AUDIO_DIR = "./reference_audio"
|
74 |
+
|
75 |
+
try:
|
76 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
77 |
+
logger.info(f"π Created output directory: {OUTPUT_DIR}")
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"β Failed to create output directory: {str(e)}")
|
80 |
+
|
81 |
+
# Initialize Flask app
|
82 |
+
app = Flask(__name__)
|
83 |
+
CORS(app)
|
84 |
+
|
85 |
+
# Load models
|
86 |
+
init_models(device)
|
87 |
+
|
88 |
+
|
89 |
+
# Define routes
|
90 |
+
@app.route("/", methods=["GET"])
|
91 |
+
def home():
|
92 |
+
return jsonify({"message": "Speech API is running", "status": "active"})
|
93 |
+
|
94 |
+
|
95 |
+
@app.route("/health", methods=["GET"])
|
96 |
+
def health_check():
|
97 |
+
health_status = check_model_status()
|
98 |
+
health_status["api_status"] = "online"
|
99 |
+
health_status["device"] = device
|
100 |
+
return jsonify(health_status)
|
101 |
+
|
102 |
+
|
103 |
+
@app.route("/asr", methods=["POST"])
|
104 |
+
def transcribe_audio():
|
105 |
+
return handle_asr_request(request, OUTPUT_DIR, SAMPLE_RATE)
|
106 |
+
|
107 |
+
|
108 |
+
@app.route("/tts", methods=["POST"])
|
109 |
+
def generate_tts():
|
110 |
+
return handle_tts_request(request, OUTPUT_DIR)
|
111 |
+
|
112 |
+
|
113 |
+
@app.route("/translate", methods=["POST"])
|
114 |
+
def translate_text():
|
115 |
+
return handle_translation_request(request)
|
116 |
+
|
117 |
+
|
118 |
+
@app.route("/download/<filename>", methods=["GET"])
|
119 |
+
def download_audio(filename):
|
120 |
+
file_path = os.path.join(OUTPUT_DIR, filename)
|
121 |
+
if os.path.exists(file_path):
|
122 |
+
logger.info(f"π€ Serving audio file: {file_path}")
|
123 |
+
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
|
124 |
+
|
125 |
+
logger.warning(f"β οΈ Requested file not found: {file_path}")
|
126 |
+
return jsonify({"error": "File not found"}), 404
|
127 |
+
|
128 |
+
|
129 |
+
@app.route("/evaluate", methods=["POST"])
|
130 |
+
def evaluate_pronunciation():
|
131 |
+
return handle_evaluation_request(request, REFERENCE_AUDIO_DIR, OUTPUT_DIR, SAMPLE_RATE)
|
132 |
+
|
133 |
+
|
134 |
+
@app.route("/check_references", methods=["GET"])
|
135 |
+
def check_references():
|
136 |
+
"""Endpoint to check if reference files exist and are accessible"""
|
137 |
+
ref_patterns = ["mayap_a_abak", "mayap_a_ugtu", "mayap_a_gatpanapun",
|
138 |
+
"mayap_a_bengi", "komusta_ka"]
|
139 |
+
results = {}
|
140 |
+
|
141 |
+
for pattern in ref_patterns:
|
142 |
+
pattern_dir = os.path.join(REFERENCE_AUDIO_DIR, pattern)
|
143 |
+
if os.path.exists(pattern_dir):
|
144 |
+
wav_files = glob.glob(os.path.join(pattern_dir, "*.wav"))
|
145 |
+
results[pattern] = {
|
146 |
+
"exists": True,
|
147 |
+
"path": pattern_dir,
|
148 |
+
"file_count": len(wav_files),
|
149 |
+
"files": [os.path.basename(f) for f in wav_files]
|
150 |
+
}
|
151 |
+
else:
|
152 |
+
results[pattern] = {
|
153 |
+
"exists": False,
|
154 |
+
"path": pattern_dir
|
155 |
+
}
|
156 |
+
|
157 |
+
return jsonify({
|
158 |
+
"reference_audio_dir": REFERENCE_AUDIO_DIR,
|
159 |
+
"directory_exists": os.path.exists(REFERENCE_AUDIO_DIR),
|
160 |
+
"patterns": results
|
161 |
+
})
|
162 |
+
|
163 |
+
|
164 |
+
@app.route("/upload_reference", methods=["POST"])
|
165 |
+
def upload_reference_audio():
|
166 |
+
return handle_upload_reference(request, REFERENCE_AUDIO_DIR, SAMPLE_RATE)
|
167 |
+
|
168 |
+
|
169 |
+
# Add an initialization route that will be called before the first request
|
170 |
+
@app.before_request
|
171 |
+
def before_request():
|
172 |
+
if not hasattr(g, 'initialized'):
|
173 |
+
init_reference_audio(REFERENCE_AUDIO_DIR, OUTPUT_DIR)
|
174 |
+
g.initialized = True
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
init_reference_audio(REFERENCE_AUDIO_DIR, OUTPUT_DIR)
|
179 |
+
logger.info("π Starting Speech API server")
|
180 |
+
|
181 |
+
# Get the status for logging
|
182 |
+
status = check_model_status()
|
183 |
+
logger.info(f"π System status: ASR model: {'β
' if status['asr_model'] == 'loaded' else 'β'}")
|
184 |
+
for lang, model_status in status['tts_models'].items():
|
185 |
+
logger.info(f"π TTS model {lang}: {'β
' if model_status == 'loaded' else 'β'}")
|
186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
app.run(host="0.0.0.0", port=7860, debug=True)
|
evaluate.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# evaluate.py - Handles evaluation and comparing tasks
|
2 |
+
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
import tempfile
|
8 |
+
import shutil
|
9 |
+
from difflib import SequenceMatcher
|
10 |
+
import torch
|
11 |
+
import torchaudio
|
12 |
+
from pydub import AudioSegment
|
13 |
+
from flask import jsonify
|
14 |
+
from werkzeug.utils import secure_filename
|
15 |
+
from concurrent.futures import ThreadPoolExecutor
|
16 |
+
|
17 |
+
# Import necessary functions from translator.py
|
18 |
+
from translator import asr_model, asr_processor, LANGUAGE_CODES
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logger = logging.getLogger("speech_api")
|
22 |
+
|
23 |
+
def calculate_similarity(text1, text2):
|
24 |
+
"""Calculate text similarity percentage."""
|
25 |
+
def clean_text(text):
|
26 |
+
return text.lower()
|
27 |
+
|
28 |
+
clean1 = clean_text(text1)
|
29 |
+
clean2 = clean_text(text2)
|
30 |
+
|
31 |
+
matcher = SequenceMatcher(None, clean1, clean2)
|
32 |
+
return matcher.ratio() * 100
|
33 |
+
|
34 |
+
def init_reference_audio(reference_dir, output_dir):
|
35 |
+
try:
|
36 |
+
# Create the output directory first
|
37 |
+
os.makedirs(output_dir, exist_ok=True)
|
38 |
+
logger.info(f"π Created output directory: {output_dir}")
|
39 |
+
|
40 |
+
# Check if the reference audio directory exists in the repository
|
41 |
+
if os.path.exists(reference_dir):
|
42 |
+
logger.info(f"β
Found reference audio directory: {reference_dir}")
|
43 |
+
|
44 |
+
# Log the contents to verify
|
45 |
+
pattern_dirs = [d for d in os.listdir(reference_dir)
|
46 |
+
if os.path.isdir(os.path.join(reference_dir, d))]
|
47 |
+
logger.info(f"π Found reference patterns: {pattern_dirs}")
|
48 |
+
|
49 |
+
# Check each pattern directory for wav files
|
50 |
+
for pattern_dir_name in pattern_dirs:
|
51 |
+
pattern_path = os.path.join(reference_dir, pattern_dir_name)
|
52 |
+
wav_files = glob.glob(os.path.join(pattern_path, "*.wav"))
|
53 |
+
logger.info(f"π Found {len(wav_files)} wav files in {pattern_dir_name}")
|
54 |
+
else:
|
55 |
+
logger.warning(f"β οΈ Reference audio directory not found: {reference_dir}")
|
56 |
+
# Create the directory if it doesn't exist
|
57 |
+
os.makedirs(reference_dir, exist_ok=True)
|
58 |
+
logger.info(f"π Created reference audio directory: {reference_dir}")
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"β Failed to set up reference audio directory: {str(e)}")
|
61 |
+
|
62 |
+
def handle_upload_reference(request, reference_dir, sample_rate):
|
63 |
+
"""Handle upload of reference audio files"""
|
64 |
+
try:
|
65 |
+
if "audio" not in request.files:
|
66 |
+
logger.warning("β οΈ Reference upload missing audio file")
|
67 |
+
return jsonify({"error": "No audio file uploaded"}), 400
|
68 |
+
|
69 |
+
reference_word = request.form.get("reference_word", "").strip()
|
70 |
+
if not reference_word:
|
71 |
+
logger.warning("β οΈ Reference upload missing reference word")
|
72 |
+
return jsonify({"error": "No reference word provided"}), 400
|
73 |
+
|
74 |
+
# Validate reference word
|
75 |
+
reference_patterns = [
|
76 |
+
"mayap_a_abak", "mayap_a_ugtu", "mayap_a_gatpanapun", "mayap_a_bengi",
|
77 |
+
"komusta_ka", "malaus_ko_pu", "malaus_kayu", "agaganaka_da_ka",
|
78 |
+
"pagdulapan_da_ka", "kaluguran_da_ka", "dakal_a_salamat", "panapaya_mu_ku"
|
79 |
+
]
|
80 |
+
|
81 |
+
if reference_word not in reference_patterns:
|
82 |
+
logger.warning(f"β οΈ Invalid reference word: {reference_word}")
|
83 |
+
return jsonify({"error": f"Invalid reference word. Available: {reference_patterns}"}), 400
|
84 |
+
|
85 |
+
# Create directory for reference pattern if it doesn't exist
|
86 |
+
pattern_dir = os.path.join(reference_dir, reference_word)
|
87 |
+
os.makedirs(pattern_dir, exist_ok=True)
|
88 |
+
|
89 |
+
# Save the reference audio file
|
90 |
+
audio_file = request.files["audio"]
|
91 |
+
file_path = os.path.join(pattern_dir, secure_filename(audio_file.filename))
|
92 |
+
audio_file.save(file_path)
|
93 |
+
|
94 |
+
# Convert to WAV if not already in that format
|
95 |
+
if not file_path.lower().endswith('.wav'):
|
96 |
+
base_path = os.path.splitext(file_path)[0]
|
97 |
+
wav_path = f"{base_path}.wav"
|
98 |
+
try:
|
99 |
+
audio = AudioSegment.from_file(file_path)
|
100 |
+
audio = audio.set_frame_rate(sample_rate).set_channels(1)
|
101 |
+
audio.export(wav_path, format="wav")
|
102 |
+
# Remove original file if conversion successful
|
103 |
+
os.unlink(file_path)
|
104 |
+
file_path = wav_path
|
105 |
+
except Exception as e:
|
106 |
+
logger.error(f"β Reference audio conversion failed: {str(e)}")
|
107 |
+
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
|
108 |
+
|
109 |
+
logger.info(f"β
Reference audio saved successfully for {reference_word}: {file_path}")
|
110 |
+
|
111 |
+
# Count how many references we have now
|
112 |
+
references = glob.glob(os.path.join(pattern_dir, "*.wav"))
|
113 |
+
return jsonify({
|
114 |
+
"message": "Reference audio uploaded successfully",
|
115 |
+
"reference_word": reference_word,
|
116 |
+
"file": os.path.basename(file_path),
|
117 |
+
"total_references": len(references)
|
118 |
+
})
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
logger.error(f"β Unhandled exception in reference upload: {str(e)}")
|
122 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
123 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
124 |
+
|
125 |
+
|
126 |
+
def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
127 |
+
"""Handle pronunciation evaluation requests"""
|
128 |
+
request_id = f"req-{id(request)}" # Create unique ID for this request
|
129 |
+
logger.info(f"[{request_id}] π Starting new pronunciation evaluation request")
|
130 |
+
|
131 |
+
temp_dir = None
|
132 |
+
|
133 |
+
if asr_model is None or asr_processor is None:
|
134 |
+
logger.error(f"[{request_id}] β Evaluation endpoint called but ASR models aren't loaded")
|
135 |
+
return jsonify({"error": "ASR model not available"}), 503
|
136 |
+
|
137 |
+
try:
|
138 |
+
if "audio" not in request.files:
|
139 |
+
logger.warning(f"[{request_id}] β οΈ Evaluation request missing audio file")
|
140 |
+
return jsonify({"error": "No audio file uploaded"}), 400
|
141 |
+
|
142 |
+
audio_file = request.files["audio"]
|
143 |
+
reference_locator = request.form.get("reference_locator", "").strip()
|
144 |
+
language = request.form.get("language", "kapampangan").lower()
|
145 |
+
|
146 |
+
# Validate reference locator
|
147 |
+
if not reference_locator:
|
148 |
+
logger.warning(f"[{request_id}] β οΈ No reference locator provided")
|
149 |
+
return jsonify({"error": "Reference locator is required"}), 400
|
150 |
+
|
151 |
+
# Construct full reference directory path
|
152 |
+
reference_dir_path = os.path.join(reference_dir, reference_locator)
|
153 |
+
logger.info(f"[{request_id}] π Reference directory path: {reference_dir_path}")
|
154 |
+
|
155 |
+
if not os.path.exists(reference_dir_path):
|
156 |
+
logger.warning(f"[{request_id}] β οΈ Reference directory not found: {reference_dir_path}")
|
157 |
+
return jsonify({"error": f"Reference audio directory not found: {reference_locator}"}), 404
|
158 |
+
|
159 |
+
reference_files = glob.glob(os.path.join(reference_dir_path, "*.wav"))
|
160 |
+
logger.info(f"[{request_id}] π Found {len(reference_files)} reference files")
|
161 |
+
|
162 |
+
if not reference_files:
|
163 |
+
logger.warning(f"[{request_id}] β οΈ No reference audio files found in {reference_dir_path}")
|
164 |
+
return jsonify({"error": f"No reference audio found for {reference_locator}"}), 404
|
165 |
+
|
166 |
+
lang_code = LANGUAGE_CODES.get(language, language)
|
167 |
+
logger.info(
|
168 |
+
f"[{request_id}] π Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")
|
169 |
+
|
170 |
+
# Create a request-specific temp directory to avoid conflicts
|
171 |
+
temp_dir = os.path.join(output_dir, f"temp_{request_id}")
|
172 |
+
os.makedirs(temp_dir, exist_ok=True)
|
173 |
+
|
174 |
+
# Process user audio
|
175 |
+
user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
|
176 |
+
with open(user_audio_path, 'wb') as f:
|
177 |
+
f.write(audio_file.read())
|
178 |
+
|
179 |
+
try:
|
180 |
+
logger.info(f"[{request_id}] π Processing user audio file")
|
181 |
+
audio = AudioSegment.from_file(user_audio_path)
|
182 |
+
audio = audio.set_frame_rate(sample_rate).set_channels(1)
|
183 |
+
|
184 |
+
processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
|
185 |
+
audio.export(processed_path, format="wav")
|
186 |
+
|
187 |
+
user_waveform, sr = torchaudio.load(processed_path)
|
188 |
+
user_waveform = user_waveform.squeeze().numpy()
|
189 |
+
logger.info(f"[{request_id}] β
User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
|
190 |
+
|
191 |
+
user_audio_path = processed_path
|
192 |
+
except Exception as e:
|
193 |
+
logger.error(f"[{request_id}] β Audio processing failed: {str(e)}")
|
194 |
+
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
195 |
+
|
196 |
+
# Transcribe user audio
|
197 |
+
try:
|
198 |
+
logger.info(f"[{request_id}] π Transcribing user audio")
|
199 |
+
inputs = asr_processor(
|
200 |
+
user_waveform,
|
201 |
+
sampling_rate=sample_rate,
|
202 |
+
return_tensors="pt",
|
203 |
+
language=lang_code
|
204 |
+
)
|
205 |
+
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
logits = asr_model(**inputs).logits
|
209 |
+
ids = torch.argmax(logits, dim=-1)[0]
|
210 |
+
user_transcription = asr_processor.decode(ids)
|
211 |
+
|
212 |
+
logger.info(f"[{request_id}] β
User transcription: '{user_transcription}'")
|
213 |
+
except Exception as e:
|
214 |
+
logger.error(f"[{request_id}] β ASR inference failed: {str(e)}")
|
215 |
+
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
216 |
+
|
217 |
+
# Process reference files in batches
|
218 |
+
batch_size = 2 # Process 2 files at a time - adjust based on your hardware
|
219 |
+
results = []
|
220 |
+
best_score = 0
|
221 |
+
best_reference = None
|
222 |
+
best_transcription = None
|
223 |
+
|
224 |
+
# Use this if you want to limit the number of files to process
|
225 |
+
max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
|
226 |
+
reference_files = reference_files[:max_files_to_check]
|
227 |
+
|
228 |
+
logger.info(f"[{request_id}] π Processing {len(reference_files)} reference files in batches of {batch_size}")
|
229 |
+
|
230 |
+
# Function to process a single reference file
|
231 |
+
def process_reference_file(ref_file):
|
232 |
+
ref_filename = os.path.basename(ref_file)
|
233 |
+
try:
|
234 |
+
# Load and resample reference audio
|
235 |
+
ref_waveform, ref_sr = torchaudio.load(ref_file)
|
236 |
+
if ref_sr != sample_rate:
|
237 |
+
ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
|
238 |
+
ref_waveform = ref_waveform.squeeze().numpy()
|
239 |
+
|
240 |
+
# Transcribe reference audio
|
241 |
+
inputs = asr_processor(
|
242 |
+
ref_waveform,
|
243 |
+
sampling_rate=sample_rate,
|
244 |
+
return_tensors="pt",
|
245 |
+
language=lang_code
|
246 |
+
)
|
247 |
+
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
|
248 |
+
|
249 |
+
with torch.no_grad():
|
250 |
+
logits = asr_model(**inputs).logits
|
251 |
+
ids = torch.argmax(logits, dim=-1)[0]
|
252 |
+
ref_transcription = asr_processor.decode(ids)
|
253 |
+
|
254 |
+
# Calculate similarity
|
255 |
+
similarity = calculate_similarity(user_transcription, ref_transcription)
|
256 |
+
|
257 |
+
logger.info(
|
258 |
+
f"[{request_id}] π Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
|
259 |
+
|
260 |
+
return {
|
261 |
+
"reference_file": ref_filename,
|
262 |
+
"reference_text": ref_transcription,
|
263 |
+
"similarity_score": similarity
|
264 |
+
}
|
265 |
+
except Exception as e:
|
266 |
+
logger.error(f"[{request_id}] β Error processing {ref_filename}: {str(e)}")
|
267 |
+
return {
|
268 |
+
"reference_file": ref_filename,
|
269 |
+
"reference_text": "Error",
|
270 |
+
"similarity_score": 0,
|
271 |
+
"error": str(e)
|
272 |
+
}
|
273 |
+
|
274 |
+
# Process files in batches using ThreadPoolExecutor
|
275 |
+
with ThreadPoolExecutor(max_workers=batch_size) as executor:
|
276 |
+
batch_results = list(executor.map(process_reference_file, reference_files))
|
277 |
+
results.extend(batch_results)
|
278 |
+
|
279 |
+
# Find the best result
|
280 |
+
for result in batch_results:
|
281 |
+
if result["similarity_score"] > best_score:
|
282 |
+
best_score = result["similarity_score"]
|
283 |
+
best_reference = result["reference_file"]
|
284 |
+
best_transcription = result["reference_text"]
|
285 |
+
|
286 |
+
# Exit early if we found a very good match (optional)
|
287 |
+
if best_score > 80.0:
|
288 |
+
logger.info(f"[{request_id}] π Found excellent match: {best_score:.2f}%")
|
289 |
+
break
|
290 |
+
|
291 |
+
# Clean up temp files
|
292 |
+
try:
|
293 |
+
if temp_dir and os.path.exists(temp_dir):
|
294 |
+
shutil.rmtree(temp_dir)
|
295 |
+
logger.debug(f"[{request_id}] π§Ή Cleaned up temporary directory")
|
296 |
+
except Exception as e:
|
297 |
+
logger.warning(f"[{request_id}] β οΈ Failed to clean up temp files: {str(e)}")
|
298 |
+
|
299 |
+
# Determine feedback based on score
|
300 |
+
is_correct = best_score >= 70.0
|
301 |
+
|
302 |
+
if best_score >= 90.0:
|
303 |
+
feedback = "Perfect pronunciation! Excellent job!"
|
304 |
+
elif best_score >= 80.0:
|
305 |
+
feedback = "Great pronunciation! Your accent is very good."
|
306 |
+
elif best_score >= 70.0:
|
307 |
+
feedback = "Good pronunciation. Keep practicing!"
|
308 |
+
elif best_score >= 50.0:
|
309 |
+
feedback = "Fair attempt. Try focusing on the syllables that differ from the sample."
|
310 |
+
else:
|
311 |
+
feedback = "Try again. Listen carefully to the sample pronunciation."
|
312 |
+
|
313 |
+
logger.info(f"[{request_id}] π Final evaluation results: score={best_score:.2f}%, is_correct={is_correct}")
|
314 |
+
logger.info(f"[{request_id}] π Feedback: '{feedback}'")
|
315 |
+
logger.info(f"[{request_id}] β
Evaluation complete")
|
316 |
+
|
317 |
+
# Sort results by score descending
|
318 |
+
results.sort(key=lambda x: x["similarity_score"], reverse=True)
|
319 |
+
|
320 |
+
return jsonify({
|
321 |
+
"is_correct": is_correct,
|
322 |
+
"score": best_score,
|
323 |
+
"feedback": feedback,
|
324 |
+
"user_transcription": user_transcription,
|
325 |
+
"best_reference_transcription": best_transcription,
|
326 |
+
"reference_locator": reference_locator,
|
327 |
+
"details": results
|
328 |
+
})
|
329 |
+
|
330 |
+
except Exception as e:
|
331 |
+
logger.error(f"[{request_id}] β Unhandled exception in evaluation endpoint: {str(e)}")
|
332 |
+
logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
|
333 |
+
|
334 |
+
# Clean up on error
|
335 |
+
try:
|
336 |
+
if temp_dir and os.path.exists(temp_dir):
|
337 |
+
shutil.rmtree(temp_dir)
|
338 |
+
except:
|
339 |
+
pass
|
340 |
+
|
341 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
translator.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# translator.py - Handles ASR, TTS, and translation tasks
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
import tempfile
|
10 |
+
import soundfile as sf
|
11 |
+
from pydub import AudioSegment
|
12 |
+
from flask import jsonify
|
13 |
+
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
14 |
+
from transformers import MarianMTModel, MarianTokenizer
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logger = logging.getLogger("speech_api")
|
18 |
+
|
19 |
+
# Global variables to store models and processors
|
20 |
+
asr_model = None
|
21 |
+
asr_processor = None
|
22 |
+
tts_models = {}
|
23 |
+
tts_processors = {}
|
24 |
+
translation_models = {}
|
25 |
+
translation_tokenizers = {}
|
26 |
+
|
27 |
+
# Language-specific configurations
|
28 |
+
LANGUAGE_CODES = {
|
29 |
+
"kapampangan": "pam",
|
30 |
+
"filipino": "fil",
|
31 |
+
"english": "eng",
|
32 |
+
"tagalog": "tgl",
|
33 |
+
}
|
34 |
+
|
35 |
+
# TTS Models (Kapampangan, Tagalog, English)
|
36 |
+
TTS_MODELS = {
|
37 |
+
"kapampangan": "facebook/mms-tts-pam",
|
38 |
+
"tagalog": "facebook/mms-tts-tgl",
|
39 |
+
"english": "facebook/mms-tts-eng"
|
40 |
+
}
|
41 |
+
|
42 |
+
# Translation Models
|
43 |
+
TRANSLATION_MODELS = {
|
44 |
+
"pam-eng": "Coco-18/opus-mt-pam-en",
|
45 |
+
"eng-pam": "Coco-18/opus-mt-en-pam",
|
46 |
+
"tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
|
47 |
+
"eng-tgl": "Helsinki-NLP/opus-mt-en-tl",
|
48 |
+
"phi": "Coco-18/opus-mt-phi"
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def init_models(device):
|
53 |
+
"""Initialize all models required for the API"""
|
54 |
+
global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
|
55 |
+
|
56 |
+
# Initialize ASR model
|
57 |
+
ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
|
58 |
+
logger.info(f"π Loading ASR model: {ASR_MODEL_ID}")
|
59 |
+
|
60 |
+
try:
|
61 |
+
asr_processor = AutoProcessor.from_pretrained(
|
62 |
+
ASR_MODEL_ID,
|
63 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
64 |
+
)
|
65 |
+
logger.info("β
ASR processor loaded successfully")
|
66 |
+
|
67 |
+
asr_model = Wav2Vec2ForCTC.from_pretrained(
|
68 |
+
ASR_MODEL_ID,
|
69 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
70 |
+
)
|
71 |
+
asr_model.to(device)
|
72 |
+
logger.info(f"β
ASR model loaded successfully on {device}")
|
73 |
+
except Exception as e:
|
74 |
+
logger.error(f"β Error loading ASR model: {str(e)}")
|
75 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
76 |
+
|
77 |
+
# Initialize TTS models
|
78 |
+
for lang, model_id in TTS_MODELS.items():
|
79 |
+
logger.info(f"π Loading TTS model for {lang}: {model_id}")
|
80 |
+
try:
|
81 |
+
tts_processors[lang] = AutoTokenizer.from_pretrained(
|
82 |
+
model_id,
|
83 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
84 |
+
)
|
85 |
+
logger.info(f"β
{lang} TTS processor loaded")
|
86 |
+
|
87 |
+
tts_models[lang] = VitsModel.from_pretrained(
|
88 |
+
model_id,
|
89 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
90 |
+
)
|
91 |
+
tts_models[lang].to(device)
|
92 |
+
logger.info(f"β
{lang} TTS model loaded on {device}")
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"β Failed to load {lang} TTS model: {str(e)}")
|
95 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
96 |
+
tts_models[lang] = None
|
97 |
+
|
98 |
+
# Initialize translation models
|
99 |
+
for model_key, model_id in TRANSLATION_MODELS.items():
|
100 |
+
logger.info(f"π Loading Translation model: {model_id}")
|
101 |
+
|
102 |
+
try:
|
103 |
+
translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
|
104 |
+
model_id,
|
105 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
106 |
+
)
|
107 |
+
logger.info(f"β
Translation tokenizer loaded successfully for {model_key}")
|
108 |
+
|
109 |
+
translation_models[model_key] = MarianMTModel.from_pretrained(
|
110 |
+
model_id,
|
111 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
112 |
+
)
|
113 |
+
translation_models[model_key].to(device)
|
114 |
+
logger.info(f"β
Translation model loaded successfully on {device} for {model_key}")
|
115 |
+
except Exception as e:
|
116 |
+
logger.error(f"β Error loading Translation model for {model_key}: {str(e)}")
|
117 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
118 |
+
translation_models[model_key] = None
|
119 |
+
translation_tokenizers[model_key] = None
|
120 |
+
|
121 |
+
|
122 |
+
def check_model_status():
|
123 |
+
"""Check and return the status of all models"""
|
124 |
+
# Initialize direct language pair statuses based on loaded models
|
125 |
+
translation_status = {}
|
126 |
+
|
127 |
+
# Add status for direct model pairs
|
128 |
+
for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]:
|
129 |
+
translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[
|
130 |
+
lang_pair] is not None else "failed"
|
131 |
+
|
132 |
+
# Add special phi model status
|
133 |
+
phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed"
|
134 |
+
translation_status["pam-fil"] = phi_status
|
135 |
+
translation_status["fil-pam"] = phi_status
|
136 |
+
translation_status["pam-tgl"] = phi_status # Using phi model but replacing tgl with fil
|
137 |
+
translation_status["tgl-pam"] = phi_status # Using phi model but replacing tgl with fil
|
138 |
+
|
139 |
+
return {
|
140 |
+
"asr_model": "loaded" if asr_model is not None else "failed",
|
141 |
+
"tts_models": {lang: "loaded" if model is not None else "failed"
|
142 |
+
for lang, model in tts_models.items()},
|
143 |
+
"translation_models": translation_status
|
144 |
+
}
|
145 |
+
|
146 |
+
|
147 |
+
def handle_asr_request(request, output_dir, sample_rate):
|
148 |
+
"""Handle ASR (Automatic Speech Recognition) requests"""
|
149 |
+
if asr_model is None or asr_processor is None:
|
150 |
+
logger.error("β ASR endpoint called but models aren't loaded")
|
151 |
+
return jsonify({"error": "ASR model not available"}), 503
|
152 |
+
|
153 |
+
try:
|
154 |
+
if "audio" not in request.files:
|
155 |
+
logger.warning("β οΈ ASR request missing audio file")
|
156 |
+
return jsonify({"error": "No audio file uploaded"}), 400
|
157 |
+
|
158 |
+
audio_file = request.files["audio"]
|
159 |
+
language = request.form.get("language", "english").lower()
|
160 |
+
|
161 |
+
if language not in LANGUAGE_CODES:
|
162 |
+
logger.warning(f"β οΈ Unsupported language requested: {language}")
|
163 |
+
return jsonify(
|
164 |
+
{"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
|
165 |
+
|
166 |
+
lang_code = LANGUAGE_CODES[language]
|
167 |
+
logger.info(f"π Processing {language} audio for ASR")
|
168 |
+
|
169 |
+
# Save the uploaded file temporarily
|
170 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
|
171 |
+
temp_audio.write(audio_file.read())
|
172 |
+
temp_audio_path = temp_audio.name
|
173 |
+
logger.debug(f"π Temporary audio saved to {temp_audio_path}")
|
174 |
+
|
175 |
+
# Convert to WAV if necessary
|
176 |
+
wav_path = temp_audio_path
|
177 |
+
if not audio_file.filename.lower().endswith(".wav"):
|
178 |
+
wav_path = os.path.join(output_dir, "converted_audio.wav")
|
179 |
+
logger.info(f"π Converting audio to WAV format: {wav_path}")
|
180 |
+
try:
|
181 |
+
audio = AudioSegment.from_file(temp_audio_path)
|
182 |
+
audio = audio.set_frame_rate(sample_rate).set_channels(1)
|
183 |
+
audio.export(wav_path, format="wav")
|
184 |
+
except Exception as e:
|
185 |
+
logger.error(f"β Audio conversion failed: {str(e)}")
|
186 |
+
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
|
187 |
+
|
188 |
+
# Load and process the WAV file
|
189 |
+
try:
|
190 |
+
waveform, sr = torchaudio.load(wav_path)
|
191 |
+
logger.debug(f"β
Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
|
192 |
+
|
193 |
+
# Resample if needed
|
194 |
+
if sr != sample_rate:
|
195 |
+
logger.info(f"π Resampling audio from {sr}Hz to {sample_rate}Hz")
|
196 |
+
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
|
197 |
+
|
198 |
+
waveform = waveform / torch.max(torch.abs(waveform))
|
199 |
+
except Exception as e:
|
200 |
+
logger.error(f"β Failed to load or process audio: {str(e)}")
|
201 |
+
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
202 |
+
|
203 |
+
# Process audio for ASR
|
204 |
+
try:
|
205 |
+
inputs = asr_processor(
|
206 |
+
waveform.squeeze().numpy(),
|
207 |
+
sampling_rate=sample_rate,
|
208 |
+
return_tensors="pt",
|
209 |
+
language=lang_code
|
210 |
+
)
|
211 |
+
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
|
212 |
+
except Exception as e:
|
213 |
+
logger.error(f"β ASR preprocessing failed: {str(e)}")
|
214 |
+
return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
|
215 |
+
|
216 |
+
# Perform ASR
|
217 |
+
try:
|
218 |
+
with torch.no_grad():
|
219 |
+
logits = asr_model(**inputs).logits
|
220 |
+
ids = torch.argmax(logits, dim=-1)[0]
|
221 |
+
transcription = asr_processor.decode(ids)
|
222 |
+
|
223 |
+
logger.info(f"β
Transcription ({language}): {transcription}")
|
224 |
+
|
225 |
+
# Clean up temp files
|
226 |
+
try:
|
227 |
+
os.unlink(temp_audio_path)
|
228 |
+
if wav_path != temp_audio_path:
|
229 |
+
os.unlink(wav_path)
|
230 |
+
except Exception as e:
|
231 |
+
logger.warning(f"β οΈ Failed to clean up temp files: {str(e)}")
|
232 |
+
|
233 |
+
return jsonify({
|
234 |
+
"transcription": transcription,
|
235 |
+
"language": language,
|
236 |
+
"language_code": lang_code
|
237 |
+
})
|
238 |
+
except Exception as e:
|
239 |
+
logger.error(f"β ASR inference failed: {str(e)}")
|
240 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
241 |
+
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
242 |
+
|
243 |
+
except Exception as e:
|
244 |
+
logger.error(f"β Unhandled exception in ASR endpoint: {str(e)}")
|
245 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
246 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
247 |
+
|
248 |
+
def handle_tts_request(request, output_dir):
|
249 |
+
"""Handle TTS (Text-to-Speech) requests"""
|
250 |
+
try:
|
251 |
+
data = request.get_json()
|
252 |
+
if not data:
|
253 |
+
logger.warning("β οΈ TTS endpoint called with no JSON data")
|
254 |
+
return jsonify({"error": "No JSON data provided"}), 400
|
255 |
+
|
256 |
+
text_input = data.get("text", "").strip()
|
257 |
+
language = data.get("language", "kapampangan").lower()
|
258 |
+
|
259 |
+
if not text_input:
|
260 |
+
logger.warning("β οΈ TTS request with empty text")
|
261 |
+
return jsonify({"error": "No text provided"}), 400
|
262 |
+
|
263 |
+
if language not in TTS_MODELS:
|
264 |
+
logger.warning(f"β οΈ TTS requested for unsupported language: {language}")
|
265 |
+
return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
|
266 |
+
|
267 |
+
if tts_models[language] is None:
|
268 |
+
logger.error(f"β TTS model for {language} not loaded")
|
269 |
+
return jsonify({"error": f"TTS model for {language} not available"}), 503
|
270 |
+
|
271 |
+
logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'")
|
272 |
+
|
273 |
+
try:
|
274 |
+
processor = tts_processors[language]
|
275 |
+
model = tts_models[language]
|
276 |
+
inputs = processor(text_input, return_tensors="pt")
|
277 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
278 |
+
except Exception as e:
|
279 |
+
logger.error(f"β TTS preprocessing failed: {str(e)}")
|
280 |
+
return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
|
281 |
+
|
282 |
+
# Generate speech
|
283 |
+
try:
|
284 |
+
with torch.no_grad():
|
285 |
+
output = model(**inputs).waveform
|
286 |
+
waveform = output.squeeze().cpu().numpy()
|
287 |
+
except Exception as e:
|
288 |
+
logger.error(f"β TTS inference failed: {str(e)}")
|
289 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
290 |
+
return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
|
291 |
+
|
292 |
+
# Save to file
|
293 |
+
try:
|
294 |
+
output_filename = os.path.join(output_dir, f"{language}_output.wav")
|
295 |
+
sampling_rate = model.config.sampling_rate
|
296 |
+
sf.write(output_filename, waveform, sampling_rate)
|
297 |
+
logger.info(f"β
Speech generated! File saved: {output_filename}")
|
298 |
+
except Exception as e:
|
299 |
+
logger.error(f"β Failed to save audio file: {str(e)}")
|
300 |
+
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
|
301 |
+
|
302 |
+
return jsonify({
|
303 |
+
"message": "TTS audio generated",
|
304 |
+
"file_url": f"/download/{os.path.basename(output_filename)}",
|
305 |
+
"language": language,
|
306 |
+
"text_length": len(text_input)
|
307 |
+
})
|
308 |
+
except Exception as e:
|
309 |
+
logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}")
|
310 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
311 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
312 |
+
|
313 |
+
def handle_translation_request(request):
|
314 |
+
"""Handle translation requests"""
|
315 |
+
try:
|
316 |
+
data = request.get_json()
|
317 |
+
if not data:
|
318 |
+
logger.warning("β οΈ Translation endpoint called with no JSON data")
|
319 |
+
return jsonify({"error": "No JSON data provided"}), 400
|
320 |
+
|
321 |
+
source_text = data.get("text", "").strip()
|
322 |
+
source_language = data.get("source_language", "").lower()
|
323 |
+
target_language = data.get("target_language", "").lower()
|
324 |
+
|
325 |
+
if not source_text:
|
326 |
+
logger.warning("β οΈ Translation request with empty text")
|
327 |
+
return jsonify({"error": "No text provided"}), 400
|
328 |
+
|
329 |
+
# Map language names to codes
|
330 |
+
source_code = LANGUAGE_CODES.get(source_language, source_language)
|
331 |
+
target_code = LANGUAGE_CODES.get(target_language, target_language)
|
332 |
+
|
333 |
+
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
334 |
+
|
335 |
+
# Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
|
336 |
+
use_phi_model = False
|
337 |
+
actual_source_code = source_code
|
338 |
+
actual_target_code = target_code
|
339 |
+
|
340 |
+
# Check if we need to use the phi model with fil replacement
|
341 |
+
if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
|
342 |
+
use_phi_model = True
|
343 |
+
elif (source_code == "pam" and target_code == "tgl"):
|
344 |
+
use_phi_model = True
|
345 |
+
actual_target_code = "fil" # Replace tgl with fil for the phi model
|
346 |
+
elif (source_code == "tgl" and target_code == "pam"):
|
347 |
+
use_phi_model = True
|
348 |
+
actual_source_code = "fil" # Replace tgl with fil for the phi model
|
349 |
+
|
350 |
+
if use_phi_model:
|
351 |
+
model_key = "phi"
|
352 |
+
|
353 |
+
# Check if we have the phi model
|
354 |
+
if model_key not in translation_models or translation_models[model_key] is None:
|
355 |
+
logger.error(f"β Translation model for {model_key} not loaded")
|
356 |
+
return jsonify({"error": f"Translation model not available"}), 503
|
357 |
+
|
358 |
+
try:
|
359 |
+
# Get the phi model and tokenizer
|
360 |
+
model = translation_models[model_key]
|
361 |
+
tokenizer = translation_tokenizers[model_key]
|
362 |
+
|
363 |
+
# Prepend target language token to input
|
364 |
+
input_text = f">>{actual_target_code}<< {source_text}"
|
365 |
+
|
366 |
+
logger.info(f"π Using phi model with input: '{input_text}'")
|
367 |
+
|
368 |
+
# Tokenize the text
|
369 |
+
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
370 |
+
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
371 |
+
|
372 |
+
# Generate translation
|
373 |
+
with torch.no_grad():
|
374 |
+
translated = model.generate(**tokenized)
|
375 |
+
|
376 |
+
# Decode the translation
|
377 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
378 |
+
|
379 |
+
logger.info(f"β
Translation result: '{result}'")
|
380 |
+
|
381 |
+
return jsonify({
|
382 |
+
"translated_text": result,
|
383 |
+
"source_language": source_language,
|
384 |
+
"target_language": target_language
|
385 |
+
})
|
386 |
+
except Exception as e:
|
387 |
+
logger.error(f"β Translation processing failed: {str(e)}")
|
388 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
389 |
+
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
390 |
+
else:
|
391 |
+
# Create the regular language pair key for other language pairs
|
392 |
+
lang_pair = f"{source_code}-{target_code}"
|
393 |
+
|
394 |
+
# Check if we have a model for this language pair
|
395 |
+
if lang_pair not in translation_models:
|
396 |
+
logger.warning(f"β οΈ No translation model available for {lang_pair}")
|
397 |
+
return jsonify(
|
398 |
+
{"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
|
399 |
+
|
400 |
+
if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
|
401 |
+
logger.error(f"β Translation model for {lang_pair} not loaded")
|
402 |
+
return jsonify({"error": f"Translation model not available"}), 503
|
403 |
+
|
404 |
+
try:
|
405 |
+
# Regular translation process for other language pairs
|
406 |
+
model = translation_models[lang_pair]
|
407 |
+
tokenizer = translation_tokenizers[lang_pair]
|
408 |
+
|
409 |
+
# Tokenize the text
|
410 |
+
tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
|
411 |
+
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
412 |
+
|
413 |
+
# Generate translation
|
414 |
+
with torch.no_grad():
|
415 |
+
translated = model.generate(**tokenized)
|
416 |
+
|
417 |
+
# Decode the translation
|
418 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
419 |
+
|
420 |
+
logger.info(f"β
Translation result: '{result}'")
|
421 |
+
|
422 |
+
return jsonify({
|
423 |
+
"translated_text": result,
|
424 |
+
"source_language": source_language,
|
425 |
+
"target_language": target_language
|
426 |
+
})
|
427 |
+
except Exception as e:
|
428 |
+
logger.error(f"β Translation processing failed: {str(e)}")
|
429 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
430 |
+
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
431 |
+
|
432 |
+
except Exception as e:
|
433 |
+
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}")
|
434 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
435 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|