|
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import traceback
|
|
import torch
|
|
import torchaudio
|
|
import tempfile
|
|
import soundfile as sf
|
|
from pydub import AudioSegment
|
|
from flask import jsonify
|
|
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
|
from transformers import MarianMTModel, MarianTokenizer
|
|
|
|
|
|
logger = logging.getLogger("speech_api")
|
|
|
|
|
|
asr_model = None
|
|
asr_processor = None
|
|
tts_models = {}
|
|
tts_processors = {}
|
|
translation_models = {}
|
|
translation_tokenizers = {}
|
|
|
|
|
|
LANGUAGE_CODES = {
|
|
"kapampangan": "pam",
|
|
"filipino": "fil",
|
|
"english": "eng",
|
|
"tagalog": "tgl",
|
|
}
|
|
|
|
|
|
TTS_MODELS = {
|
|
"kapampangan": "facebook/mms-tts-pam",
|
|
"tagalog": "facebook/mms-tts-tgl",
|
|
"english": "facebook/mms-tts-eng"
|
|
}
|
|
|
|
|
|
TRANSLATION_MODELS = {
|
|
"pam-eng": "Coco-18/opus-mt-pam-en",
|
|
"eng-pam": "Coco-18/opus-mt-en-pam",
|
|
"tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
|
|
"eng-tgl": "Helsinki-NLP/opus-mt-en-tl",
|
|
"phi": "Coco-18/opus-mt-phi"
|
|
}
|
|
|
|
|
|
def init_models(device):
|
|
"""Initialize all models required for the API"""
|
|
global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
|
|
|
|
|
|
ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
|
|
logger.info(f"π Loading ASR model: {ASR_MODEL_ID}")
|
|
|
|
try:
|
|
asr_processor = AutoProcessor.from_pretrained(
|
|
ASR_MODEL_ID,
|
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
|
)
|
|
logger.info("β
ASR processor loaded successfully")
|
|
|
|
asr_model = Wav2Vec2ForCTC.from_pretrained(
|
|
ASR_MODEL_ID,
|
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
|
)
|
|
asr_model.to(device)
|
|
logger.info(f"β
ASR model loaded successfully on {device}")
|
|
except Exception as e:
|
|
logger.error(f"β Error loading ASR model: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
|
|
|
|
for lang, model_id in TTS_MODELS.items():
|
|
logger.info(f"π Loading TTS model for {lang}: {model_id}")
|
|
try:
|
|
tts_processors[lang] = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
|
)
|
|
logger.info(f"β
{lang} TTS processor loaded")
|
|
|
|
tts_models[lang] = VitsModel.from_pretrained(
|
|
model_id,
|
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
|
)
|
|
tts_models[lang].to(device)
|
|
logger.info(f"β
{lang} TTS model loaded on {device}")
|
|
except Exception as e:
|
|
logger.error(f"β Failed to load {lang} TTS model: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
tts_models[lang] = None
|
|
|
|
|
|
for model_key, model_id in TRANSLATION_MODELS.items():
|
|
logger.info(f"π Loading Translation model: {model_id}")
|
|
|
|
try:
|
|
translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
|
|
model_id,
|
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
|
)
|
|
logger.info(f"β
Translation tokenizer loaded successfully for {model_key}")
|
|
|
|
translation_models[model_key] = MarianMTModel.from_pretrained(
|
|
model_id,
|
|
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
|
)
|
|
translation_models[model_key].to(device)
|
|
logger.info(f"β
Translation model loaded successfully on {device} for {model_key}")
|
|
except Exception as e:
|
|
logger.error(f"β Error loading Translation model for {model_key}: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
translation_models[model_key] = None
|
|
translation_tokenizers[model_key] = None
|
|
|
|
|
|
def check_model_status():
|
|
"""Check and return the status of all models"""
|
|
|
|
translation_status = {}
|
|
|
|
|
|
for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]:
|
|
translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[
|
|
lang_pair] is not None else "failed"
|
|
|
|
|
|
phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed"
|
|
translation_status["pam-fil"] = phi_status
|
|
translation_status["fil-pam"] = phi_status
|
|
translation_status["pam-tgl"] = phi_status
|
|
translation_status["tgl-pam"] = phi_status
|
|
|
|
return {
|
|
"asr_model": "loaded" if asr_model is not None else "failed",
|
|
"tts_models": {lang: "loaded" if model is not None else "failed"
|
|
for lang, model in tts_models.items()},
|
|
"translation_models": translation_status
|
|
}
|
|
|
|
|
|
def handle_asr_request(request, output_dir, sample_rate):
|
|
"""Handle ASR (Automatic Speech Recognition) requests"""
|
|
if asr_model is None or asr_processor is None:
|
|
logger.error("β ASR endpoint called but models aren't loaded")
|
|
return jsonify({"error": "ASR model not available"}), 503
|
|
|
|
try:
|
|
if "audio" not in request.files:
|
|
logger.warning("β οΈ ASR request missing audio file")
|
|
return jsonify({"error": "No audio file uploaded"}), 400
|
|
|
|
audio_file = request.files["audio"]
|
|
language = request.form.get("language", "english").lower()
|
|
|
|
if language not in LANGUAGE_CODES:
|
|
logger.warning(f"β οΈ Unsupported language requested: {language}")
|
|
return jsonify(
|
|
{"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
|
|
|
|
lang_code = LANGUAGE_CODES[language]
|
|
logger.info(f"π Processing {language} audio for ASR")
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
|
|
temp_audio.write(audio_file.read())
|
|
temp_audio_path = temp_audio.name
|
|
logger.debug(f"π Temporary audio saved to {temp_audio_path}")
|
|
|
|
|
|
wav_path = temp_audio_path
|
|
if not audio_file.filename.lower().endswith(".wav"):
|
|
wav_path = os.path.join(output_dir, "converted_audio.wav")
|
|
logger.info(f"π Converting audio to WAV format: {wav_path}")
|
|
try:
|
|
audio = AudioSegment.from_file(temp_audio_path)
|
|
audio = audio.set_frame_rate(sample_rate).set_channels(1)
|
|
audio.export(wav_path, format="wav")
|
|
except Exception as e:
|
|
logger.error(f"β Audio conversion failed: {str(e)}")
|
|
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
|
|
|
|
|
|
try:
|
|
waveform, sr = torchaudio.load(wav_path)
|
|
logger.debug(f"β
Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
|
|
|
|
|
|
if sr != sample_rate:
|
|
logger.info(f"π Resampling audio from {sr}Hz to {sample_rate}Hz")
|
|
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
|
|
|
|
waveform = waveform / torch.max(torch.abs(waveform))
|
|
except Exception as e:
|
|
logger.error(f"β Failed to load or process audio: {str(e)}")
|
|
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
|
|
|
|
|
try:
|
|
inputs = asr_processor(
|
|
waveform.squeeze().numpy(),
|
|
sampling_rate=sample_rate,
|
|
return_tensors="pt",
|
|
language=lang_code
|
|
)
|
|
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
|
|
except Exception as e:
|
|
logger.error(f"β ASR preprocessing failed: {str(e)}")
|
|
return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
|
|
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
logits = asr_model(**inputs).logits
|
|
ids = torch.argmax(logits, dim=-1)[0]
|
|
transcription = asr_processor.decode(ids)
|
|
|
|
logger.info(f"β
Transcription ({language}): {transcription}")
|
|
|
|
|
|
try:
|
|
os.unlink(temp_audio_path)
|
|
if wav_path != temp_audio_path:
|
|
os.unlink(wav_path)
|
|
except Exception as e:
|
|
logger.warning(f"β οΈ Failed to clean up temp files: {str(e)}")
|
|
|
|
return jsonify({
|
|
"transcription": transcription,
|
|
"language": language,
|
|
"language_code": lang_code
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"β ASR inference failed: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
|
|
|
except Exception as e:
|
|
logger.error(f"β Unhandled exception in ASR endpoint: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
|
|
|
def handle_tts_request(request, output_dir):
|
|
"""Handle TTS (Text-to-Speech) requests"""
|
|
try:
|
|
data = request.get_json()
|
|
if not data:
|
|
logger.warning("β οΈ TTS endpoint called with no JSON data")
|
|
return jsonify({"error": "No JSON data provided"}), 400
|
|
|
|
text_input = data.get("text", "").strip()
|
|
language = data.get("language", "kapampangan").lower()
|
|
|
|
if not text_input:
|
|
logger.warning("β οΈ TTS request with empty text")
|
|
return jsonify({"error": "No text provided"}), 400
|
|
|
|
if language not in TTS_MODELS:
|
|
logger.warning(f"β οΈ TTS requested for unsupported language: {language}")
|
|
return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
|
|
|
|
if tts_models[language] is None:
|
|
logger.error(f"β TTS model for {language} not loaded")
|
|
return jsonify({"error": f"TTS model for {language} not available"}), 503
|
|
|
|
logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'")
|
|
|
|
try:
|
|
processor = tts_processors[language]
|
|
model = tts_models[language]
|
|
inputs = processor(text_input, return_tensors="pt")
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
|
except Exception as e:
|
|
logger.error(f"β TTS preprocessing failed: {str(e)}")
|
|
return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
|
|
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
output = model(**inputs).waveform
|
|
waveform = output.squeeze().cpu().numpy()
|
|
except Exception as e:
|
|
logger.error(f"β TTS inference failed: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
|
|
|
|
|
|
try:
|
|
output_filename = os.path.join(output_dir, f"{language}_output.wav")
|
|
sampling_rate = model.config.sampling_rate
|
|
sf.write(output_filename, waveform, sampling_rate)
|
|
logger.info(f"β
Speech generated! File saved: {output_filename}")
|
|
except Exception as e:
|
|
logger.error(f"β Failed to save audio file: {str(e)}")
|
|
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
|
|
|
|
return jsonify({
|
|
"message": "TTS audio generated",
|
|
"file_url": f"/download/{os.path.basename(output_filename)}",
|
|
"language": language,
|
|
"text_length": len(text_input)
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
|
|
|
def handle_translation_request(request):
|
|
"""Handle translation requests"""
|
|
try:
|
|
data = request.get_json()
|
|
if not data:
|
|
logger.warning("β οΈ Translation endpoint called with no JSON data")
|
|
return jsonify({"error": "No JSON data provided"}), 400
|
|
|
|
source_text = data.get("text", "").strip()
|
|
source_language = data.get("source_language", "").lower()
|
|
target_language = data.get("target_language", "").lower()
|
|
|
|
if not source_text:
|
|
logger.warning("β οΈ Translation request with empty text")
|
|
return jsonify({"error": "No text provided"}), 400
|
|
|
|
|
|
source_code = LANGUAGE_CODES.get(source_language, source_language)
|
|
target_code = LANGUAGE_CODES.get(target_language, target_language)
|
|
|
|
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
|
|
|
|
|
use_phi_model = False
|
|
actual_source_code = source_code
|
|
actual_target_code = target_code
|
|
|
|
|
|
if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
|
|
use_phi_model = True
|
|
elif (source_code == "pam" and target_code == "tgl"):
|
|
use_phi_model = True
|
|
actual_target_code = "fil"
|
|
elif (source_code == "tgl" and target_code == "pam"):
|
|
use_phi_model = True
|
|
actual_source_code = "fil"
|
|
|
|
if use_phi_model:
|
|
model_key = "phi"
|
|
|
|
|
|
if model_key not in translation_models or translation_models[model_key] is None:
|
|
logger.error(f"β Translation model for {model_key} not loaded")
|
|
return jsonify({"error": f"Translation model not available"}), 503
|
|
|
|
try:
|
|
|
|
model = translation_models[model_key]
|
|
tokenizer = translation_tokenizers[model_key]
|
|
|
|
|
|
input_text = f">>{actual_target_code}<< {source_text}"
|
|
|
|
logger.info(f"π Using phi model with input: '{input_text}'")
|
|
|
|
|
|
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
|
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
translated = model.generate(**tokenized)
|
|
|
|
|
|
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
|
|
|
logger.info(f"β
Translation result: '{result}'")
|
|
|
|
return jsonify({
|
|
"translated_text": result,
|
|
"source_language": source_language,
|
|
"target_language": target_language
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"β Translation processing failed: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
|
else:
|
|
|
|
lang_pair = f"{source_code}-{target_code}"
|
|
|
|
|
|
if lang_pair not in translation_models:
|
|
logger.warning(f"β οΈ No translation model available for {lang_pair}")
|
|
return jsonify(
|
|
{"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
|
|
|
|
if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
|
|
logger.error(f"β Translation model for {lang_pair} not loaded")
|
|
return jsonify({"error": f"Translation model not available"}), 503
|
|
|
|
try:
|
|
|
|
model = translation_models[lang_pair]
|
|
tokenizer = translation_tokenizers[lang_pair]
|
|
|
|
|
|
tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
|
|
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
translated = model.generate(**tokenized)
|
|
|
|
|
|
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
|
|
|
logger.info(f"β
Translation result: '{result}'")
|
|
|
|
return jsonify({
|
|
"translated_text": result,
|
|
"source_language": source_language,
|
|
"target_language": target_language
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"β Translation processing failed: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
|
|
|
except Exception as e:
|
|
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}")
|
|
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
|
return jsonify({"error": f"Internal server error: {str(e)}"}), 500 |