Update translator.py
Browse files- translator.py +173 -336
translator.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# translator.py - Handles ASR, TTS, and translation tasks
|
2 |
|
3 |
import os
|
4 |
import sys
|
@@ -12,11 +12,6 @@ from pydub import AudioSegment
|
|
12 |
from flask import jsonify
|
13 |
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
14 |
from transformers import MarianMTModel, MarianTokenizer
|
15 |
-
import concurrent.futures
|
16 |
-
import functools
|
17 |
-
import threading
|
18 |
-
from concurrent.futures import ThreadPoolExecutor
|
19 |
-
from functools import lru_cache
|
20 |
|
21 |
# Configure logging
|
22 |
logger = logging.getLogger("speech_api")
|
@@ -29,16 +24,6 @@ tts_processors = {}
|
|
29 |
translation_models = {}
|
30 |
translation_tokenizers = {}
|
31 |
|
32 |
-
# Caching dictionaries
|
33 |
-
asr_cache = {}
|
34 |
-
tts_cache = {}
|
35 |
-
translation_cache = {}
|
36 |
-
|
37 |
-
# Mutex locks for thread safety
|
38 |
-
asr_lock = threading.Lock()
|
39 |
-
tts_lock = threading.Lock()
|
40 |
-
translation_lock = threading.Lock()
|
41 |
-
|
42 |
# Language-specific configurations
|
43 |
LANGUAGE_CODES = {
|
44 |
"kapampangan": "pam",
|
@@ -63,114 +48,74 @@ TRANSLATION_MODELS = {
|
|
63 |
"phi": "Coco-18/opus-mt-phi"
|
64 |
}
|
65 |
|
66 |
-
# Cache settings
|
67 |
-
MAX_CACHE_SIZE = 100 # Maximum number of items to cache
|
68 |
-
CACHE_TTL = 3600 # Time to live in seconds (1 hour)
|
69 |
-
|
70 |
def init_models(device):
|
71 |
-
"""Initialize all models required for the API
|
72 |
global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
def init_tts(lang, model_id):
|
99 |
try:
|
100 |
-
|
101 |
model_id,
|
102 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
103 |
)
|
104 |
-
|
105 |
-
|
|
|
106 |
model_id,
|
107 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
108 |
)
|
109 |
-
|
110 |
logger.info(f"β
{lang} TTS model loaded on {device}")
|
111 |
-
return lang, processor, model
|
112 |
except Exception as e:
|
113 |
logger.error(f"β Failed to load {lang} TTS model: {str(e)}")
|
114 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
118 |
try:
|
119 |
-
|
120 |
model_id,
|
121 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
122 |
)
|
123 |
-
|
124 |
-
|
|
|
125 |
model_id,
|
126 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
127 |
)
|
128 |
-
|
129 |
logger.info(f"β
Translation model loaded successfully on {device} for {model_key}")
|
130 |
-
return model_key, tokenizer, model
|
131 |
except Exception as e:
|
132 |
logger.error(f"β Error loading Translation model for {model_key}: {str(e)}")
|
133 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
134 |
-
|
135 |
-
|
136 |
-
# Use ThreadPoolExecutor to initialize models in parallel
|
137 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
138 |
-
# Start ASR model initialization
|
139 |
-
asr_future = executor.submit(init_asr)
|
140 |
-
|
141 |
-
# Start TTS model initialization in parallel
|
142 |
-
tts_futures = {
|
143 |
-
executor.submit(init_tts, lang, model_id): lang
|
144 |
-
for lang, model_id in TTS_MODELS.items()
|
145 |
-
}
|
146 |
-
|
147 |
-
# Start translation model initialization in parallel
|
148 |
-
translation_futures = {
|
149 |
-
executor.submit(init_translation, model_key, model_id): model_key
|
150 |
-
for model_key, model_id in TRANSLATION_MODELS.items()
|
151 |
-
}
|
152 |
-
|
153 |
-
# Wait for all futures to complete and process results
|
154 |
-
|
155 |
-
# Process TTS results
|
156 |
-
for future in concurrent.futures.as_completed(tts_futures):
|
157 |
-
lang, processor, model = future.result()
|
158 |
-
if processor is not None and model is not None:
|
159 |
-
tts_processors[lang] = processor
|
160 |
-
tts_models[lang] = model
|
161 |
-
|
162 |
-
# Process translation results
|
163 |
-
for future in concurrent.futures.as_completed(translation_futures):
|
164 |
-
model_key, tokenizer, model = future.result()
|
165 |
-
if tokenizer is not None and model is not None:
|
166 |
-
translation_tokenizers[model_key] = tokenizer
|
167 |
-
translation_models[model_key] = model
|
168 |
-
|
169 |
-
# Log summary of loaded models
|
170 |
-
logger.info("π Model initialization summary:")
|
171 |
-
logger.info(f" - ASR model: {'loaded' if asr_model is not None else 'failed'}")
|
172 |
-
logger.info(f" - TTS models loaded: {sum(1 for m in tts_models.values() if m is not None)}/{len(TTS_MODELS)}")
|
173 |
-
logger.info(f" - Translation models loaded: {sum(1 for m in translation_models.values() if m is not None)}/{len(TRANSLATION_MODELS)}")
|
174 |
|
175 |
|
176 |
def check_model_status():
|
@@ -197,50 +142,9 @@ def check_model_status():
|
|
197 |
"translation_models": translation_status
|
198 |
}
|
199 |
|
200 |
-
# Cache for ASR results
|
201 |
-
@lru_cache(maxsize=MAX_CACHE_SIZE)
|
202 |
-
def get_cached_transcription(file_hash, language_code):
|
203 |
-
"""Retrieve cached transcription result if available"""
|
204 |
-
return asr_cache.get((file_hash, language_code))
|
205 |
-
|
206 |
-
def process_audio_file(audio_data, temp_audio_path, output_dir, sample_rate):
|
207 |
-
"""Process audio file for ASR (separate from ASR logic)"""
|
208 |
-
wav_path = temp_audio_path
|
209 |
-
|
210 |
-
if not temp_audio_path.lower().endswith(".wav"):
|
211 |
-
wav_path = os.path.join(output_dir, "converted_audio.wav")
|
212 |
-
logger.info(f"π Converting audio to WAV format: {wav_path}")
|
213 |
-
try:
|
214 |
-
audio = AudioSegment.from_file(temp_audio_path)
|
215 |
-
audio = audio.set_frame_rate(sample_rate).set_channels(1)
|
216 |
-
audio.export(wav_path, format="wav")
|
217 |
-
except Exception as e:
|
218 |
-
logger.error(f"β Audio conversion failed: {str(e)}")
|
219 |
-
raise Exception(f"Audio conversion failed: {str(e)}")
|
220 |
-
|
221 |
-
# Load and process the WAV file
|
222 |
-
try:
|
223 |
-
waveform, sr = torchaudio.load(wav_path)
|
224 |
-
|
225 |
-
# Resample if needed
|
226 |
-
if sr != sample_rate:
|
227 |
-
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
|
228 |
-
|
229 |
-
# Normalize waveform
|
230 |
-
waveform = waveform / torch.max(torch.abs(waveform))
|
231 |
-
|
232 |
-
return waveform.squeeze().numpy(), wav_path
|
233 |
-
except Exception as e:
|
234 |
-
logger.error(f"β Failed to load or process audio: {str(e)}")
|
235 |
-
raise Exception(f"Audio processing failed: {str(e)}")
|
236 |
-
|
237 |
-
def compute_audio_hash(audio_data):
|
238 |
-
"""Compute a hash of audio data for caching purposes"""
|
239 |
-
import hashlib
|
240 |
-
return hashlib.md5(audio_data).hexdigest()
|
241 |
|
242 |
def handle_asr_request(request, output_dir, sample_rate):
|
243 |
-
"""Handle ASR (Automatic Speech Recognition) requests
|
244 |
if asr_model is None or asr_processor is None:
|
245 |
logger.error("β ASR endpoint called but models aren't loaded")
|
246 |
return jsonify({"error": "ASR model not available"}), 503
|
@@ -261,40 +165,44 @@ def handle_asr_request(request, output_dir, sample_rate):
|
|
261 |
lang_code = LANGUAGE_CODES[language]
|
262 |
logger.info(f"π Processing {language} audio for ASR")
|
263 |
|
264 |
-
# Read the file content for hashing
|
265 |
-
audio_content = audio_file.read()
|
266 |
-
audio_hash = compute_audio_hash(audio_content)
|
267 |
-
|
268 |
-
# Check cache first
|
269 |
-
with asr_lock:
|
270 |
-
cached_result = asr_cache.get((audio_hash, lang_code))
|
271 |
-
if cached_result:
|
272 |
-
logger.info(f"β
Using cached ASR result for {language}")
|
273 |
-
return jsonify({
|
274 |
-
"transcription": cached_result,
|
275 |
-
"language": language,
|
276 |
-
"language_code": lang_code,
|
277 |
-
"from_cache": True
|
278 |
-
})
|
279 |
-
|
280 |
# Save the uploaded file temporarily
|
281 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
|
282 |
-
temp_audio.write(
|
283 |
temp_audio_path = temp_audio.name
|
284 |
logger.debug(f"π Temporary audio saved to {temp_audio_path}")
|
285 |
|
286 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
try:
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
except Exception as e:
|
292 |
-
|
|
|
293 |
|
294 |
# Process audio for ASR
|
295 |
try:
|
296 |
inputs = asr_processor(
|
297 |
-
waveform,
|
298 |
sampling_rate=sample_rate,
|
299 |
return_tensors="pt",
|
300 |
language=lang_code
|
@@ -312,14 +220,6 @@ def handle_asr_request(request, output_dir, sample_rate):
|
|
312 |
transcription = asr_processor.decode(ids)
|
313 |
|
314 |
logger.info(f"β
Transcription ({language}): {transcription}")
|
315 |
-
|
316 |
-
# Cache the result
|
317 |
-
with asr_lock:
|
318 |
-
asr_cache[(audio_hash, lang_code)] = transcription
|
319 |
-
# Implement cache size limitation if needed
|
320 |
-
if len(asr_cache) > MAX_CACHE_SIZE:
|
321 |
-
# Remove oldest entry (simplified approach)
|
322 |
-
asr_cache.pop(next(iter(asr_cache)))
|
323 |
|
324 |
# Clean up temp files
|
325 |
try:
|
@@ -332,8 +232,7 @@ def handle_asr_request(request, output_dir, sample_rate):
|
|
332 |
return jsonify({
|
333 |
"transcription": transcription,
|
334 |
"language": language,
|
335 |
-
"language_code": lang_code
|
336 |
-
"from_cache": False
|
337 |
})
|
338 |
except Exception as e:
|
339 |
logger.error(f"β ASR inference failed: {str(e)}")
|
@@ -345,14 +244,8 @@ def handle_asr_request(request, output_dir, sample_rate):
|
|
345 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
346 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
347 |
|
348 |
-
# Cache key generator for TTS
|
349 |
-
def tts_cache_key(text, language):
|
350 |
-
"""Generate a cache key for TTS results"""
|
351 |
-
import hashlib
|
352 |
-
return hashlib.md5(f"{text}:{language}".encode()).hexdigest()
|
353 |
-
|
354 |
def handle_tts_request(request, output_dir):
|
355 |
-
"""Handle TTS (Text-to-Speech) requests
|
356 |
try:
|
357 |
data = request.get_json()
|
358 |
if not data:
|
@@ -375,57 +268,7 @@ def handle_tts_request(request, output_dir):
|
|
375 |
return jsonify({"error": f"TTS model for {language} not available"}), 503
|
376 |
|
377 |
logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'")
|
378 |
-
|
379 |
-
# Generate cache key
|
380 |
-
cache_key = tts_cache_key(text_input, language)
|
381 |
-
|
382 |
-
# Check cache
|
383 |
-
with tts_lock:
|
384 |
-
cached_file = tts_cache.get(cache_key)
|
385 |
-
if cached_file and os.path.exists(cached_file):
|
386 |
-
logger.info(f"β
Using cached TTS audio for: '{text_input}'")
|
387 |
-
return jsonify({
|
388 |
-
"message": "TTS audio retrieved from cache",
|
389 |
-
"file_url": f"/download/{os.path.basename(cached_file)}",
|
390 |
-
"language": language,
|
391 |
-
"text_length": len(text_input),
|
392 |
-
"from_cache": True
|
393 |
-
})
|
394 |
|
395 |
-
# Chunk text if too long (optional optimization for very long texts)
|
396 |
-
MAX_TEXT_LENGTH = 200 # Maximum text length to process in one go
|
397 |
-
|
398 |
-
if len(text_input) > MAX_TEXT_LENGTH:
|
399 |
-
# Simple chunking by splitting on periods
|
400 |
-
chunks = []
|
401 |
-
current_chunk = ""
|
402 |
-
|
403 |
-
for sentence in text_input.split("."):
|
404 |
-
if len(current_chunk) + len(sentence) < MAX_TEXT_LENGTH:
|
405 |
-
current_chunk += sentence + "."
|
406 |
-
else:
|
407 |
-
if current_chunk:
|
408 |
-
chunks.append(current_chunk)
|
409 |
-
current_chunk = sentence + "."
|
410 |
-
|
411 |
-
if current_chunk:
|
412 |
-
chunks.append(current_chunk)
|
413 |
-
|
414 |
-
logger.info(f"π Text chunked into {len(chunks)} parts for processing")
|
415 |
-
|
416 |
-
# Process chunks and combine results
|
417 |
-
try:
|
418 |
-
processor = tts_processors[language]
|
419 |
-
model = tts_models[language]
|
420 |
-
|
421 |
-
# For simplicity, we'll just use the first chunk in this example
|
422 |
-
# A full implementation would process all chunks and concatenate audio
|
423 |
-
text_input = chunks[0]
|
424 |
-
logger.info(f"β οΈ Using only the first chunk for demonstration: '{text_input}'")
|
425 |
-
except Exception as e:
|
426 |
-
logger.error(f"β TTS chunking failed: {str(e)}")
|
427 |
-
return jsonify({"error": f"TTS chunking failed: {str(e)}"}), 500
|
428 |
-
|
429 |
try:
|
430 |
processor = tts_processors[language]
|
431 |
model = tts_models[language]
|
@@ -447,22 +290,10 @@ def handle_tts_request(request, output_dir):
|
|
447 |
|
448 |
# Save to file
|
449 |
try:
|
450 |
-
output_filename = os.path.join(output_dir, f"{language}
|
451 |
sampling_rate = model.config.sampling_rate
|
452 |
sf.write(output_filename, waveform, sampling_rate)
|
453 |
logger.info(f"β
Speech generated! File saved: {output_filename}")
|
454 |
-
|
455 |
-
# Cache the result
|
456 |
-
with tts_lock:
|
457 |
-
tts_cache[cache_key] = output_filename
|
458 |
-
# Implement cache size limitation if needed
|
459 |
-
if len(tts_cache) > MAX_CACHE_SIZE:
|
460 |
-
oldest_key = next(iter(tts_cache))
|
461 |
-
try:
|
462 |
-
os.remove(tts_cache[oldest_key])
|
463 |
-
except:
|
464 |
-
pass
|
465 |
-
tts_cache.pop(oldest_key)
|
466 |
except Exception as e:
|
467 |
logger.error(f"β Failed to save audio file: {str(e)}")
|
468 |
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
|
@@ -471,22 +302,15 @@ def handle_tts_request(request, output_dir):
|
|
471 |
"message": "TTS audio generated",
|
472 |
"file_url": f"/download/{os.path.basename(output_filename)}",
|
473 |
"language": language,
|
474 |
-
"text_length": len(text_input)
|
475 |
-
"from_cache": False
|
476 |
})
|
477 |
except Exception as e:
|
478 |
logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}")
|
479 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
480 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
481 |
|
482 |
-
# Cache key generator for translation
|
483 |
-
def translation_cache_key(text, source_lang, target_lang):
|
484 |
-
"""Generate a cache key for translation results"""
|
485 |
-
import hashlib
|
486 |
-
return hashlib.md5(f"{text}:{source_lang}:{target_lang}".encode()).hexdigest()
|
487 |
-
|
488 |
def handle_translation_request(request):
|
489 |
-
"""Handle translation requests
|
490 |
try:
|
491 |
data = request.get_json()
|
492 |
if not data:
|
@@ -506,97 +330,110 @@ def handle_translation_request(request):
|
|
506 |
target_code = LANGUAGE_CODES.get(target_language, target_language)
|
507 |
|
508 |
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
509 |
-
|
510 |
-
# Generate cache key
|
511 |
-
cache_key = translation_cache_key(source_text, source_code, target_code)
|
512 |
-
|
513 |
-
# Check cache
|
514 |
-
with translation_lock:
|
515 |
-
cached_result = translation_cache.get(cache_key)
|
516 |
-
if cached_result:
|
517 |
-
logger.info(f"β
Using cached translation result")
|
518 |
-
return jsonify({
|
519 |
-
"translated_text": cached_result,
|
520 |
-
"source_language": source_language,
|
521 |
-
"target_language": target_language,
|
522 |
-
"from_cache": True
|
523 |
-
})
|
524 |
|
525 |
-
#
|
526 |
-
|
527 |
actual_source_code = source_code
|
528 |
actual_target_code = target_code
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
model_key = f"{source_code}-{target_code}"
|
535 |
-
use_phi_model = False
|
536 |
-
elif (source_code in ["pam", "fil", "tgl"] and target_code in ["pam", "fil", "tgl"]):
|
537 |
-
# Use phi model with appropriate substitutions
|
538 |
-
model_key = "phi"
|
539 |
use_phi_model = True
|
540 |
-
# Replace tgl with fil for the phi model
|
541 |
-
|
542 |
-
|
543 |
-
#
|
544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
else:
|
546 |
-
|
547 |
-
|
548 |
-
{"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
|
549 |
-
|
550 |
-
# Check if model exists and is loaded
|
551 |
-
if model_key not in translation_models or translation_models[model_key] is None:
|
552 |
-
logger.error(f"β Translation model for {model_key} not loaded")
|
553 |
-
return jsonify({"error": f"Translation model not available"}), 503
|
554 |
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
|
|
559 |
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
translated = model.generate(
|
569 |
-
**tokenized,
|
570 |
-
max_length=max_length,
|
571 |
-
num_beams=4,
|
572 |
-
length_penalty=0.6,
|
573 |
-
early_stopping=True,
|
574 |
-
repetition_penalty=1.5,
|
575 |
-
no_repeat_ngram_size=3
|
576 |
-
)
|
577 |
-
|
578 |
-
# Decode the translation
|
579 |
-
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
580 |
-
|
581 |
-
logger.info(f"β
Translation result: '{result}'")
|
582 |
-
|
583 |
-
# Cache the result
|
584 |
-
with translation_lock:
|
585 |
-
translation_cache[cache_key] = result
|
586 |
-
# Implement cache size limitation if needed
|
587 |
-
if len(translation_cache) > MAX_CACHE_SIZE:
|
588 |
-
translation_cache.pop(next(iter(translation_cache)))
|
589 |
|
590 |
-
|
591 |
-
"
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
|
601 |
except Exception as e:
|
602 |
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}")
|
|
|
1 |
+
# translator.py - Handles ASR, TTS, and translation tasks
|
2 |
|
3 |
import os
|
4 |
import sys
|
|
|
12 |
from flask import jsonify
|
13 |
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
14 |
from transformers import MarianMTModel, MarianTokenizer
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# Configure logging
|
17 |
logger = logging.getLogger("speech_api")
|
|
|
24 |
translation_models = {}
|
25 |
translation_tokenizers = {}
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Language-specific configurations
|
28 |
LANGUAGE_CODES = {
|
29 |
"kapampangan": "pam",
|
|
|
48 |
"phi": "Coco-18/opus-mt-phi"
|
49 |
}
|
50 |
|
|
|
|
|
|
|
|
|
51 |
def init_models(device):
|
52 |
+
"""Initialize all models required for the API"""
|
53 |
global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
|
54 |
+
|
55 |
+
# Initialize ASR model
|
56 |
+
ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
|
57 |
+
logger.info(f"π Loading ASR model: {ASR_MODEL_ID}")
|
58 |
+
|
59 |
+
try:
|
60 |
+
asr_processor = AutoProcessor.from_pretrained(
|
61 |
+
ASR_MODEL_ID,
|
62 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
63 |
+
)
|
64 |
+
logger.info("β
ASR processor loaded successfully")
|
65 |
+
|
66 |
+
asr_model = Wav2Vec2ForCTC.from_pretrained(
|
67 |
+
ASR_MODEL_ID,
|
68 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
69 |
+
)
|
70 |
+
asr_model.to(device)
|
71 |
+
logger.info(f"β
ASR model loaded successfully on {device}")
|
72 |
+
except Exception as e:
|
73 |
+
logger.error(f"β Error loading ASR model: {str(e)}")
|
74 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
75 |
+
|
76 |
+
# Initialize TTS models
|
77 |
+
for lang, model_id in TTS_MODELS.items():
|
78 |
+
logger.info(f"π Loading TTS model for {lang}: {model_id}")
|
|
|
79 |
try:
|
80 |
+
tts_processors[lang] = AutoTokenizer.from_pretrained(
|
81 |
model_id,
|
82 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
83 |
)
|
84 |
+
logger.info(f"β
{lang} TTS processor loaded")
|
85 |
+
|
86 |
+
tts_models[lang] = VitsModel.from_pretrained(
|
87 |
model_id,
|
88 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
89 |
)
|
90 |
+
tts_models[lang].to(device)
|
91 |
logger.info(f"β
{lang} TTS model loaded on {device}")
|
|
|
92 |
except Exception as e:
|
93 |
logger.error(f"β Failed to load {lang} TTS model: {str(e)}")
|
94 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
95 |
+
tts_models[lang] = None
|
96 |
+
|
97 |
+
# Initialize translation models
|
98 |
+
for model_key, model_id in TRANSLATION_MODELS.items():
|
99 |
+
logger.info(f"π Loading Translation model: {model_id}")
|
100 |
+
|
101 |
try:
|
102 |
+
translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
|
103 |
model_id,
|
104 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
105 |
)
|
106 |
+
logger.info(f"β
Translation tokenizer loaded successfully for {model_key}")
|
107 |
+
|
108 |
+
translation_models[model_key] = MarianMTModel.from_pretrained(
|
109 |
model_id,
|
110 |
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
111 |
)
|
112 |
+
translation_models[model_key].to(device)
|
113 |
logger.info(f"β
Translation model loaded successfully on {device} for {model_key}")
|
|
|
114 |
except Exception as e:
|
115 |
logger.error(f"β Error loading Translation model for {model_key}: {str(e)}")
|
116 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
117 |
+
translation_models[model_key] = None
|
118 |
+
translation_tokenizers[model_key] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
def check_model_status():
|
|
|
142 |
"translation_models": translation_status
|
143 |
}
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
def handle_asr_request(request, output_dir, sample_rate):
|
147 |
+
"""Handle ASR (Automatic Speech Recognition) requests"""
|
148 |
if asr_model is None or asr_processor is None:
|
149 |
logger.error("β ASR endpoint called but models aren't loaded")
|
150 |
return jsonify({"error": "ASR model not available"}), 503
|
|
|
165 |
lang_code = LANGUAGE_CODES[language]
|
166 |
logger.info(f"π Processing {language} audio for ASR")
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
# Save the uploaded file temporarily
|
169 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
|
170 |
+
temp_audio.write(audio_file.read())
|
171 |
temp_audio_path = temp_audio.name
|
172 |
logger.debug(f"π Temporary audio saved to {temp_audio_path}")
|
173 |
|
174 |
+
# Convert to WAV if necessary
|
175 |
+
wav_path = temp_audio_path
|
176 |
+
if not audio_file.filename.lower().endswith(".wav"):
|
177 |
+
wav_path = os.path.join(output_dir, "converted_audio.wav")
|
178 |
+
logger.info(f"π Converting audio to WAV format: {wav_path}")
|
179 |
+
try:
|
180 |
+
audio = AudioSegment.from_file(temp_audio_path)
|
181 |
+
audio = audio.set_frame_rate(sample_rate).set_channels(1)
|
182 |
+
audio.export(wav_path, format="wav")
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"β Audio conversion failed: {str(e)}")
|
185 |
+
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
|
186 |
+
|
187 |
+
# Load and process the WAV file
|
188 |
try:
|
189 |
+
waveform, sr = torchaudio.load(wav_path)
|
190 |
+
logger.debug(f"β
Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
|
191 |
+
|
192 |
+
# Resample if needed
|
193 |
+
if sr != sample_rate:
|
194 |
+
logger.info(f"π Resampling audio from {sr}Hz to {sample_rate}Hz")
|
195 |
+
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
|
196 |
+
|
197 |
+
waveform = waveform / torch.max(torch.abs(waveform))
|
198 |
except Exception as e:
|
199 |
+
logger.error(f"β Failed to load or process audio: {str(e)}")
|
200 |
+
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
201 |
|
202 |
# Process audio for ASR
|
203 |
try:
|
204 |
inputs = asr_processor(
|
205 |
+
waveform.squeeze().numpy(),
|
206 |
sampling_rate=sample_rate,
|
207 |
return_tensors="pt",
|
208 |
language=lang_code
|
|
|
220 |
transcription = asr_processor.decode(ids)
|
221 |
|
222 |
logger.info(f"β
Transcription ({language}): {transcription}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
# Clean up temp files
|
225 |
try:
|
|
|
232 |
return jsonify({
|
233 |
"transcription": transcription,
|
234 |
"language": language,
|
235 |
+
"language_code": lang_code
|
|
|
236 |
})
|
237 |
except Exception as e:
|
238 |
logger.error(f"β ASR inference failed: {str(e)}")
|
|
|
244 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
245 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
def handle_tts_request(request, output_dir):
|
248 |
+
"""Handle TTS (Text-to-Speech) requests"""
|
249 |
try:
|
250 |
data = request.get_json()
|
251 |
if not data:
|
|
|
268 |
return jsonify({"error": f"TTS model for {language} not available"}), 503
|
269 |
|
270 |
logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
try:
|
273 |
processor = tts_processors[language]
|
274 |
model = tts_models[language]
|
|
|
290 |
|
291 |
# Save to file
|
292 |
try:
|
293 |
+
output_filename = os.path.join(output_dir, f"{language}_output.wav")
|
294 |
sampling_rate = model.config.sampling_rate
|
295 |
sf.write(output_filename, waveform, sampling_rate)
|
296 |
logger.info(f"β
Speech generated! File saved: {output_filename}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
except Exception as e:
|
298 |
logger.error(f"β Failed to save audio file: {str(e)}")
|
299 |
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
|
|
|
302 |
"message": "TTS audio generated",
|
303 |
"file_url": f"/download/{os.path.basename(output_filename)}",
|
304 |
"language": language,
|
305 |
+
"text_length": len(text_input)
|
|
|
306 |
})
|
307 |
except Exception as e:
|
308 |
logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}")
|
309 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
310 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
def handle_translation_request(request):
|
313 |
+
"""Handle translation requests"""
|
314 |
try:
|
315 |
data = request.get_json()
|
316 |
if not data:
|
|
|
330 |
target_code = LANGUAGE_CODES.get(target_language, target_language)
|
331 |
|
332 |
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
+
# Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
|
335 |
+
use_phi_model = False
|
336 |
actual_source_code = source_code
|
337 |
actual_target_code = target_code
|
338 |
+
|
339 |
+
# Check if we need to use the phi model with fil replacement
|
340 |
+
if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
|
341 |
+
use_phi_model = True
|
342 |
+
elif (source_code == "pam" and target_code == "tgl"):
|
|
|
|
|
|
|
|
|
|
|
343 |
use_phi_model = True
|
344 |
+
actual_target_code = "fil" # Replace tgl with fil for the phi model
|
345 |
+
elif (source_code == "tgl" and target_code == "pam"):
|
346 |
+
use_phi_model = True
|
347 |
+
actual_source_code = "fil" # Replace tgl with fil for the phi model
|
348 |
+
|
349 |
+
if use_phi_model:
|
350 |
+
model_key = "phi"
|
351 |
+
|
352 |
+
# Check if we have the phi model
|
353 |
+
if model_key not in translation_models or translation_models[model_key] is None:
|
354 |
+
logger.error(f"β Translation model for {model_key} not loaded")
|
355 |
+
return jsonify({"error": f"Translation model not available"}), 503
|
356 |
+
|
357 |
+
try:
|
358 |
+
# Get the phi model and tokenizer
|
359 |
+
model = translation_models[model_key]
|
360 |
+
tokenizer = translation_tokenizers[model_key]
|
361 |
+
|
362 |
+
# Prepend target language token to input
|
363 |
+
input_text = f">>{actual_target_code}<< {source_text}"
|
364 |
+
|
365 |
+
logger.info(f"π Using phi model with input: '{input_text}'")
|
366 |
+
|
367 |
+
# Tokenize the text
|
368 |
+
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
369 |
+
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
370 |
+
|
371 |
+
with torch.no_grad():
|
372 |
+
translated = model.generate(
|
373 |
+
**tokenized,
|
374 |
+
max_length=100, # Reasonable output length
|
375 |
+
num_beams=4, # Same as in training
|
376 |
+
length_penalty=0.6, # Same as in training
|
377 |
+
early_stopping=True, # Same as in training
|
378 |
+
repetition_penalty=1.5, # Add this to prevent repetition
|
379 |
+
no_repeat_ngram_size=3 # Add this to prevent repetition
|
380 |
+
)
|
381 |
+
|
382 |
+
# Decode the translation
|
383 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
384 |
+
|
385 |
+
logger.info(f"β
Translation result: '{result}'")
|
386 |
+
|
387 |
+
return jsonify({
|
388 |
+
"translated_text": result,
|
389 |
+
"source_language": source_language,
|
390 |
+
"target_language": target_language
|
391 |
+
})
|
392 |
+
except Exception as e:
|
393 |
+
logger.error(f"β Translation processing failed: {str(e)}")
|
394 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
395 |
+
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
396 |
else:
|
397 |
+
# Create the regular language pair key for other language pairs
|
398 |
+
lang_pair = f"{source_code}-{target_code}"
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
|
400 |
+
# Check if we have a model for this language pair
|
401 |
+
if lang_pair not in translation_models:
|
402 |
+
logger.warning(f"β οΈ No translation model available for {lang_pair}")
|
403 |
+
return jsonify(
|
404 |
+
{"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
|
405 |
|
406 |
+
if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
|
407 |
+
logger.error(f"β Translation model for {lang_pair} not loaded")
|
408 |
+
return jsonify({"error": f"Translation model not available"}), 503
|
409 |
|
410 |
+
try:
|
411 |
+
# Regular translation process for other language pairs
|
412 |
+
model = translation_models[lang_pair]
|
413 |
+
tokenizer = translation_tokenizers[lang_pair]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
+
# Tokenize the text
|
416 |
+
tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
|
417 |
+
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
418 |
+
|
419 |
+
# Generate translation
|
420 |
+
with torch.no_grad():
|
421 |
+
translated = model.generate(**tokenized)
|
422 |
+
|
423 |
+
# Decode the translation
|
424 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
425 |
+
|
426 |
+
logger.info(f"β
Translation result: '{result}'")
|
427 |
+
|
428 |
+
return jsonify({
|
429 |
+
"translated_text": result,
|
430 |
+
"source_language": source_language,
|
431 |
+
"target_language": target_language
|
432 |
+
})
|
433 |
+
except Exception as e:
|
434 |
+
logger.error(f"β Translation processing failed: {str(e)}")
|
435 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
436 |
+
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
437 |
|
438 |
except Exception as e:
|
439 |
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}")
|