|
|
|
|
|
import os |
|
import sys |
|
import logging |
|
import traceback |
|
import torch |
|
import torchaudio |
|
import tempfile |
|
import soundfile as sf |
|
from pydub import AudioSegment |
|
from flask import jsonify |
|
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer |
|
from transformers import MarianMTModel, MarianTokenizer |
|
import concurrent.futures |
|
import functools |
|
import threading |
|
from concurrent.futures import ThreadPoolExecutor |
|
from functools import lru_cache |
|
|
|
|
|
logger = logging.getLogger("speech_api") |
|
|
|
|
|
asr_model = None |
|
asr_processor = None |
|
tts_models = {} |
|
tts_processors = {} |
|
translation_models = {} |
|
translation_tokenizers = {} |
|
|
|
|
|
asr_cache = {} |
|
tts_cache = {} |
|
translation_cache = {} |
|
|
|
|
|
asr_lock = threading.Lock() |
|
tts_lock = threading.Lock() |
|
translation_lock = threading.Lock() |
|
|
|
|
|
LANGUAGE_CODES = { |
|
"kapampangan": "pam", |
|
"filipino": "fil", |
|
"english": "eng", |
|
"tagalog": "tgl", |
|
} |
|
|
|
|
|
TTS_MODELS = { |
|
"kapampangan": "facebook/mms-tts-pam", |
|
"tagalog": "facebook/mms-tts-tgl", |
|
"english": "facebook/mms-tts-eng" |
|
} |
|
|
|
|
|
TRANSLATION_MODELS = { |
|
"pam-eng": "Coco-18/opus-mt-pam-en", |
|
"eng-pam": "Coco-18/opus-mt-en-pam", |
|
"tgl-eng": "Helsinki-NLP/opus-mt-tl-en", |
|
"eng-tgl": "Helsinki-NLP/opus-mt-en-tl", |
|
"phi": "Coco-18/opus-mt-phi" |
|
} |
|
|
|
|
|
MAX_CACHE_SIZE = 100 |
|
CACHE_TTL = 3600 |
|
|
|
def init_models(device): |
|
"""Initialize all models required for the API with parallelization""" |
|
global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers |
|
|
|
logger.info("π Starting parallel model initialization") |
|
|
|
|
|
def init_asr(): |
|
global asr_model, asr_processor |
|
ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor" |
|
try: |
|
asr_processor = AutoProcessor.from_pretrained( |
|
ASR_MODEL_ID, |
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE") |
|
) |
|
|
|
asr_model = Wav2Vec2ForCTC.from_pretrained( |
|
ASR_MODEL_ID, |
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE") |
|
) |
|
asr_model.to(device) |
|
logger.info(f"β
ASR model loaded successfully on {device}") |
|
return True |
|
except Exception as e: |
|
logger.error(f"β Error loading ASR model: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return False |
|
|
|
def init_tts(lang, model_id): |
|
try: |
|
processor = AutoTokenizer.from_pretrained( |
|
model_id, |
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE") |
|
) |
|
|
|
model = VitsModel.from_pretrained( |
|
model_id, |
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE") |
|
) |
|
model.to(device) |
|
logger.info(f"β
{lang} TTS model loaded on {device}") |
|
return lang, processor, model |
|
except Exception as e: |
|
logger.error(f"β Failed to load {lang} TTS model: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return lang, None, None |
|
|
|
def init_translation(model_key, model_id): |
|
try: |
|
tokenizer = MarianTokenizer.from_pretrained( |
|
model_id, |
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE") |
|
) |
|
|
|
model = MarianMTModel.from_pretrained( |
|
model_id, |
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE") |
|
) |
|
model.to(device) |
|
logger.info(f"β
Translation model loaded successfully on {device} for {model_key}") |
|
return model_key, tokenizer, model |
|
except Exception as e: |
|
logger.error(f"β Error loading Translation model for {model_key}: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return model_key, None, None |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: |
|
|
|
asr_future = executor.submit(init_asr) |
|
|
|
|
|
tts_futures = { |
|
executor.submit(init_tts, lang, model_id): lang |
|
for lang, model_id in TTS_MODELS.items() |
|
} |
|
|
|
|
|
translation_futures = { |
|
executor.submit(init_translation, model_key, model_id): model_key |
|
for model_key, model_id in TRANSLATION_MODELS.items() |
|
} |
|
|
|
|
|
|
|
|
|
for future in concurrent.futures.as_completed(tts_futures): |
|
lang, processor, model = future.result() |
|
if processor is not None and model is not None: |
|
tts_processors[lang] = processor |
|
tts_models[lang] = model |
|
|
|
|
|
for future in concurrent.futures.as_completed(translation_futures): |
|
model_key, tokenizer, model = future.result() |
|
if tokenizer is not None and model is not None: |
|
translation_tokenizers[model_key] = tokenizer |
|
translation_models[model_key] = model |
|
|
|
|
|
logger.info("π Model initialization summary:") |
|
logger.info(f" - ASR model: {'loaded' if asr_model is not None else 'failed'}") |
|
logger.info(f" - TTS models loaded: {sum(1 for m in tts_models.values() if m is not None)}/{len(TTS_MODELS)}") |
|
logger.info(f" - Translation models loaded: {sum(1 for m in translation_models.values() if m is not None)}/{len(TRANSLATION_MODELS)}") |
|
|
|
|
|
def check_model_status(): |
|
"""Check and return the status of all models""" |
|
|
|
translation_status = {} |
|
|
|
|
|
for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]: |
|
translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[ |
|
lang_pair] is not None else "failed" |
|
|
|
|
|
phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed" |
|
translation_status["pam-fil"] = phi_status |
|
translation_status["fil-pam"] = phi_status |
|
translation_status["pam-tgl"] = phi_status |
|
translation_status["tgl-pam"] = phi_status |
|
|
|
return { |
|
"asr_model": "loaded" if asr_model is not None else "failed", |
|
"tts_models": {lang: "loaded" if model is not None else "failed" |
|
for lang, model in tts_models.items()}, |
|
"translation_models": translation_status |
|
} |
|
|
|
|
|
@lru_cache(maxsize=MAX_CACHE_SIZE) |
|
def get_cached_transcription(file_hash, language_code): |
|
"""Retrieve cached transcription result if available""" |
|
return asr_cache.get((file_hash, language_code)) |
|
|
|
def process_audio_file(audio_data, temp_audio_path, output_dir, sample_rate): |
|
"""Process audio file for ASR (separate from ASR logic)""" |
|
wav_path = temp_audio_path |
|
|
|
if not temp_audio_path.lower().endswith(".wav"): |
|
wav_path = os.path.join(output_dir, "converted_audio.wav") |
|
logger.info(f"π Converting audio to WAV format: {wav_path}") |
|
try: |
|
audio = AudioSegment.from_file(temp_audio_path) |
|
audio = audio.set_frame_rate(sample_rate).set_channels(1) |
|
audio.export(wav_path, format="wav") |
|
except Exception as e: |
|
logger.error(f"β Audio conversion failed: {str(e)}") |
|
raise Exception(f"Audio conversion failed: {str(e)}") |
|
|
|
|
|
try: |
|
waveform, sr = torchaudio.load(wav_path) |
|
|
|
|
|
if sr != sample_rate: |
|
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) |
|
|
|
|
|
waveform = waveform / torch.max(torch.abs(waveform)) |
|
|
|
return waveform.squeeze().numpy(), wav_path |
|
except Exception as e: |
|
logger.error(f"β Failed to load or process audio: {str(e)}") |
|
raise Exception(f"Audio processing failed: {str(e)}") |
|
|
|
def compute_audio_hash(audio_data): |
|
"""Compute a hash of audio data for caching purposes""" |
|
import hashlib |
|
return hashlib.md5(audio_data).hexdigest() |
|
|
|
def handle_asr_request(request, output_dir, sample_rate): |
|
"""Handle ASR (Automatic Speech Recognition) requests with optimization""" |
|
if asr_model is None or asr_processor is None: |
|
logger.error("β ASR endpoint called but models aren't loaded") |
|
return jsonify({"error": "ASR model not available"}), 503 |
|
|
|
try: |
|
if "audio" not in request.files: |
|
logger.warning("β οΈ ASR request missing audio file") |
|
return jsonify({"error": "No audio file uploaded"}), 400 |
|
|
|
audio_file = request.files["audio"] |
|
language = request.form.get("language", "english").lower() |
|
|
|
if language not in LANGUAGE_CODES: |
|
logger.warning(f"β οΈ Unsupported language requested: {language}") |
|
return jsonify( |
|
{"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400 |
|
|
|
lang_code = LANGUAGE_CODES[language] |
|
logger.info(f"π Processing {language} audio for ASR") |
|
|
|
|
|
audio_content = audio_file.read() |
|
audio_hash = compute_audio_hash(audio_content) |
|
|
|
|
|
with asr_lock: |
|
cached_result = asr_cache.get((audio_hash, lang_code)) |
|
if cached_result: |
|
logger.info(f"β
Using cached ASR result for {language}") |
|
return jsonify({ |
|
"transcription": cached_result, |
|
"language": language, |
|
"language_code": lang_code, |
|
"from_cache": True |
|
}) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio: |
|
temp_audio.write(audio_content) |
|
temp_audio_path = temp_audio.name |
|
logger.debug(f"π Temporary audio saved to {temp_audio_path}") |
|
|
|
|
|
try: |
|
with ThreadPoolExecutor(max_workers=2) as executor: |
|
future = executor.submit(process_audio_file, audio_content, temp_audio_path, output_dir, sample_rate) |
|
waveform, wav_path = future.result() |
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
try: |
|
inputs = asr_processor( |
|
waveform, |
|
sampling_rate=sample_rate, |
|
return_tensors="pt", |
|
language=lang_code |
|
) |
|
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()} |
|
except Exception as e: |
|
logger.error(f"β ASR preprocessing failed: {str(e)}") |
|
return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500 |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
logits = asr_model(**inputs).logits |
|
ids = torch.argmax(logits, dim=-1)[0] |
|
transcription = asr_processor.decode(ids) |
|
|
|
logger.info(f"β
Transcription ({language}): {transcription}") |
|
|
|
|
|
with asr_lock: |
|
asr_cache[(audio_hash, lang_code)] = transcription |
|
|
|
if len(asr_cache) > MAX_CACHE_SIZE: |
|
|
|
asr_cache.pop(next(iter(asr_cache))) |
|
|
|
|
|
try: |
|
os.unlink(temp_audio_path) |
|
if wav_path != temp_audio_path: |
|
os.unlink(wav_path) |
|
except Exception as e: |
|
logger.warning(f"β οΈ Failed to clean up temp files: {str(e)}") |
|
|
|
return jsonify({ |
|
"transcription": transcription, |
|
"language": language, |
|
"language_code": lang_code, |
|
"from_cache": False |
|
}) |
|
except Exception as e: |
|
logger.error(f"β ASR inference failed: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500 |
|
|
|
except Exception as e: |
|
logger.error(f"β Unhandled exception in ASR endpoint: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return jsonify({"error": f"Internal server error: {str(e)}"}), 500 |
|
|
|
|
|
def tts_cache_key(text, language): |
|
"""Generate a cache key for TTS results""" |
|
import hashlib |
|
return hashlib.md5(f"{text}:{language}".encode()).hexdigest() |
|
|
|
def handle_tts_request(request, output_dir): |
|
"""Handle TTS (Text-to-Speech) requests with optimization""" |
|
try: |
|
data = request.get_json() |
|
if not data: |
|
logger.warning("β οΈ TTS endpoint called with no JSON data") |
|
return jsonify({"error": "No JSON data provided"}), 400 |
|
|
|
text_input = data.get("text", "").strip() |
|
language = data.get("language", "kapampangan").lower() |
|
|
|
if not text_input: |
|
logger.warning("β οΈ TTS request with empty text") |
|
return jsonify({"error": "No text provided"}), 400 |
|
|
|
if language not in TTS_MODELS: |
|
logger.warning(f"β οΈ TTS requested for unsupported language: {language}") |
|
return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400 |
|
|
|
if tts_models[language] is None: |
|
logger.error(f"β TTS model for {language} not loaded") |
|
return jsonify({"error": f"TTS model for {language} not available"}), 503 |
|
|
|
logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'") |
|
|
|
|
|
cache_key = tts_cache_key(text_input, language) |
|
|
|
|
|
with tts_lock: |
|
cached_file = tts_cache.get(cache_key) |
|
if cached_file and os.path.exists(cached_file): |
|
logger.info(f"β
Using cached TTS audio for: '{text_input}'") |
|
return jsonify({ |
|
"message": "TTS audio retrieved from cache", |
|
"file_url": f"/download/{os.path.basename(cached_file)}", |
|
"language": language, |
|
"text_length": len(text_input), |
|
"from_cache": True |
|
}) |
|
|
|
|
|
MAX_TEXT_LENGTH = 200 |
|
|
|
if len(text_input) > MAX_TEXT_LENGTH: |
|
|
|
chunks = [] |
|
current_chunk = "" |
|
|
|
for sentence in text_input.split("."): |
|
if len(current_chunk) + len(sentence) < MAX_TEXT_LENGTH: |
|
current_chunk += sentence + "." |
|
else: |
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
current_chunk = sentence + "." |
|
|
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
|
|
logger.info(f"π Text chunked into {len(chunks)} parts for processing") |
|
|
|
|
|
try: |
|
processor = tts_processors[language] |
|
model = tts_models[language] |
|
|
|
|
|
|
|
text_input = chunks[0] |
|
logger.info(f"β οΈ Using only the first chunk for demonstration: '{text_input}'") |
|
except Exception as e: |
|
logger.error(f"β TTS chunking failed: {str(e)}") |
|
return jsonify({"error": f"TTS chunking failed: {str(e)}"}), 500 |
|
|
|
try: |
|
processor = tts_processors[language] |
|
model = tts_models[language] |
|
inputs = processor(text_input, return_tensors="pt") |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
except Exception as e: |
|
logger.error(f"β TTS preprocessing failed: {str(e)}") |
|
return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500 |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
output = model(**inputs).waveform |
|
waveform = output.squeeze().cpu().numpy() |
|
except Exception as e: |
|
logger.error(f"β TTS inference failed: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500 |
|
|
|
|
|
try: |
|
output_filename = os.path.join(output_dir, f"{language}_{cache_key}.wav") |
|
sampling_rate = model.config.sampling_rate |
|
sf.write(output_filename, waveform, sampling_rate) |
|
logger.info(f"β
Speech generated! File saved: {output_filename}") |
|
|
|
|
|
with tts_lock: |
|
tts_cache[cache_key] = output_filename |
|
|
|
if len(tts_cache) > MAX_CACHE_SIZE: |
|
oldest_key = next(iter(tts_cache)) |
|
try: |
|
os.remove(tts_cache[oldest_key]) |
|
except: |
|
pass |
|
tts_cache.pop(oldest_key) |
|
except Exception as e: |
|
logger.error(f"β Failed to save audio file: {str(e)}") |
|
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500 |
|
|
|
return jsonify({ |
|
"message": "TTS audio generated", |
|
"file_url": f"/download/{os.path.basename(output_filename)}", |
|
"language": language, |
|
"text_length": len(text_input), |
|
"from_cache": False |
|
}) |
|
except Exception as e: |
|
logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return jsonify({"error": f"Internal server error: {str(e)}"}), 500 |
|
|
|
|
|
def translation_cache_key(text, source_lang, target_lang): |
|
"""Generate a cache key for translation results""" |
|
import hashlib |
|
return hashlib.md5(f"{text}:{source_lang}:{target_lang}".encode()).hexdigest() |
|
|
|
def handle_translation_request(request): |
|
"""Handle translation requests with optimization""" |
|
try: |
|
data = request.get_json() |
|
if not data: |
|
logger.warning("β οΈ Translation endpoint called with no JSON data") |
|
return jsonify({"error": "No JSON data provided"}), 400 |
|
|
|
source_text = data.get("text", "").strip() |
|
source_language = data.get("source_language", "").lower() |
|
target_language = data.get("target_language", "").lower() |
|
|
|
if not source_text: |
|
logger.warning("β οΈ Translation request with empty text") |
|
return jsonify({"error": "No text provided"}), 400 |
|
|
|
|
|
source_code = LANGUAGE_CODES.get(source_language, source_language) |
|
target_code = LANGUAGE_CODES.get(target_language, target_language) |
|
|
|
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'") |
|
|
|
|
|
cache_key = translation_cache_key(source_text, source_code, target_code) |
|
|
|
|
|
with translation_lock: |
|
cached_result = translation_cache.get(cache_key) |
|
if cached_result: |
|
logger.info(f"β
Using cached translation result") |
|
return jsonify({ |
|
"translated_text": cached_result, |
|
"source_language": source_language, |
|
"target_language": target_language, |
|
"from_cache": True |
|
}) |
|
|
|
|
|
model_key = None |
|
actual_source_code = source_code |
|
actual_target_code = target_code |
|
input_text = source_text |
|
|
|
|
|
if f"{source_code}-{target_code}" in translation_models: |
|
|
|
model_key = f"{source_code}-{target_code}" |
|
use_phi_model = False |
|
elif (source_code in ["pam", "fil", "tgl"] and target_code in ["pam", "fil", "tgl"]): |
|
|
|
model_key = "phi" |
|
use_phi_model = True |
|
|
|
if source_code == "tgl": actual_source_code = "fil" |
|
if target_code == "tgl": actual_target_code = "fil" |
|
|
|
input_text = f">>{actual_target_code}<< {source_text}" |
|
else: |
|
logger.warning(f"β οΈ No translation model available for {source_code}-{target_code}") |
|
return jsonify( |
|
{"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400 |
|
|
|
|
|
if model_key not in translation_models or translation_models[model_key] is None: |
|
logger.error(f"β Translation model for {model_key} not loaded") |
|
return jsonify({"error": f"Translation model not available"}), 503 |
|
|
|
try: |
|
|
|
model = translation_models[model_key] |
|
tokenizer = translation_tokenizers[model_key] |
|
|
|
|
|
tokenized = tokenizer(input_text, return_tensors="pt", padding=True) |
|
tokenized = {k: v.to(model.device) for k, v in tokenized.items()} |
|
|
|
|
|
max_length = min(100, len(source_text.split()) * 2) |
|
|
|
with torch.no_grad(): |
|
translated = model.generate( |
|
**tokenized, |
|
max_length=max_length, |
|
num_beams=4, |
|
length_penalty=0.6, |
|
early_stopping=True, |
|
repetition_penalty=1.5, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
|
|
result = tokenizer.decode(translated[0], skip_special_tokens=True) |
|
|
|
logger.info(f"β
Translation result: '{result}'") |
|
|
|
|
|
with translation_lock: |
|
translation_cache[cache_key] = result |
|
|
|
if len(translation_cache) > MAX_CACHE_SIZE: |
|
translation_cache.pop(next(iter(translation_cache))) |
|
|
|
return jsonify({ |
|
"translated_text": result, |
|
"source_language": source_language, |
|
"target_language": target_language, |
|
"from_cache": False |
|
}) |
|
except Exception as e: |
|
logger.error(f"β Translation processing failed: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500 |
|
|
|
except Exception as e: |
|
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}") |
|
logger.debug(f"Stack trace: {traceback.format_exc()}") |
|
return jsonify({"error": f"Internal server error: {str(e)}"}), 500 |
|
|
|
def get_asr_model(): |
|
return asr_model |
|
|
|
def get_asr_processor(): |
|
return asr_processor |