# Set cache directories first, before other imports import os # Set all cache directories to locations within /tmp os.environ["HF_HOME"] = "/tmp/hf_home" os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_hub_cache" os.environ["TORCH_HOME"] = "/tmp/torch_home" os.environ["XDG_CACHE_HOME"] = "/tmp/xdg_cache" # Create necessary directories for path in ["/tmp/hf_home", "/tmp/transformers_cache", "/tmp/huggingface_hub_cache", "/tmp/torch_home", "/tmp/xdg_cache"]: os.makedirs(path, exist_ok=True) # Now import the rest of the libraries import torch from pydub import AudioSegment import tempfile import torchaudio import soundfile as sf from flask import Flask, request, jsonify, send_file from flask_cors import CORS from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer app = Flask(__name__) CORS(app) # ASR Model ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor" print(f"Loading ASR model: {ASR_MODEL_ID}") try: asr_processor = AutoProcessor.from_pretrained( ASR_MODEL_ID, cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir ) asr_model = Wav2Vec2ForCTC.from_pretrained( ASR_MODEL_ID, cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir ) print("✅ ASR Model loaded successfully") except Exception as e: print(f"❌ Error loading ASR model: {str(e)}") # Provide more debugging information import sys print(f"Python version: {sys.version}") print(f"Current working directory: {os.getcwd()}") print(f"Temp directory exists: {os.path.exists('/tmp')}") print(f"Temp directory writeable: {os.access('/tmp', os.W_OK)}") # Let's continue anyway to see if we can at least start the API # Language-specific configurations LANGUAGE_CODES = { "kapampangan": "pam", "tagalog": "tgl", "english": "eng" } # TTS Models (Kapampangan, Tagalog, English) TTS_MODELS = { "kapampangan": "facebook/mms-tts-pam", "tagalog": "facebook/mms-tts-tgl", "english": "facebook/mms-tts-eng" } tts_models = {} tts_processors = {} for lang, model_id in TTS_MODELS.items(): try: tts_models[lang] = VitsModel.from_pretrained( model_id, cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir ) tts_processors[lang] = AutoTokenizer.from_pretrained( model_id, cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir ) print(f"✅ TTS Model loaded: {lang}") except Exception as e: print(f"❌ Error loading {lang} TTS model: {e}") tts_models[lang] = None # Constants SAMPLE_RATE = 16000 OUTPUT_DIR = "/tmp/audio_outputs" os.makedirs(OUTPUT_DIR, exist_ok=True) @app.route("/", methods=["GET"]) def home(): return jsonify({"message": "Speech API is running."}) @app.route("/asr", methods=["POST"]) def transcribe_audio(): try: if "audio" not in request.files: return jsonify({"error": "No audio file uploaded"}), 400 audio_file = request.files["audio"] language = request.form.get("language", "english").lower() if language not in LANGUAGE_CODES: return jsonify({"error": f"Unsupported language: {language}"}), 400 lang_code = LANGUAGE_CODES[language] # Save the uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio: temp_audio.write(audio_file.read()) temp_audio_path = temp_audio.name # Convert to WAV if necessary wav_path = temp_audio_path if not audio_file.filename.lower().endswith(".wav"): wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav") audio = AudioSegment.from_file(temp_audio_path) audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1) audio.export(wav_path, format="wav") # Load and process the WAV file waveform, sr = torchaudio.load(wav_path) # Resample if needed if sr != SAMPLE_RATE: waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform) waveform = waveform / torch.max(torch.abs(waveform)) # Process audio for ASR inputs = asr_processor( waveform.squeeze().numpy(), sampling_rate=SAMPLE_RATE, return_tensors="pt", language=lang_code ) # Perform ASR with torch.no_grad(): logits = asr_model(**inputs).logits ids = torch.argmax(logits, dim=-1)[0] transcription = asr_processor.decode(ids) print(f"Transcription ({language}): {transcription}") return jsonify({"transcription": transcription}) except Exception as e: print(f"ASR error: {str(e)}") return jsonify({"error": f"ASR failed: {str(e)}"}), 500 @app.route("/tts", methods=["POST"]) def generate_tts(): try: data = request.get_json() text_input = data.get("text", "").strip() language = data.get("language", "kapampangan").lower() if language not in TTS_MODELS: return jsonify({"error": "Invalid language"}), 400 if not text_input: return jsonify({"error": "No text provided"}), 400 if tts_models[language] is None: return jsonify({"error": "TTS model not available"}), 500 processor = tts_processors[language] model = tts_models[language] inputs = processor(text_input, return_tensors="pt") with torch.no_grad(): output = model.generate(**inputs) waveform = output.cpu().numpy().flatten() output_filename = os.path.join(OUTPUT_DIR, f"{language}_tts.wav") sf.write(output_filename, waveform, SAMPLE_RATE) return jsonify({"file_url": f"/download/{language}_tts.wav"}) except Exception as e: return jsonify({"error": f"TTS failed: {e}"}), 500 @app.route("/download/", methods=["GET"]) def download_audio(filename): file_path = os.path.join(OUTPUT_DIR, filename) if os.path.exists(file_path): return send_file(file_path, mimetype="audio/wav", as_attachment=True) return jsonify({"error": "File not found"}), 404 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True)