Coco-18 commited on
Commit
04bb535
Β·
verified Β·
1 Parent(s): 9a69347

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -81
app.py CHANGED
@@ -1,53 +1,88 @@
1
  # Set cache directories first, before other imports
2
  import os
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Set all cache directories to locations within /tmp
5
- os.environ["HF_HOME"] = "/tmp/hf_home"
6
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
7
- os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_hub_cache"
8
- os.environ["TORCH_HOME"] = "/tmp/torch_home"
9
- os.environ["XDG_CACHE_HOME"] = "/tmp/xdg_cache"
 
 
10
 
11
- # Create necessary directories
12
- for path in ["/tmp/hf_home", "/tmp/transformers_cache", "/tmp/huggingface_hub_cache", "/tmp/torch_home", "/tmp/xdg_cache"]:
13
- os.makedirs(path, exist_ok=True)
 
 
 
 
 
14
 
15
  # Now import the rest of the libraries
16
- import torch
17
- from pydub import AudioSegment
18
- import tempfile
19
- import torchaudio
20
- import soundfile as sf
21
- from flask import Flask, request, jsonify, send_file
22
- from flask_cors import CORS
23
- from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  app = Flask(__name__)
26
  CORS(app)
27
 
28
  # ASR Model
29
  ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
30
- print(f"Loading ASR model: {ASR_MODEL_ID}")
 
 
 
31
 
32
  try:
33
  asr_processor = AutoProcessor.from_pretrained(
34
  ASR_MODEL_ID,
35
- cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir
36
  )
 
 
37
  asr_model = Wav2Vec2ForCTC.from_pretrained(
38
  ASR_MODEL_ID,
39
- cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir
40
  )
41
- print("βœ… ASR Model loaded successfully")
 
42
  except Exception as e:
43
- print(f"❌ Error loading ASR model: {str(e)}")
44
- # Provide more debugging information
45
- import sys
46
- print(f"Python version: {sys.version}")
47
- print(f"Current working directory: {os.getcwd()}")
48
- print(f"Temp directory exists: {os.path.exists('/tmp')}")
49
- print(f"Temp directory writeable: {os.access('/tmp', os.W_OK)}")
50
- # Let's continue anyway to see if we can at least start the API
51
 
52
  # Language-specific configurations
53
  LANGUAGE_CODES = {
@@ -66,125 +101,213 @@ TTS_MODELS = {
66
  tts_models = {}
67
  tts_processors = {}
68
  for lang, model_id in TTS_MODELS.items():
 
69
  try:
70
- tts_models[lang] = VitsModel.from_pretrained(
71
  model_id,
72
- cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir
73
  )
74
- tts_processors[lang] = AutoTokenizer.from_pretrained(
 
 
75
  model_id,
76
- cache_dir="/tmp/transformers_cache" # Explicitly set cache_dir
77
  )
78
- print(f"βœ… TTS Model loaded: {lang}")
 
79
  except Exception as e:
80
- print(f"❌ Error loading {lang} TTS model: {e}")
 
81
  tts_models[lang] = None
82
 
83
  # Constants
84
  SAMPLE_RATE = 16000
85
  OUTPUT_DIR = "/tmp/audio_outputs"
86
- os.makedirs(OUTPUT_DIR, exist_ok=True)
 
 
 
 
87
 
88
  @app.route("/", methods=["GET"])
89
  def home():
90
- return jsonify({"message": "Speech API is running."})
91
-
 
 
 
 
 
 
 
 
 
 
92
 
93
  @app.route("/asr", methods=["POST"])
94
  def transcribe_audio():
 
 
 
 
95
  try:
96
  if "audio" not in request.files:
 
97
  return jsonify({"error": "No audio file uploaded"}), 400
98
 
99
  audio_file = request.files["audio"]
100
  language = request.form.get("language", "english").lower()
101
 
102
  if language not in LANGUAGE_CODES:
103
- return jsonify({"error": f"Unsupported language: {language}"}), 400
 
104
 
105
  lang_code = LANGUAGE_CODES[language]
 
106
 
107
  # Save the uploaded file temporarily
108
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
109
  temp_audio.write(audio_file.read())
110
  temp_audio_path = temp_audio.name
 
111
 
112
  # Convert to WAV if necessary
113
  wav_path = temp_audio_path
114
  if not audio_file.filename.lower().endswith(".wav"):
115
  wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav")
116
- audio = AudioSegment.from_file(temp_audio_path)
117
- audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
118
- audio.export(wav_path, format="wav")
 
 
 
 
 
119
 
120
  # Load and process the WAV file
121
- waveform, sr = torchaudio.load(wav_path)
 
 
122
 
123
- # Resample if needed
124
- if sr != SAMPLE_RATE:
125
- waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
 
126
 
127
- waveform = waveform / torch.max(torch.abs(waveform))
 
 
 
128
 
129
  # Process audio for ASR
130
- inputs = asr_processor(
131
- waveform.squeeze().numpy(),
132
- sampling_rate=SAMPLE_RATE,
133
- return_tensors="pt",
134
- language=lang_code
135
- )
 
 
 
 
 
136
 
137
  # Perform ASR
138
- with torch.no_grad():
139
- logits = asr_model(**inputs).logits
140
- ids = torch.argmax(logits, dim=-1)[0]
141
- transcription = asr_processor.decode(ids)
142
-
143
- print(f"Transcription ({language}): {transcription}")
144
-
145
- return jsonify({"transcription": transcription})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  except Exception as e:
148
- print(f"ASR error: {str(e)}")
149
- return jsonify({"error": f"ASR failed: {str(e)}"}), 500
 
150
 
151
 
152
  @app.route("/tts", methods=["POST"])
153
  def generate_tts():
154
  try:
155
  data = request.get_json()
 
 
 
 
156
  text_input = data.get("text", "").strip()
157
  language = data.get("language", "kapampangan").lower()
158
 
159
- if language not in TTS_MODELS:
160
- return jsonify({"error": "Invalid language"}), 400
161
  if not text_input:
 
162
  return jsonify({"error": "No text provided"}), 400
 
 
 
 
 
163
  if tts_models[language] is None:
164
- return jsonify({"error": "TTS model not available"}), 500
165
-
166
- processor = tts_processors[language]
167
- model = tts_models[language]
168
- inputs = processor(text_input, return_tensors="pt")
169
-
170
- # Generate speech - using model(**inputs) instead of model.generate()
171
- with torch.no_grad():
172
- output = model(**inputs).waveform
173
- waveform = output.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Save to file
176
- output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
177
- # Use the model's sampling rate
178
- sampling_rate = model.config.sampling_rate
179
- sf.write(output_filename, waveform, sampling_rate)
180
- print(f"βœ… Speech generated! File saved: {output_filename}")
 
 
 
181
 
182
  return jsonify({
183
  "message": "TTS audio generated",
184
- "file_url": f"/download/{language}_output.wav"
 
 
185
  })
186
  except Exception as e:
187
- print(f"❌ Error generating TTS: {e}")
 
188
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
189
 
190
 
@@ -192,9 +315,17 @@ def generate_tts():
192
  def download_audio(filename):
193
  file_path = os.path.join(OUTPUT_DIR, filename)
194
  if os.path.exists(file_path):
 
195
  return send_file(file_path, mimetype="audio/wav", as_attachment=True)
 
 
196
  return jsonify({"error": "File not found"}), 404
197
 
198
 
199
  if __name__ == "__main__":
 
 
 
 
 
200
  app.run(host="0.0.0.0", port=7860, debug=True)
 
1
  # Set cache directories first, before other imports
2
  import os
3
+ import sys
4
+ import logging
5
+ import traceback
6
+
7
+ # Configure logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(levelname)s - %(message)s',
11
+ datefmt='%Y-%m-%d %H:%M:%S'
12
+ )
13
+ logger = logging.getLogger("speech_api")
14
 
15
  # Set all cache directories to locations within /tmp
16
+ cache_dirs = {
17
+ "HF_HOME": "/tmp/hf_home",
18
+ "TRANSFORMERS_CACHE": "/tmp/transformers_cache",
19
+ "HUGGINGFACE_HUB_CACHE": "/tmp/huggingface_hub_cache",
20
+ "TORCH_HOME": "/tmp/torch_home",
21
+ "XDG_CACHE_HOME": "/tmp/xdg_cache"
22
+ }
23
 
24
+ # Set environment variables and create directories
25
+ for env_var, path in cache_dirs.items():
26
+ os.environ[env_var] = path
27
+ try:
28
+ os.makedirs(path, exist_ok=True)
29
+ logger.info(f"πŸ“ Created cache directory: {path}")
30
+ except Exception as e:
31
+ logger.error(f"❌ Failed to create directory {path}: {str(e)}")
32
 
33
  # Now import the rest of the libraries
34
+ try:
35
+ import torch
36
+ from pydub import AudioSegment
37
+ import tempfile
38
+ import torchaudio
39
+ import soundfile as sf
40
+ from flask import Flask, request, jsonify, send_file
41
+ from flask_cors import CORS
42
+ from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
43
+ logger.info("βœ… All required libraries imported successfully")
44
+ except ImportError as e:
45
+ logger.critical(f"❌ Failed to import necessary libraries: {str(e)}")
46
+ sys.exit(1)
47
+
48
+ # Check CUDA availability
49
+ if torch.cuda.is_available():
50
+ logger.info(f"πŸš€ CUDA available: {torch.cuda.get_device_name(0)}")
51
+ device = "cuda"
52
+ else:
53
+ logger.info("⚠️ CUDA not available, using CPU")
54
+ device = "cpu"
55
 
56
  app = Flask(__name__)
57
  CORS(app)
58
 
59
  # ASR Model
60
  ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
61
+ logger.info(f"πŸ”„ Loading ASR model: {ASR_MODEL_ID}")
62
+
63
+ asr_processor = None
64
+ asr_model = None
65
 
66
  try:
67
  asr_processor = AutoProcessor.from_pretrained(
68
  ASR_MODEL_ID,
69
+ cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
70
  )
71
+ logger.info("βœ… ASR processor loaded successfully")
72
+
73
  asr_model = Wav2Vec2ForCTC.from_pretrained(
74
  ASR_MODEL_ID,
75
+ cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
76
  )
77
+ asr_model.to(device)
78
+ logger.info(f"βœ… ASR model loaded successfully on {device}")
79
  except Exception as e:
80
+ logger.error(f"❌ Error loading ASR model: {str(e)}")
81
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
82
+ logger.debug(f"Python version: {sys.version}")
83
+ logger.debug(f"Current working directory: {os.getcwd()}")
84
+ logger.debug(f"Temp directory exists: {os.path.exists('/tmp')}")
85
+ logger.debug(f"Temp directory writeable: {os.access('/tmp', os.W_OK)}")
 
 
86
 
87
  # Language-specific configurations
88
  LANGUAGE_CODES = {
 
101
  tts_models = {}
102
  tts_processors = {}
103
  for lang, model_id in TTS_MODELS.items():
104
+ logger.info(f"πŸ”„ Loading TTS model for {lang}: {model_id}")
105
  try:
106
+ tts_processors[lang] = AutoTokenizer.from_pretrained(
107
  model_id,
108
+ cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
109
  )
110
+ logger.info(f"βœ… {lang} TTS processor loaded")
111
+
112
+ tts_models[lang] = VitsModel.from_pretrained(
113
  model_id,
114
+ cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
115
  )
116
+ tts_models[lang].to(device)
117
+ logger.info(f"βœ… {lang} TTS model loaded on {device}")
118
  except Exception as e:
119
+ logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
120
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
121
  tts_models[lang] = None
122
 
123
  # Constants
124
  SAMPLE_RATE = 16000
125
  OUTPUT_DIR = "/tmp/audio_outputs"
126
+ try:
127
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
128
+ logger.info(f"πŸ“ Created output directory: {OUTPUT_DIR}")
129
+ except Exception as e:
130
+ logger.error(f"❌ Failed to create output directory: {str(e)}")
131
 
132
  @app.route("/", methods=["GET"])
133
  def home():
134
+ return jsonify({"message": "Speech API is running", "status": "active"})
135
+
136
+ @app.route("/health", methods=["GET"])
137
+ def health_check():
138
+ health_status = {
139
+ "api_status": "online",
140
+ "asr_model": "loaded" if asr_model is not None else "failed",
141
+ "tts_models": {lang: "loaded" if model is not None else "failed"
142
+ for lang, model in tts_models.items()},
143
+ "device": device
144
+ }
145
+ return jsonify(health_status)
146
 
147
  @app.route("/asr", methods=["POST"])
148
  def transcribe_audio():
149
+ if asr_model is None or asr_processor is None:
150
+ logger.error("❌ ASR endpoint called but models aren't loaded")
151
+ return jsonify({"error": "ASR model not available"}), 503
152
+
153
  try:
154
  if "audio" not in request.files:
155
+ logger.warning("⚠️ ASR request missing audio file")
156
  return jsonify({"error": "No audio file uploaded"}), 400
157
 
158
  audio_file = request.files["audio"]
159
  language = request.form.get("language", "english").lower()
160
 
161
  if language not in LANGUAGE_CODES:
162
+ logger.warning(f"⚠️ Unsupported language requested: {language}")
163
+ return jsonify({"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
164
 
165
  lang_code = LANGUAGE_CODES[language]
166
+ logger.info(f"πŸ”„ Processing {language} audio for ASR")
167
 
168
  # Save the uploaded file temporarily
169
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
170
  temp_audio.write(audio_file.read())
171
  temp_audio_path = temp_audio.name
172
+ logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
173
 
174
  # Convert to WAV if necessary
175
  wav_path = temp_audio_path
176
  if not audio_file.filename.lower().endswith(".wav"):
177
  wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav")
178
+ logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
179
+ try:
180
+ audio = AudioSegment.from_file(temp_audio_path)
181
+ audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
182
+ audio.export(wav_path, format="wav")
183
+ except Exception as e:
184
+ logger.error(f"❌ Audio conversion failed: {str(e)}")
185
+ return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
186
 
187
  # Load and process the WAV file
188
+ try:
189
+ waveform, sr = torchaudio.load(wav_path)
190
+ logger.debug(f"βœ… Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
191
 
192
+ # Resample if needed
193
+ if sr != SAMPLE_RATE:
194
+ logger.info(f"πŸ”„ Resampling audio from {sr}Hz to {SAMPLE_RATE}Hz")
195
+ waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
196
 
197
+ waveform = waveform / torch.max(torch.abs(waveform))
198
+ except Exception as e:
199
+ logger.error(f"❌ Failed to load or process audio: {str(e)}")
200
+ return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
201
 
202
  # Process audio for ASR
203
+ try:
204
+ inputs = asr_processor(
205
+ waveform.squeeze().numpy(),
206
+ sampling_rate=SAMPLE_RATE,
207
+ return_tensors="pt",
208
+ language=lang_code
209
+ )
210
+ inputs = {k: v.to(device) for k, v in inputs.items()}
211
+ except Exception as e:
212
+ logger.error(f"❌ ASR preprocessing failed: {str(e)}")
213
+ return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
214
 
215
  # Perform ASR
216
+ try:
217
+ with torch.no_grad():
218
+ logits = asr_model(**inputs).logits
219
+ ids = torch.argmax(logits, dim=-1)[0]
220
+ transcription = asr_processor.decode(ids)
221
+
222
+ logger.info(f"βœ… Transcription ({language}): {transcription}")
223
+
224
+ # Clean up temp files
225
+ try:
226
+ os.unlink(temp_audio_path)
227
+ if wav_path != temp_audio_path:
228
+ os.unlink(wav_path)
229
+ except Exception as e:
230
+ logger.warning(f"⚠️ Failed to clean up temp files: {str(e)}")
231
+
232
+ return jsonify({
233
+ "transcription": transcription,
234
+ "language": language,
235
+ "language_code": lang_code
236
+ })
237
+ except Exception as e:
238
+ logger.error(f"❌ ASR inference failed: {str(e)}")
239
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
240
+ return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
241
 
242
  except Exception as e:
243
+ logger.error(f"❌ Unhandled exception in ASR endpoint: {str(e)}")
244
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
245
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500
246
 
247
 
248
  @app.route("/tts", methods=["POST"])
249
  def generate_tts():
250
  try:
251
  data = request.get_json()
252
+ if not data:
253
+ logger.warning("⚠️ TTS endpoint called with no JSON data")
254
+ return jsonify({"error": "No JSON data provided"}), 400
255
+
256
  text_input = data.get("text", "").strip()
257
  language = data.get("language", "kapampangan").lower()
258
 
 
 
259
  if not text_input:
260
+ logger.warning("⚠️ TTS request with empty text")
261
  return jsonify({"error": "No text provided"}), 400
262
+
263
+ if language not in TTS_MODELS:
264
+ logger.warning(f"⚠️ TTS requested for unsupported language: {language}")
265
+ return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
266
+
267
  if tts_models[language] is None:
268
+ logger.error(f"❌ TTS model for {language} not loaded")
269
+ return jsonify({"error": f"TTS model for {language} not available"}), 503
270
+
271
+ logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
272
+
273
+ try:
274
+ processor = tts_processors[language]
275
+ model = tts_models[language]
276
+ inputs = processor(text_input, return_tensors="pt")
277
+ inputs = {k: v.to(device) for k, v in inputs.items()}
278
+ except Exception as e:
279
+ logger.error(f"❌ TTS preprocessing failed: {str(e)}")
280
+ return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
281
+
282
+ # Generate speech
283
+ try:
284
+ with torch.no_grad():
285
+ output = model(**inputs).waveform
286
+ waveform = output.squeeze().cpu().numpy()
287
+ except Exception as e:
288
+ logger.error(f"❌ TTS inference failed: {str(e)}")
289
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
290
+ return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
291
 
292
  # Save to file
293
+ try:
294
+ output_filename = os.path.join(OUTPUT_DIR, f"{language}_output.wav")
295
+ sampling_rate = model.config.sampling_rate
296
+ sf.write(output_filename, waveform, sampling_rate)
297
+ logger.info(f"βœ… Speech generated! File saved: {output_filename}")
298
+ except Exception as e:
299
+ logger.error(f"❌ Failed to save audio file: {str(e)}")
300
+ return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
301
 
302
  return jsonify({
303
  "message": "TTS audio generated",
304
+ "file_url": f"/download/{os.path.basename(output_filename)}",
305
+ "language": language,
306
+ "text_length": len(text_input)
307
  })
308
  except Exception as e:
309
+ logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
310
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
311
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
312
 
313
 
 
315
  def download_audio(filename):
316
  file_path = os.path.join(OUTPUT_DIR, filename)
317
  if os.path.exists(file_path):
318
+ logger.info(f"πŸ“€ Serving audio file: {file_path}")
319
  return send_file(file_path, mimetype="audio/wav", as_attachment=True)
320
+
321
+ logger.warning(f"⚠️ Requested file not found: {file_path}")
322
  return jsonify({"error": "File not found"}), 404
323
 
324
 
325
  if __name__ == "__main__":
326
+ logger.info("πŸš€ Starting Speech API server")
327
+ logger.info(f"πŸ“Š System status: ASR model: {'βœ…' if asr_model else '❌'}")
328
+ for lang, model in tts_models.items():
329
+ logger.info(f"πŸ“Š TTS model {lang}: {'βœ…' if model else '❌'}")
330
+
331
  app.run(host="0.0.0.0", port=7860, debug=True)