Update app.py
Browse files
app.py
CHANGED
@@ -6,14 +6,12 @@ from flask import Flask, request, jsonify, send_file
|
|
6 |
from flask_cors import CORS
|
7 |
from transformers import VitsModel, AutoTokenizer
|
8 |
|
9 |
-
|
10 |
# Set ALL cache directories to /tmp (writable in Hugging Face Spaces)
|
11 |
os.environ["HF_HOME"] = "/tmp/hf_home"
|
12 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
13 |
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
|
14 |
os.environ["TORCH_HOME"] = "/tmp/torch_home"
|
15 |
|
16 |
-
|
17 |
app = Flask(__name__)
|
18 |
CORS(app) # Allow external requests
|
19 |
|
@@ -30,11 +28,9 @@ loaded_processors = {}
|
|
30 |
for lang, path in MODELS.items():
|
31 |
try:
|
32 |
print(f"π Loading {lang} model: {path}...")
|
33 |
-
|
34 |
# Force models to save in /tmp
|
35 |
loaded_models[lang] = VitsModel.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
|
36 |
loaded_processors[lang] = AutoTokenizer.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
|
37 |
-
|
38 |
print(f"β
{lang.capitalize()} model loaded successfully!")
|
39 |
except Exception as e:
|
40 |
print(f"β Error loading {lang} model: {str(e)}")
|
@@ -48,12 +44,12 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
48 |
|
49 |
@app.route("/", methods=["GET"])
|
50 |
def home():
|
51 |
-
"""
|
52 |
return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
|
53 |
|
54 |
@app.route("/tts", methods=["POST"])
|
55 |
def generate_tts():
|
56 |
-
"""
|
57 |
try:
|
58 |
# Get request data
|
59 |
data = request.get_json()
|
@@ -63,10 +59,8 @@ def generate_tts():
|
|
63 |
# Validate inputs
|
64 |
if language not in MODELS:
|
65 |
return jsonify({"error": "Invalid language. Choose 'kapampangan', 'tagalog', or 'english'."}), 400
|
66 |
-
|
67 |
if not text_input:
|
68 |
return jsonify({"error": "No text provided"}), 400
|
69 |
-
|
70 |
if loaded_models[language] is None:
|
71 |
return jsonify({"error": f"Model for {language} failed to load"}), 500
|
72 |
|
@@ -80,27 +74,29 @@ def generate_tts():
|
|
80 |
# Generate speech
|
81 |
with torch.no_grad():
|
82 |
output = model.generate(**inputs)
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# Save to file
|
87 |
output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
|
88 |
sf.write(output_filename, waveform, SAMPLE_RATE)
|
89 |
-
|
90 |
print(f"β
Speech generated! File saved: {output_filename}")
|
91 |
|
92 |
return jsonify({
|
93 |
"message": "TTS audio generated",
|
94 |
"file_url": f"/download/{language}_output.wav"
|
95 |
})
|
96 |
-
|
97 |
except Exception as e:
|
98 |
print(f"β Error generating TTS: {e}")
|
99 |
-
return jsonify({"error": "Internal server error"}), 500
|
100 |
|
101 |
@app.route("/download/<filename>", methods=["GET"])
|
102 |
def download_audio(filename):
|
103 |
-
"""
|
104 |
file_path = os.path.join(OUTPUT_DIR, filename)
|
105 |
if os.path.exists(file_path):
|
106 |
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
|
@@ -109,4 +105,3 @@ def download_audio(filename):
|
|
109 |
if __name__ == "__main__":
|
110 |
app.run(host="0.0.0.0", port=7860, debug=True)
|
111 |
|
112 |
-
|
|
|
6 |
from flask_cors import CORS
|
7 |
from transformers import VitsModel, AutoTokenizer
|
8 |
|
|
|
9 |
# Set ALL cache directories to /tmp (writable in Hugging Face Spaces)
|
10 |
os.environ["HF_HOME"] = "/tmp/hf_home"
|
11 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
12 |
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
|
13 |
os.environ["TORCH_HOME"] = "/tmp/torch_home"
|
14 |
|
|
|
15 |
app = Flask(__name__)
|
16 |
CORS(app) # Allow external requests
|
17 |
|
|
|
28 |
for lang, path in MODELS.items():
|
29 |
try:
|
30 |
print(f"π Loading {lang} model: {path}...")
|
|
|
31 |
# Force models to save in /tmp
|
32 |
loaded_models[lang] = VitsModel.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
|
33 |
loaded_processors[lang] = AutoTokenizer.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
|
|
|
34 |
print(f"β
{lang.capitalize()} model loaded successfully!")
|
35 |
except Exception as e:
|
36 |
print(f"β Error loading {lang} model: {str(e)}")
|
|
|
44 |
|
45 |
@app.route("/", methods=["GET"])
|
46 |
def home():
|
47 |
+
"""Root route to check if the API is running"""
|
48 |
return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
|
49 |
|
50 |
@app.route("/tts", methods=["POST"])
|
51 |
def generate_tts():
|
52 |
+
"""API endpoint to generate TTS audio"""
|
53 |
try:
|
54 |
# Get request data
|
55 |
data = request.get_json()
|
|
|
59 |
# Validate inputs
|
60 |
if language not in MODELS:
|
61 |
return jsonify({"error": "Invalid language. Choose 'kapampangan', 'tagalog', or 'english'."}), 400
|
|
|
62 |
if not text_input:
|
63 |
return jsonify({"error": "No text provided"}), 400
|
|
|
64 |
if loaded_models[language] is None:
|
65 |
return jsonify({"error": f"Model for {language} failed to load"}), 500
|
66 |
|
|
|
74 |
# Generate speech
|
75 |
with torch.no_grad():
|
76 |
output = model.generate(**inputs)
|
77 |
+
# For VITS models, the output is typically a waveform
|
78 |
+
# Check if output is a tuple/list or a single tensor
|
79 |
+
if isinstance(output, tuple) or isinstance(output, list):
|
80 |
+
waveform = output[0].cpu().numpy().squeeze()
|
81 |
+
else:
|
82 |
+
waveform = output.cpu().numpy().squeeze()
|
83 |
|
84 |
# Save to file
|
85 |
output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
|
86 |
sf.write(output_filename, waveform, SAMPLE_RATE)
|
|
|
87 |
print(f"β
Speech generated! File saved: {output_filename}")
|
88 |
|
89 |
return jsonify({
|
90 |
"message": "TTS audio generated",
|
91 |
"file_url": f"/download/{language}_output.wav"
|
92 |
})
|
|
|
93 |
except Exception as e:
|
94 |
print(f"β Error generating TTS: {e}")
|
95 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
96 |
|
97 |
@app.route("/download/<filename>", methods=["GET"])
|
98 |
def download_audio(filename):
|
99 |
+
"""Serve generated audio files"""
|
100 |
file_path = os.path.join(OUTPUT_DIR, filename)
|
101 |
if os.path.exists(file_path):
|
102 |
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
|
|
|
105 |
if __name__ == "__main__":
|
106 |
app.run(host="0.0.0.0", port=7860, debug=True)
|
107 |
|
|