|
|
|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/hf_home" |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_hub_cache" |
|
os.environ["TORCH_HOME"] = "/tmp/torch_home" |
|
os.environ["XDG_CACHE_HOME"] = "/tmp/xdg_cache" |
|
|
|
|
|
for path in ["/tmp/hf_home", "/tmp/transformers_cache", "/tmp/huggingface_hub_cache", "/tmp/torch_home", "/tmp/xdg_cache"]: |
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
import torch |
|
from pydub import AudioSegment |
|
import tempfile |
|
import torchaudio |
|
import soundfile as sf |
|
from flask import Flask, request, jsonify, send_file |
|
from flask_cors import CORS |
|
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor" |
|
print(f"Loading ASR model: {ASR_MODEL_ID}") |
|
|
|
try: |
|
asr_processor = AutoProcessor.from_pretrained( |
|
ASR_MODEL_ID, |
|
cache_dir="/tmp/transformers_cache" |
|
) |
|
asr_model = Wav2Vec2ForCTC.from_pretrained( |
|
ASR_MODEL_ID, |
|
cache_dir="/tmp/transformers_cache" |
|
) |
|
print("β
ASR Model loaded successfully") |
|
except Exception as e: |
|
print(f"β Error loading ASR model: {str(e)}") |
|
|
|
import sys |
|
print(f"Python version: {sys.version}") |
|
print(f"Current working directory: {os.getcwd()}") |
|
print(f"Temp directory exists: {os.path.exists('/tmp')}") |
|
print(f"Temp directory writeable: {os.access('/tmp', os.W_OK)}") |
|
|
|
|
|
|
|
LANGUAGE_CODES = { |
|
"kapampangan": "pam", |
|
"tagalog": "tgl", |
|
"english": "eng" |
|
} |
|
|
|
|
|
TTS_MODELS = { |
|
"kapampangan": "facebook/mms-tts-pam", |
|
"tagalog": "facebook/mms-tts-tgl", |
|
"english": "facebook/mms-tts-eng" |
|
} |
|
|
|
tts_models = {} |
|
tts_processors = {} |
|
for lang, model_id in TTS_MODELS.items(): |
|
try: |
|
tts_models[lang] = VitsModel.from_pretrained( |
|
model_id, |
|
cache_dir="/tmp/transformers_cache" |
|
) |
|
tts_processors[lang] = AutoTokenizer.from_pretrained( |
|
model_id, |
|
cache_dir="/tmp/transformers_cache" |
|
) |
|
print(f"β
TTS Model loaded: {lang}") |
|
except Exception as e: |
|
print(f"β Error loading {lang} TTS model: {e}") |
|
tts_models[lang] = None |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
OUTPUT_DIR = "/tmp/audio_outputs" |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
def home(): |
|
return jsonify({"message": "Speech API is running."}) |
|
|
|
|
|
@app.route("/asr", methods=["POST"]) |
|
def transcribe_audio(): |
|
try: |
|
if "audio" not in request.files: |
|
return jsonify({"error": "No audio file uploaded"}), 400 |
|
|
|
audio_file = request.files["audio"] |
|
language = request.form.get("language", "english").lower() |
|
|
|
if language not in LANGUAGE_CODES: |
|
return jsonify({"error": f"Unsupported language: {language}"}), 400 |
|
|
|
lang_code = LANGUAGE_CODES[language] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio: |
|
temp_audio.write(audio_file.read()) |
|
temp_audio_path = temp_audio.name |
|
|
|
|
|
wav_path = temp_audio_path |
|
if not audio_file.filename.lower().endswith(".wav"): |
|
wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav") |
|
audio = AudioSegment.from_file(temp_audio_path) |
|
audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1) |
|
audio.export(wav_path, format="wav") |
|
|
|
|
|
waveform, sr = torchaudio.load(wav_path) |
|
|
|
|
|
if sr != SAMPLE_RATE: |
|
waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform) |
|
|
|
waveform = waveform / torch.max(torch.abs(waveform)) |
|
|
|
|
|
inputs = asr_processor( |
|
waveform.squeeze().numpy(), |
|
sampling_rate=SAMPLE_RATE, |
|
return_tensors="pt", |
|
language=lang_code |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
logits = asr_model(**inputs).logits |
|
ids = torch.argmax(logits, dim=-1)[0] |
|
transcription = asr_processor.decode(ids) |
|
|
|
print(f"Transcription ({language}): {transcription}") |
|
|
|
return jsonify({"transcription": transcription}) |
|
|
|
except Exception as e: |
|
print(f"ASR error: {str(e)}") |
|
return jsonify({"error": f"ASR failed: {str(e)}"}), 500 |
|
|
|
|
|
@app.route("/tts", methods=["POST"]) |
|
def generate_tts(): |
|
try: |
|
data = request.get_json() |
|
text_input = data.get("text", "").strip() |
|
language = data.get("language", "kapampangan").lower() |
|
|
|
if language not in TTS_MODELS: |
|
return jsonify({"error": "Invalid language"}), 400 |
|
if not text_input: |
|
return jsonify({"error": "No text provided"}), 400 |
|
if tts_models[language] is None: |
|
return jsonify({"error": "TTS model not available"}), 500 |
|
|
|
processor = tts_processors[language] |
|
model = tts_models[language] |
|
inputs = processor(text_input, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs) |
|
|
|
waveform = output.cpu().numpy().flatten() |
|
output_filename = os.path.join(OUTPUT_DIR, f"{language}_tts.wav") |
|
sf.write(output_filename, waveform, SAMPLE_RATE) |
|
|
|
return jsonify({"file_url": f"/download/{language}_tts.wav"}) |
|
except Exception as e: |
|
return jsonify({"error": f"TTS failed: {e}"}), 500 |
|
|
|
|
|
@app.route("/download/<filename>", methods=["GET"]) |
|
def download_audio(filename): |
|
file_path = os.path.join(OUTPUT_DIR, filename) |
|
if os.path.exists(file_path): |
|
return send_file(file_path, mimetype="audio/wav", as_attachment=True) |
|
return jsonify({"error": "File not found"}), 404 |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860, debug=True) |