Update app.py
Browse files
app.py
CHANGED
@@ -1,106 +1,109 @@
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
# Set
|
4 |
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
|
5 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
6 |
os.environ["HF_HOME"] = "/tmp/hf_home"
|
7 |
os.environ["TORCH_HOME"] = "/tmp/torch_home"
|
8 |
|
9 |
-
from flask import Flask, request, jsonify, send_file
|
10 |
-
from flask_cors import CORS
|
11 |
-
import torch
|
12 |
-
import torchaudio
|
13 |
-
import soundfile as sf
|
14 |
-
from transformers import VitsModel, AutoTokenizer
|
15 |
-
|
16 |
app = Flask(__name__)
|
17 |
-
CORS(app) # Allow
|
18 |
|
19 |
-
#
|
20 |
-
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
|
21 |
-
|
22 |
-
# Model paths for different languages (Hugging Face Hub)
|
23 |
MODELS = {
|
24 |
"kapampangan": "facebook/mms-tts-pam",
|
25 |
"tagalog": "facebook/mms-tts-tgl",
|
26 |
"english": "facebook/mms-tts-eng"
|
27 |
}
|
28 |
|
29 |
-
#
|
30 |
loaded_models = {}
|
31 |
loaded_processors = {}
|
32 |
|
33 |
-
|
|
|
34 |
try:
|
35 |
-
print(f"Loading {lang} model
|
36 |
-
loaded_models[lang] = VitsModel.from_pretrained(
|
37 |
-
loaded_processors[lang] = AutoTokenizer.from_pretrained(
|
38 |
-
print(f"{lang.capitalize()} model loaded successfully!")
|
39 |
except Exception as e:
|
40 |
-
print(f"Error loading {lang} model: {
|
41 |
-
loaded_models[lang] = None
|
42 |
loaded_processors[lang] = None
|
43 |
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
|
47 |
@app.route("/", methods=["GET"])
|
48 |
def home():
|
49 |
-
""" Root route to check if the
|
50 |
return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
|
51 |
|
52 |
@app.route("/tts", methods=["POST"])
|
53 |
def generate_tts():
|
54 |
-
""" API endpoint to generate
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
|
65 |
-
|
66 |
-
|
67 |
|
68 |
-
|
69 |
|
70 |
-
|
71 |
-
# Select the correct model and processor
|
72 |
-
model = loaded_models[language]
|
73 |
processor = loaded_processors[language]
|
74 |
-
|
75 |
-
# Tokenize input text
|
76 |
inputs = processor(text_input, return_tensors="pt")
|
77 |
|
78 |
-
# Generate
|
79 |
with torch.no_grad():
|
80 |
output = model.generate(**inputs)
|
81 |
|
82 |
waveform = output.cpu().numpy().flatten()
|
83 |
|
84 |
-
# Save
|
85 |
output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
|
86 |
sf.write(output_filename, waveform, SAMPLE_RATE)
|
87 |
|
|
|
|
|
88 |
return jsonify({
|
89 |
"message": "TTS audio generated",
|
90 |
-
"file_url": f"/
|
91 |
})
|
|
|
92 |
except Exception as e:
|
93 |
-
print(f"Error generating TTS: {
|
94 |
return jsonify({"error": "Internal server error"}), 500
|
95 |
|
96 |
-
@app.route("/
|
97 |
-
def
|
98 |
-
""" Serve
|
99 |
file_path = os.path.join(OUTPUT_DIR, filename)
|
100 |
if os.path.exists(file_path):
|
101 |
-
return send_file(file_path, mimetype="audio/wav")
|
102 |
return jsonify({"error": "File not found"}), 404
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
app.run(host="0.0.0.0", port=7860, debug=True)
|
106 |
|
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import soundfile as sf
|
5 |
+
from flask import Flask, request, jsonify, send_file
|
6 |
+
from flask_cors import CORS
|
7 |
+
from transformers import VitsModel, AutoTokenizer
|
8 |
|
9 |
+
# Set environment variables for Hugging Face cache (prevents permission issues)
|
10 |
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
|
11 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
12 |
os.environ["HF_HOME"] = "/tmp/hf_home"
|
13 |
os.environ["TORCH_HOME"] = "/tmp/torch_home"
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
app = Flask(__name__)
|
16 |
+
CORS(app) # Allow external requests
|
17 |
|
18 |
+
# Language models (Hugging Face)
|
|
|
|
|
|
|
19 |
MODELS = {
|
20 |
"kapampangan": "facebook/mms-tts-pam",
|
21 |
"tagalog": "facebook/mms-tts-tgl",
|
22 |
"english": "facebook/mms-tts-eng"
|
23 |
}
|
24 |
|
25 |
+
# Model storage
|
26 |
loaded_models = {}
|
27 |
loaded_processors = {}
|
28 |
|
29 |
+
# Load models and processors
|
30 |
+
for lang, model_path in MODELS.items():
|
31 |
try:
|
32 |
+
print(f"π Loading {lang} model: {model_path}...")
|
33 |
+
loaded_models[lang] = VitsModel.from_pretrained(model_path)
|
34 |
+
loaded_processors[lang] = AutoTokenizer.from_pretrained(model_path)
|
35 |
+
print(f"β
{lang.capitalize()} model loaded successfully!")
|
36 |
except Exception as e:
|
37 |
+
print(f"β Error loading {lang} model: {e}")
|
38 |
+
loaded_models[lang] = None
|
39 |
loaded_processors[lang] = None
|
40 |
|
41 |
+
# Constants
|
42 |
+
SAMPLE_RATE = 16000
|
43 |
+
OUTPUT_DIR = "/tmp/"
|
44 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
45 |
|
46 |
@app.route("/", methods=["GET"])
|
47 |
def home():
|
48 |
+
""" Root route to check if the API is running """
|
49 |
return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
|
50 |
|
51 |
@app.route("/tts", methods=["POST"])
|
52 |
def generate_tts():
|
53 |
+
""" API endpoint to generate TTS audio """
|
54 |
+
try:
|
55 |
+
# Get request data
|
56 |
+
data = request.get_json()
|
57 |
+
text_input = data.get("text", "").strip()
|
58 |
+
language = data.get("language", "kapampangan").lower()
|
59 |
|
60 |
+
# Validate inputs
|
61 |
+
if language not in MODELS:
|
62 |
+
return jsonify({"error": "Invalid language. Choose 'kapampangan', 'tagalog', or 'english'."}), 400
|
63 |
|
64 |
+
if not text_input:
|
65 |
+
return jsonify({"error": "No text provided"}), 400
|
66 |
|
67 |
+
if loaded_models[language] is None:
|
68 |
+
return jsonify({"error": f"Model for {language} failed to load"}), 500
|
69 |
|
70 |
+
print(f"π Generating speech for '{text_input}' in {language}...")
|
71 |
|
72 |
+
# Process text input
|
|
|
|
|
73 |
processor = loaded_processors[language]
|
74 |
+
model = loaded_models[language]
|
|
|
75 |
inputs = processor(text_input, return_tensors="pt")
|
76 |
|
77 |
+
# Generate speech
|
78 |
with torch.no_grad():
|
79 |
output = model.generate(**inputs)
|
80 |
|
81 |
waveform = output.cpu().numpy().flatten()
|
82 |
|
83 |
+
# Save to file
|
84 |
output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
|
85 |
sf.write(output_filename, waveform, SAMPLE_RATE)
|
86 |
|
87 |
+
print(f"β
Speech generated! File saved: {output_filename}")
|
88 |
+
|
89 |
return jsonify({
|
90 |
"message": "TTS audio generated",
|
91 |
+
"file_url": f"/download/{language}_output.wav"
|
92 |
})
|
93 |
+
|
94 |
except Exception as e:
|
95 |
+
print(f"β Error generating TTS: {e}")
|
96 |
return jsonify({"error": "Internal server error"}), 500
|
97 |
|
98 |
+
@app.route("/download/<filename>", methods=["GET"])
|
99 |
+
def download_audio(filename):
|
100 |
+
""" Serve generated audio files """
|
101 |
file_path = os.path.join(OUTPUT_DIR, filename)
|
102 |
if os.path.exists(file_path):
|
103 |
+
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
|
104 |
return jsonify({"error": "File not found"}), 404
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
app.run(host="0.0.0.0", port=7860, debug=True)
|
108 |
|
109 |
+
|