File size: 3,788 Bytes
8cb6f0a 036e058 8cb6f0a 6ddbbc3 8cb6f0a 6ddbbc3 8cb6f0a 661887e 036e058 661887e 6ddbbc3 661887e 6ddbbc3 8cb6f0a 6ddbbc3 036e058 8cb6f0a 6ddbbc3 fe51424 661887e 036e058 661887e fe51424 6aa3d97 fe51424 661887e 6aa3d97 036e058 661887e 036e058 fe51424 036e058 661887e 036e058 fe51424 036e058 fe51424 661887e 168acfa fe51424 168acfa 661887e 036e058 fe51424 168acfa 036e058 fe51424 036e058 fe51424 036e058 6aa3d97 661887e 036e058 6aa3d97 661887e 036e058 661887e fe51424 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import torch
import torchaudio
import soundfile as sf
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from transformers import VitsModel, AutoTokenizer
# Set ALL cache directories to /tmp (writable in Hugging Face Spaces)
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
os.environ["TORCH_HOME"] = "/tmp/torch_home"
app = Flask(__name__)
CORS(app) # Allow external requests
# Model paths for different languages (Hugging Face Hub)
MODELS = {
"kapampangan": "facebook/mms-tts-pam",
"tagalog": "facebook/mms-tts-tgl",
"english": "facebook/mms-tts-eng"
}
loaded_models = {}
loaded_processors = {}
for lang, path in MODELS.items():
try:
print(f"π Loading {lang} model: {path}...")
# Force models to save in /tmp
loaded_models[lang] = VitsModel.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
loaded_processors[lang] = AutoTokenizer.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
print(f"β
{lang.capitalize()} model loaded successfully!")
except Exception as e:
print(f"β Error loading {lang} model: {str(e)}")
loaded_models[lang] = None # Mark as unavailable
loaded_processors[lang] = None
# Constants
OUTPUT_DIR = "/tmp/"
os.makedirs(OUTPUT_DIR, exist_ok=True)
@app.route("/", methods=["GET"])
def home():
"""Root route to check if the API is running"""
return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
@app.route("/tts", methods=["POST"])
def generate_tts():
"""API endpoint to generate TTS audio"""
try:
# Get request data
data = request.get_json()
text_input = data.get("text", "").strip()
language = data.get("language", "kapampangan").lower()
# Validate inputs
if language not in MODELS:
return jsonify({"error": "Invalid language. Choose 'kapampangan', 'tagalog', or 'english'."}), 400
if not text_input:
return jsonify({"error": "No text provided"}), 400
if loaded_models[language] is None:
return jsonify({"error": f"Model for {language} failed to load"}), 500
print(f"π Generating speech for '{text_input}' in {language}...")
# Process text input
processor = loaded_processors[language]
model = loaded_models[language]
inputs = processor(text_input, return_tensors="pt")
# Generate speech - using model(**inputs) instead of model.generate()
with torch.no_grad():
output = model(**inputs).waveform
waveform = output.squeeze().cpu().numpy()
# Save to file
output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
# Use the model's sampling rate
sampling_rate = model.config.sampling_rate
sf.write(output_filename, waveform, sampling_rate)
print(f"β
Speech generated! File saved: {output_filename}")
return jsonify({
"message": "TTS audio generated",
"file_url": f"/download/{language}_output.wav"
})
except Exception as e:
print(f"β Error generating TTS: {e}")
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
@app.route("/download/<filename>", methods=["GET"])
def download_audio(filename):
"""Serve generated audio files"""
file_path = os.path.join(OUTPUT_DIR, filename)
if os.path.exists(file_path):
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
return jsonify({"error": "File not found"}), 404
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)
|