# Set cache directories first, before other imports import os import sys import logging import traceback # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger("speech_api") # Set all cache directories to locations within /tmp cache_dirs = { "HF_HOME": "/tmp/hf_home", "TRANSFORMERS_CACHE": "/tmp/transformers_cache", "HUGGINGFACE_HUB_CACHE": "/tmp/huggingface_hub_cache", "TORCH_HOME": "/tmp/torch_home", "XDG_CACHE_HOME": "/tmp/xdg_cache" } # Set environment variables and create directories for env_var, path in cache_dirs.items(): os.environ[env_var] = path try: os.makedirs(path, exist_ok=True) logger.info(f"📁 Created cache directory: {path}") except Exception as e: logger.error(f"❌ Failed to create directory {path}: {str(e)}") # Now import the rest of the libraries try: import torch from pydub import AudioSegment import tempfile import torchaudio import soundfile as sf from flask import Flask, request, jsonify, send_file from flask_cors import CORS from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer logger.info("✅ All required libraries imported successfully") except ImportError as e: logger.critical(f"❌ Failed to import necessary libraries: {str(e)}") sys.exit(1) # Check CUDA availability if torch.cuda.is_available(): logger.info(f"🚀 CUDA available: {torch.cuda.get_device_name(0)}") device = "cuda" else: logger.info("⚠️ CUDA not available, using CPU") device = "cpu" app = Flask(__name__) CORS(app) # ASR Model ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor" logger.info(f"🔄 Loading ASR model: {ASR_MODEL_ID}") asr_processor = None asr_model = None try: asr_processor = AutoProcessor.from_pretrained( ASR_MODEL_ID, cache_dir=cache_dirs["TRANSFORMERS_CACHE"] ) logger.info("✅ ASR processor loaded successfully") asr_model = Wav2Vec2ForCTC.from_pretrained( ASR_MODEL_ID, cache_dir=cache_dirs["TRANSFORMERS_CACHE"] ) asr_model.to(device) logger.info(f"✅ ASR model loaded successfully on {device}") except Exception as e: logger.error(f"❌ Error loading ASR model: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") logger.debug(f"Python version: {sys.version}") logger.debug(f"Current working directory: {os.getcwd()}") logger.debug(f"Temp directory exists: {os.path.exists('/tmp')}") logger.debug(f"Temp directory writeable: {os.access('/tmp', os.W_OK)}") # Language-specific configurations LANGUAGE_CODES = { "kapampangan": "pam", "tagalog": "tgl", "english": "eng" } # TTS Models (Kapampangan, Tagalog, English) TTS_MODELS = { "kapampangan": "facebook/mms-tts-pam", "tagalog": "facebook/mms-tts-tgl", "english": "facebook/mms-tts-eng" } tts_models = {} tts_processors = {} for lang, model_id in TTS_MODELS.items(): logger.info(f"🔄 Loading TTS model for {lang}: {model_id}") try: tts_processors[lang] = AutoTokenizer.from_pretrained( model_id, cache_dir=cache_dirs["TRANSFORMERS_CACHE"] ) logger.info(f"✅ {lang} TTS processor loaded") tts_models[lang] = VitsModel.from_pretrained( model_id, cache_dir=cache_dirs["TRANSFORMERS_CACHE"] ) tts_models[lang].to(device) logger.info(f"✅ {lang} TTS model loaded on {device}") except Exception as e: logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") tts_models[lang] = None # Constants SAMPLE_RATE = 16000 OUTPUT_DIR = "/tmp/audio_outputs" try: os.makedirs(OUTPUT_DIR, exist_ok=True) logger.info(f"📁 Created output directory: {OUTPUT_DIR}") except Exception as e: logger.error(f"❌ Failed to create output directory: {str(e)}") @app.route("/", methods=["GET"]) def home(): return jsonify({"message": "Speech API is running", "status": "active"}) @app.route("/health", methods=["GET"]) def health_check(): health_status = { "api_status": "online", "asr_model": "loaded" if asr_model is not None else "failed", "tts_models": {lang: "loaded" if model is not None else "failed" for lang, model in tts_models.items()}, "device": device } return jsonify(health_status) @app.route("/asr", methods=["POST"]) def transcribe_audio(): if asr_model is None or asr_processor is None: logger.error("❌ ASR endpoint called but models aren't loaded") return jsonify({"error": "ASR model not available"}), 503 try: if "audio" not in request.files: logger.warning("⚠️ ASR request missing audio file") return jsonify({"error": "No audio file uploaded"}), 400 audio_file = request.files["audio"] language = request.form.get("language", "english").lower() if language not in LANGUAGE_CODES: logger.warning(f"⚠️ Unsupported language requested: {language}") return jsonify({"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400 lang_code = LANGUAGE_CODES[language] logger.info(f"🔄 Processing {language} audio for ASR") # Save the uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio: temp_audio.write(audio_file.read()) temp_audio_path = temp_audio.name logger.debug(f"📁 Temporary audio saved to {temp_audio_path}") # Convert to WAV if necessary wav_path = temp_audio_path if not audio_file.filename.lower().endswith(".wav"): wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav") logger.info(f"🔄 Converting audio to WAV format: {wav_path}") try: audio = AudioSegment.from_file(temp_audio_path) audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1) audio.export(wav_path, format="wav") except Exception as e: logger.error(f"❌ Audio conversion failed: {str(e)}") return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500 # Load and process the WAV file try: waveform, sr = torchaudio.load(wav_path) logger.debug(f"✅ Audio loaded: {wav_path} (Sample rate: {sr}Hz)") # Resample if needed if sr != SAMPLE_RATE: logger.info(f"🔄 Resampling audio from {sr}Hz to {SAMPLE_RATE}Hz") waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform) waveform = waveform / torch.max(torch.abs(waveform)) except Exception as e: logger.error(f"❌ Failed to load or process audio: {str(e)}") return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500 # Process audio for ASR try: inputs = asr_processor( waveform.squeeze().numpy(), sampling_rate=SAMPLE_RATE, return_tensors="pt", language=lang_code ) inputs = {k: v.to(device) for k, v in inputs.items()} except Exception as e: logger.error(f"❌ ASR preprocessing failed: {str(e)}") return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500 # Perform ASR try: with torch.no_grad(): logits = asr_model(**inputs).logits ids = torch.argmax(logits, dim=-1)[0] transcription = asr_processor.decode(ids) logger.info(f"✅ Transcription ({language}): {transcription}") # Clean up temp files try: os.unlink(temp_audio_path) if wav_path != temp_audio_path: os.unlink(wav_path) except Exception as e: logger.warning(f"⚠️ Failed to clean up temp files: {str(e)}") return jsonify({ "transcription": transcription, "language": language, "language_code": lang_code }) except Exception as e: logger.error(f"❌ ASR inference failed: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500 except Exception as e: logger.error(f"❌ Unhandled exception in ASR endpoint: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 @app.route("/tts", methods=["POST"]) def generate_tts(): try: data = request.get_json() if not data: logger.warning("⚠️ TTS endpoint called with no JSON data") return jsonify({"error": "No JSON data provided"}), 400 text_input = data.get("text", "").strip() language = data.get("language", "kapampangan").lower() if not text_input: logger.warning("⚠️ TTS request with empty text") return jsonify({"error": "No text provided"}), 400 if language not in TTS_MODELS: logger.warning(f"⚠️ TTS requested for unsupported language: {language}") return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400 if tts_models[language] is None: logger.error(f"❌ TTS model for {language} not loaded") return jsonify({"error": f"TTS model for {language} not available"}), 503 logger.info(f"🔄 Generating TTS for language: {language}, text: '{text_input}'") try: processor = tts_processors[language] model = tts_models[language] inputs = processor(text_input, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} except Exception as e: logger.error(f"❌ TTS preprocessing failed: {str(e)}") return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500 # Generate speech try: with torch.no_grad(): output = model(**inputs).waveform waveform = output.squeeze().cpu().numpy() except Exception as e: logger.error(f"❌ TTS inference failed: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500 # Save to file try: output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav") sampling_rate = model.config.sampling_rate sf.write(output_filename, waveform, sampling_rate) logger.info(f"✅ Speech generated! File saved: {output_filename}") except Exception as e: logger.error(f"❌ Failed to save audio file: {str(e)}") return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500 return jsonify({ "message": "TTS audio generated", "file_url": f"/download/{os.path.basename(output_filename)}", "language": language, "text_length": len(text_input) }) except Exception as e: logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}") logger.debug(f"Stack trace: {traceback.format_exc()}") return jsonify({"error": f"Internal server error: {str(e)}"}), 500 @app.route("/download/", methods=["GET"]) def download_audio(filename): file_path = os.path.join(OUTPUT_DIR, filename) if os.path.exists(file_path): logger.info(f"📤 Serving audio file: {file_path}") return send_file(file_path, mimetype="audio/wav", as_attachment=True) logger.warning(f"⚠️ Requested file not found: {file_path}") return jsonify({"error": "File not found"}), 404 if __name__ == "__main__": logger.info("🚀 Starting Speech API server") logger.info(f"📊 System status: ASR model: {'✅' if asr_model else '❌'}") for lang, model in tts_models.items(): logger.info(f"📊 TTS model {lang}: {'✅' if model else '❌'}") app.run(host="0.0.0.0", port=7860, debug=True)