Kapamtalk / app.py
Coco-18's picture
Update app.py
04bb535 verified
raw
history blame
12.7 kB
# Set cache directories first, before other imports
import os
import sys
import logging
import traceback
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger("speech_api")
# Set all cache directories to locations within /tmp
cache_dirs = {
"HF_HOME": "/tmp/hf_home",
"TRANSFORMERS_CACHE": "/tmp/transformers_cache",
"HUGGINGFACE_HUB_CACHE": "/tmp/huggingface_hub_cache",
"TORCH_HOME": "/tmp/torch_home",
"XDG_CACHE_HOME": "/tmp/xdg_cache"
}
# Set environment variables and create directories
for env_var, path in cache_dirs.items():
os.environ[env_var] = path
try:
os.makedirs(path, exist_ok=True)
logger.info(f"πŸ“ Created cache directory: {path}")
except Exception as e:
logger.error(f"❌ Failed to create directory {path}: {str(e)}")
# Now import the rest of the libraries
try:
import torch
from pydub import AudioSegment
import tempfile
import torchaudio
import soundfile as sf
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
logger.info("βœ… All required libraries imported successfully")
except ImportError as e:
logger.critical(f"❌ Failed to import necessary libraries: {str(e)}")
sys.exit(1)
# Check CUDA availability
if torch.cuda.is_available():
logger.info(f"πŸš€ CUDA available: {torch.cuda.get_device_name(0)}")
device = "cuda"
else:
logger.info("⚠️ CUDA not available, using CPU")
device = "cpu"
app = Flask(__name__)
CORS(app)
# ASR Model
ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
logger.info(f"πŸ”„ Loading ASR model: {ASR_MODEL_ID}")
asr_processor = None
asr_model = None
try:
asr_processor = AutoProcessor.from_pretrained(
ASR_MODEL_ID,
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
)
logger.info("βœ… ASR processor loaded successfully")
asr_model = Wav2Vec2ForCTC.from_pretrained(
ASR_MODEL_ID,
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
)
asr_model.to(device)
logger.info(f"βœ… ASR model loaded successfully on {device}")
except Exception as e:
logger.error(f"❌ Error loading ASR model: {str(e)}")
logger.debug(f"Stack trace: {traceback.format_exc()}")
logger.debug(f"Python version: {sys.version}")
logger.debug(f"Current working directory: {os.getcwd()}")
logger.debug(f"Temp directory exists: {os.path.exists('/tmp')}")
logger.debug(f"Temp directory writeable: {os.access('/tmp', os.W_OK)}")
# Language-specific configurations
LANGUAGE_CODES = {
"kapampangan": "pam",
"tagalog": "tgl",
"english": "eng"
}
# TTS Models (Kapampangan, Tagalog, English)
TTS_MODELS = {
"kapampangan": "facebook/mms-tts-pam",
"tagalog": "facebook/mms-tts-tgl",
"english": "facebook/mms-tts-eng"
}
tts_models = {}
tts_processors = {}
for lang, model_id in TTS_MODELS.items():
logger.info(f"πŸ”„ Loading TTS model for {lang}: {model_id}")
try:
tts_processors[lang] = AutoTokenizer.from_pretrained(
model_id,
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
)
logger.info(f"βœ… {lang} TTS processor loaded")
tts_models[lang] = VitsModel.from_pretrained(
model_id,
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
)
tts_models[lang].to(device)
logger.info(f"βœ… {lang} TTS model loaded on {device}")
except Exception as e:
logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
logger.debug(f"Stack trace: {traceback.format_exc()}")
tts_models[lang] = None
# Constants
SAMPLE_RATE = 16000
OUTPUT_DIR = "/tmp/audio_outputs"
try:
os.makedirs(OUTPUT_DIR, exist_ok=True)
logger.info(f"πŸ“ Created output directory: {OUTPUT_DIR}")
except Exception as e:
logger.error(f"❌ Failed to create output directory: {str(e)}")
@app.route("/", methods=["GET"])
def home():
return jsonify({"message": "Speech API is running", "status": "active"})
@app.route("/health", methods=["GET"])
def health_check():
health_status = {
"api_status": "online",
"asr_model": "loaded" if asr_model is not None else "failed",
"tts_models": {lang: "loaded" if model is not None else "failed"
for lang, model in tts_models.items()},
"device": device
}
return jsonify(health_status)
@app.route("/asr", methods=["POST"])
def transcribe_audio():
if asr_model is None or asr_processor is None:
logger.error("❌ ASR endpoint called but models aren't loaded")
return jsonify({"error": "ASR model not available"}), 503
try:
if "audio" not in request.files:
logger.warning("⚠️ ASR request missing audio file")
return jsonify({"error": "No audio file uploaded"}), 400
audio_file = request.files["audio"]
language = request.form.get("language", "english").lower()
if language not in LANGUAGE_CODES:
logger.warning(f"⚠️ Unsupported language requested: {language}")
return jsonify({"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
lang_code = LANGUAGE_CODES[language]
logger.info(f"πŸ”„ Processing {language} audio for ASR")
# Save the uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
temp_audio.write(audio_file.read())
temp_audio_path = temp_audio.name
logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
# Convert to WAV if necessary
wav_path = temp_audio_path
if not audio_file.filename.lower().endswith(".wav"):
wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav")
logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
try:
audio = AudioSegment.from_file(temp_audio_path)
audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
audio.export(wav_path, format="wav")
except Exception as e:
logger.error(f"❌ Audio conversion failed: {str(e)}")
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
# Load and process the WAV file
try:
waveform, sr = torchaudio.load(wav_path)
logger.debug(f"βœ… Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
# Resample if needed
if sr != SAMPLE_RATE:
logger.info(f"πŸ”„ Resampling audio from {sr}Hz to {SAMPLE_RATE}Hz")
waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
waveform = waveform / torch.max(torch.abs(waveform))
except Exception as e:
logger.error(f"❌ Failed to load or process audio: {str(e)}")
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
# Process audio for ASR
try:
inputs = asr_processor(
waveform.squeeze().numpy(),
sampling_rate=SAMPLE_RATE,
return_tensors="pt",
language=lang_code
)
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception as e:
logger.error(f"❌ ASR preprocessing failed: {str(e)}")
return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
# Perform ASR
try:
with torch.no_grad():
logits = asr_model(**inputs).logits
ids = torch.argmax(logits, dim=-1)[0]
transcription = asr_processor.decode(ids)
logger.info(f"βœ… Transcription ({language}): {transcription}")
# Clean up temp files
try:
os.unlink(temp_audio_path)
if wav_path != temp_audio_path:
os.unlink(wav_path)
except Exception as e:
logger.warning(f"⚠️ Failed to clean up temp files: {str(e)}")
return jsonify({
"transcription": transcription,
"language": language,
"language_code": lang_code
})
except Exception as e:
logger.error(f"❌ ASR inference failed: {str(e)}")
logger.debug(f"Stack trace: {traceback.format_exc()}")
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
except Exception as e:
logger.error(f"❌ Unhandled exception in ASR endpoint: {str(e)}")
logger.debug(f"Stack trace: {traceback.format_exc()}")
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
@app.route("/tts", methods=["POST"])
def generate_tts():
try:
data = request.get_json()
if not data:
logger.warning("⚠️ TTS endpoint called with no JSON data")
return jsonify({"error": "No JSON data provided"}), 400
text_input = data.get("text", "").strip()
language = data.get("language", "kapampangan").lower()
if not text_input:
logger.warning("⚠️ TTS request with empty text")
return jsonify({"error": "No text provided"}), 400
if language not in TTS_MODELS:
logger.warning(f"⚠️ TTS requested for unsupported language: {language}")
return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
if tts_models[language] is None:
logger.error(f"❌ TTS model for {language} not loaded")
return jsonify({"error": f"TTS model for {language} not available"}), 503
logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
try:
processor = tts_processors[language]
model = tts_models[language]
inputs = processor(text_input, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception as e:
logger.error(f"❌ TTS preprocessing failed: {str(e)}")
return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
# Generate speech
try:
with torch.no_grad():
output = model(**inputs).waveform
waveform = output.squeeze().cpu().numpy()
except Exception as e:
logger.error(f"❌ TTS inference failed: {str(e)}")
logger.debug(f"Stack trace: {traceback.format_exc()}")
return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
# Save to file
try:
output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
sampling_rate = model.config.sampling_rate
sf.write(output_filename, waveform, sampling_rate)
logger.info(f"βœ… Speech generated! File saved: {output_filename}")
except Exception as e:
logger.error(f"❌ Failed to save audio file: {str(e)}")
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
return jsonify({
"message": "TTS audio generated",
"file_url": f"/download/{os.path.basename(output_filename)}",
"language": language,
"text_length": len(text_input)
})
except Exception as e:
logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
logger.debug(f"Stack trace: {traceback.format_exc()}")
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
@app.route("/download/<filename>", methods=["GET"])
def download_audio(filename):
file_path = os.path.join(OUTPUT_DIR, filename)
if os.path.exists(file_path):
logger.info(f"πŸ“€ Serving audio file: {file_path}")
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
logger.warning(f"⚠️ Requested file not found: {file_path}")
return jsonify({"error": "File not found"}), 404
if __name__ == "__main__":
logger.info("πŸš€ Starting Speech API server")
logger.info(f"πŸ“Š System status: ASR model: {'βœ…' if asr_model else '❌'}")
for lang, model in tts_models.items():
logger.info(f"πŸ“Š TTS model {lang}: {'βœ…' if model else '❌'}")
app.run(host="0.0.0.0", port=7860, debug=True)