Coco-18 commited on
Commit
6aa3d97
Β·
verified Β·
1 Parent(s): 6ddbbc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -15
app.py CHANGED
@@ -6,14 +6,12 @@ from flask import Flask, request, jsonify, send_file
6
  from flask_cors import CORS
7
  from transformers import VitsModel, AutoTokenizer
8
 
9
-
10
  # Set ALL cache directories to /tmp (writable in Hugging Face Spaces)
11
  os.environ["HF_HOME"] = "/tmp/hf_home"
12
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
13
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
14
  os.environ["TORCH_HOME"] = "/tmp/torch_home"
15
 
16
-
17
  app = Flask(__name__)
18
  CORS(app) # Allow external requests
19
 
@@ -30,11 +28,9 @@ loaded_processors = {}
30
  for lang, path in MODELS.items():
31
  try:
32
  print(f"πŸ”„ Loading {lang} model: {path}...")
33
-
34
  # Force models to save in /tmp
35
  loaded_models[lang] = VitsModel.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
36
  loaded_processors[lang] = AutoTokenizer.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
37
-
38
  print(f"βœ… {lang.capitalize()} model loaded successfully!")
39
  except Exception as e:
40
  print(f"❌ Error loading {lang} model: {str(e)}")
@@ -48,12 +44,12 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
48
 
49
  @app.route("/", methods=["GET"])
50
  def home():
51
- """ Root route to check if the API is running """
52
  return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
53
 
54
  @app.route("/tts", methods=["POST"])
55
  def generate_tts():
56
- """ API endpoint to generate TTS audio """
57
  try:
58
  # Get request data
59
  data = request.get_json()
@@ -63,10 +59,8 @@ def generate_tts():
63
  # Validate inputs
64
  if language not in MODELS:
65
  return jsonify({"error": "Invalid language. Choose 'kapampangan', 'tagalog', or 'english'."}), 400
66
-
67
  if not text_input:
68
  return jsonify({"error": "No text provided"}), 400
69
-
70
  if loaded_models[language] is None:
71
  return jsonify({"error": f"Model for {language} failed to load"}), 500
72
 
@@ -80,27 +74,29 @@ def generate_tts():
80
  # Generate speech
81
  with torch.no_grad():
82
  output = model.generate(**inputs)
83
-
84
- waveform = output.cpu().numpy().flatten()
 
 
 
 
85
 
86
  # Save to file
87
  output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
88
  sf.write(output_filename, waveform, SAMPLE_RATE)
89
-
90
  print(f"βœ… Speech generated! File saved: {output_filename}")
91
 
92
  return jsonify({
93
  "message": "TTS audio generated",
94
  "file_url": f"/download/{language}_output.wav"
95
  })
96
-
97
  except Exception as e:
98
  print(f"❌ Error generating TTS: {e}")
99
- return jsonify({"error": "Internal server error"}), 500
100
 
101
  @app.route("/download/<filename>", methods=["GET"])
102
  def download_audio(filename):
103
- """ Serve generated audio files """
104
  file_path = os.path.join(OUTPUT_DIR, filename)
105
  if os.path.exists(file_path):
106
  return send_file(file_path, mimetype="audio/wav", as_attachment=True)
@@ -109,4 +105,3 @@ def download_audio(filename):
109
  if __name__ == "__main__":
110
  app.run(host="0.0.0.0", port=7860, debug=True)
111
 
112
-
 
6
  from flask_cors import CORS
7
  from transformers import VitsModel, AutoTokenizer
8
 
 
9
  # Set ALL cache directories to /tmp (writable in Hugging Face Spaces)
10
  os.environ["HF_HOME"] = "/tmp/hf_home"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
12
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
13
  os.environ["TORCH_HOME"] = "/tmp/torch_home"
14
 
 
15
  app = Flask(__name__)
16
  CORS(app) # Allow external requests
17
 
 
28
  for lang, path in MODELS.items():
29
  try:
30
  print(f"πŸ”„ Loading {lang} model: {path}...")
 
31
  # Force models to save in /tmp
32
  loaded_models[lang] = VitsModel.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
33
  loaded_processors[lang] = AutoTokenizer.from_pretrained(path, cache_dir="/tmp/huggingface_cache")
 
34
  print(f"βœ… {lang.capitalize()} model loaded successfully!")
35
  except Exception as e:
36
  print(f"❌ Error loading {lang} model: {str(e)}")
 
44
 
45
  @app.route("/", methods=["GET"])
46
  def home():
47
+ """Root route to check if the API is running"""
48
  return jsonify({"message": "TTS API is running. Use /tts to generate speech."})
49
 
50
  @app.route("/tts", methods=["POST"])
51
  def generate_tts():
52
+ """API endpoint to generate TTS audio"""
53
  try:
54
  # Get request data
55
  data = request.get_json()
 
59
  # Validate inputs
60
  if language not in MODELS:
61
  return jsonify({"error": "Invalid language. Choose 'kapampangan', 'tagalog', or 'english'."}), 400
 
62
  if not text_input:
63
  return jsonify({"error": "No text provided"}), 400
 
64
  if loaded_models[language] is None:
65
  return jsonify({"error": f"Model for {language} failed to load"}), 500
66
 
 
74
  # Generate speech
75
  with torch.no_grad():
76
  output = model.generate(**inputs)
77
+ # For VITS models, the output is typically a waveform
78
+ # Check if output is a tuple/list or a single tensor
79
+ if isinstance(output, tuple) or isinstance(output, list):
80
+ waveform = output[0].cpu().numpy().squeeze()
81
+ else:
82
+ waveform = output.cpu().numpy().squeeze()
83
 
84
  # Save to file
85
  output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
86
  sf.write(output_filename, waveform, SAMPLE_RATE)
 
87
  print(f"βœ… Speech generated! File saved: {output_filename}")
88
 
89
  return jsonify({
90
  "message": "TTS audio generated",
91
  "file_url": f"/download/{language}_output.wav"
92
  })
 
93
  except Exception as e:
94
  print(f"❌ Error generating TTS: {e}")
95
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500
96
 
97
  @app.route("/download/<filename>", methods=["GET"])
98
  def download_audio(filename):
99
+ """Serve generated audio files"""
100
  file_path = os.path.join(OUTPUT_DIR, filename)
101
  if os.path.exists(file_path):
102
  return send_file(file_path, mimetype="audio/wav", as_attachment=True)
 
105
  if __name__ == "__main__":
106
  app.run(host="0.0.0.0", port=7860, debug=True)
107