# translator.py - Handles ASR, TTS, and translation tasks import os import sys import logging import traceback import torch import torchaudio import tempfile import soundfile as sf from pydub import AudioSegment from flask import jsonify from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer from transformers import MarianMTModel, MarianTokenizer # Configure logging logger = logging.getLogger("speech_api") # Global variables to store models and processors asr_model = None asr_processor = None tts_models = {} tts_processors = {} translation_models = {} translation_tokenizers = {} # Language-specific configurations LANGUAGE_CODES = { "kapampangan": "pam", "filipino": "fil", "english": "eng", "tagalog": "tgl", } # TTS Models (Kapampangan, Tagalog, English) TTS_MODELS = { "kapampangan": "facebook/mms-tts-pam", "tagalog": "facebook/mms-tts-tgl", "english": "facebook/mms-tts-eng" } # Translation Models TRANSLATION_MODELS = { "pam-eng": "Coco-18/opus-mt-pam-en", "eng-pam": "Coco-18/opus-mt-en-pam", "tgl-eng": "Helsinki-NLP/opus-mt-tl-en", "eng-tgl": "Helsinki-NLP/opus-mt-en-tl", "phi": "Coco-18/opus-mt-phi" } def init_models(device): """Initialize all models required for the API""" global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers # Initialize ASR model ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor" logger.info(f"🔄 Loading ASR model: {ASR_MODEL_ID}") try: asr_processor = AutoProcessor.from_pretrained( ASR_MODEL_ID, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) logger.info("✅ ASR processor loaded successfully") asr_model = Wav2Vec2ForCTC.from_pretrained( ASR_MODEL_ID, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) asr_model.to(device) logger.info(f"✅ ASR model loaded successfully on {device}") except Exception as e: logger.error(f"❌ Error loading ASR model: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") # Initialize TTS models for lang, model_id in TTS_MODELS.items(): logger.info(f"🔄 Loading TTS model for {lang}: {model_id}") try: tts_processors[lang] = AutoTokenizer.from_pretrained( model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) logger.info(f"✅ {lang} TTS processor loaded") tts_models[lang] = VitsModel.from_pretrained( model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) tts_models[lang].to(device) logger.info(f"✅ {lang} TTS model loaded on {device}") except Exception as e: logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") tts_models[lang] = None # Initialize translation models for model_key, model_id in TRANSLATION_MODELS.items(): logger.info(f"🔄 Loading Translation model: {model_id}") try: translation_tokenizers[model_key] = MarianTokenizer.from_pretrained( model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) logger.info(f"✅ Translation tokenizer loaded successfully for {model_key}") translation_models[model_key] = MarianMTModel.from_pretrained( model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) translation_models[model_key].to(device) logger.info(f"✅ Translation model loaded successfully on {device} for {model_key}") except Exception as e: logger.error(f"❌ Error loading Translation model for {model_key}: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") translation_models[model_key] = None translation_tokenizers[model_key] = None def check_model_status(): """Check and return the status of all models""" # Initialize direct language pair statuses based on loaded models translation_status = {} # Add status for direct model pairs for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]: translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[ lang_pair] is not None else "failed" # Add special phi model status phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed" translation_status["pam-fil"] = phi_status translation_status["fil-pam"] = phi_status translation_status["pam-tgl"] = phi_status # Using phi model but replacing tgl with fil translation_status["tgl-pam"] = phi_status # Using phi model but replacing tgl with fil return { "asr_model": "loaded" if asr_model is not None else "failed", "tts_models": {lang: "loaded" if model is not None else "failed" for lang, model in tts_models.items()}, "translation_models": translation_status } def handle_asr_request(request, output_dir, sample_rate): """Handle ASR (Automatic Speech Recognition) requests""" if asr_model is None or asr_processor is None: logger.error("❌ ASR endpoint called but models aren't loaded") return jsonify({"error": "ASR model not available"}), 503 try: if "audio" not in request.files: logger.warning("⚠️ ASR request missing audio file") return jsonify({"error": "No audio file uploaded"}), 400 audio_file = request.files["audio"] language = request.form.get("language", "english").lower() if language not in LANGUAGE_CODES: logger.warning(f"⚠️ Unsupported language requested: {language}") return jsonify( {"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400 lang_code = LANGUAGE_CODES[language] logger.info(f"🔄 Processing {language} audio for ASR") # Save the uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio: temp_audio.write(audio_file.read()) temp_audio_path = temp_audio.name logger.debug(f"📁 Temporary audio saved to {temp_audio_path}") # Convert to WAV if necessary wav_path = temp_audio_path if not audio_file.filename.lower().endswith(".wav"): wav_path = os.path.join(output_dir, "converted_audio.wav") logger.info(f"🔄 Converting audio to WAV format: {wav_path}") try: audio = AudioSegment.from_file(temp_audio_path) audio = audio.set_frame_rate(sample_rate).set_channels(1) audio.export(wav_path, format="wav") except Exception as e: logger.error(f"❌ Audio conversion failed: {str(e)}") return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500 # Load and process the WAV file try: waveform, sr = torchaudio.load(wav_path) logger.debug(f"✅ Audio loaded: {wav_path} (Sample rate: {sr}Hz)") # Resample if needed if sr != sample_rate: logger.info(f"🔄 Resampling audio from {sr}Hz to {sample_rate}Hz") waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) waveform = waveform / torch.max(torch.abs(waveform)) except Exception as e: logger.error(f"❌ Failed to load or process audio: {str(e)}") return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500 # Process audio for ASR try: inputs = asr_processor( waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt", language=lang_code ) inputs = {k: v.to(asr_model.device) for k, v in inputs.items()} except Exception as e: logger.error(f"❌ ASR preprocessing failed: {str(e)}") return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500 # Perform ASR try: with torch.no_grad(): logits = asr_model(**inputs).logits ids = torch.argmax(logits, dim=-1)[0] transcription = asr_processor.decode(ids) logger.info(f"✅ Transcription ({language}): {transcription}") # Clean up temp files try: os.unlink(temp_audio_path) if wav_path != temp_audio_path: os.unlink(wav_path) except Exception as e: logger.warning(f"⚠️ Failed to clean up temp files: {str(e)}") return jsonify({ "transcription": transcription, "language": language, "language_code": lang_code }) except Exception as e: logger.error(f"❌ ASR inference failed: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500 except Exception as e: logger.error(f"❌ Unhandled exception in ASR endpoint: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 def handle_tts_request(request, output_dir): """Handle TTS (Text-to-Speech) requests""" try: data = request.get_json() if not data: logger.warning("⚠️ TTS endpoint called with no JSON data") return jsonify({"error": "No JSON data provided"}), 400 text_input = data.get("text", "").strip() language = data.get("language", "kapampangan").lower() if not text_input: logger.warning("⚠️ TTS request with empty text") return jsonify({"error": "No text provided"}), 400 if language not in TTS_MODELS: logger.warning(f"⚠️ TTS requested for unsupported language: {language}") return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400 if tts_models[language] is None: logger.error(f"❌ TTS model for {language} not loaded") return jsonify({"error": f"TTS model for {language} not available"}), 503 logger.info(f"🔄 Generating TTS for language: {language}, text: '{text_input}'") try: processor = tts_processors[language] model = tts_models[language] inputs = processor(text_input, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} except Exception as e: logger.error(f"❌ TTS preprocessing failed: {str(e)}") return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500 # Generate speech try: with torch.no_grad(): output = model(**inputs).waveform waveform = output.squeeze().cpu().numpy() except Exception as e: logger.error(f"❌ TTS inference failed: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500 # Save to file with a unique name to prevent overwriting try: # Create a unique filename using timestamp and text hash import hashlib import time text_hash = hashlib.md5(text_input.encode()).hexdigest()[:8] timestamp = int(time.time()) output_filename = os.path.join(output_dir, f"{language}_{text_hash}_{timestamp}.wav") sampling_rate = model.config.sampling_rate sf.write(output_filename, waveform, sampling_rate) logger.info(f"✅ Speech generated! File saved: {output_filename}") except Exception as e: logger.error(f"❌ Failed to save audio file: {str(e)}") return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500 # Add cache-busting parameter to URL return jsonify({ "message": "TTS audio generated", "file_url": f"/download/{os.path.basename(output_filename)}?t={timestamp}", "language": language, "text_length": len(text_input) }) except Exception as e: logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 def handle_translation_request(request): """Handle translation requests""" try: data = request.get_json() if not data: logger.warning("⚠️ Translation endpoint called with no JSON data") return jsonify({"error": "No JSON data provided"}), 400 source_text = data.get("text", "").strip() source_language = data.get("source_language", "").lower() target_language = data.get("target_language", "").lower() if not source_text: logger.warning("⚠️ Translation request with empty text") return jsonify({"error": "No text provided"}), 400 # Map language names to codes source_code = LANGUAGE_CODES.get(source_language, source_language) target_code = LANGUAGE_CODES.get(target_language, target_language) logger.info(f"🔄 Translating from {source_language} to {target_language}: '{source_text}'") # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model use_phi_model = False actual_source_code = source_code actual_target_code = target_code # Check if we need to use the phi model with fil replacement if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"): use_phi_model = True elif (source_code == "pam" and target_code == "tgl"): use_phi_model = True actual_target_code = "fil" # Replace tgl with fil for the phi model elif (source_code == "tgl" and target_code == "pam"): use_phi_model = True actual_source_code = "fil" # Replace tgl with fil for the phi model if use_phi_model: model_key = "phi" # Check if we have the phi model if model_key not in translation_models or translation_models[model_key] is None: logger.error(f"❌ Translation model for {model_key} not loaded") return jsonify({"error": f"Translation model not available"}), 503 try: # Get the phi model and tokenizer model = translation_models[model_key] tokenizer = translation_tokenizers[model_key] # Prepend target language token to input input_text = f">>{actual_target_code}<< {source_text}" logger.info(f"🔄 Using phi model with input: '{input_text}'") # Tokenize the text tokenized = tokenizer(input_text, return_tensors="pt", padding=True) tokenized = {k: v.to(model.device) for k, v in tokenized.items()} with torch.no_grad(): translated = model.generate( **tokenized, max_length=100, # Reasonable output length num_beams=4, # Same as in training length_penalty=0.6, # Same as in training early_stopping=True, # Same as in training repetition_penalty=1.5, # Add this to prevent repetition no_repeat_ngram_size=3 # Add this to prevent repetition ) # Decode the translation result = tokenizer.decode(translated[0], skip_special_tokens=True) logger.info(f"✅ Translation result: '{result}'") return jsonify({ "translated_text": result, "source_language": source_language, "target_language": target_language }) except Exception as e: logger.error(f"❌ Translation processing failed: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500 else: # Create the regular language pair key for other language pairs lang_pair = f"{source_code}-{target_code}" # Check if we have a model for this language pair if lang_pair not in translation_models: logger.warning(f"⚠️ No translation model available for {lang_pair}") return jsonify( {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400 if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None: logger.error(f"❌ Translation model for {lang_pair} not loaded") return jsonify({"error": f"Translation model not available"}), 503 try: # Regular translation process for other language pairs model = translation_models[lang_pair] tokenizer = translation_tokenizers[lang_pair] # Tokenize the text tokenized = tokenizer(source_text, return_tensors="pt", padding=True) tokenized = {k: v.to(model.device) for k, v in tokenized.items()} # Generate translation with torch.no_grad(): translated = model.generate(**tokenized) # Decode the translation result = tokenizer.decode(translated[0], skip_special_tokens=True) logger.info(f"✅ Translation result: '{result}'") return jsonify({ "translated_text": result, "source_language": source_language, "target_language": target_language }) except Exception as e: logger.error(f"❌ Translation processing failed: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500 except Exception as e: logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 def get_asr_model(): return asr_model def get_asr_processor(): return asr_processor