Coco-18 commited on
Commit
f083352
Β·
verified Β·
1 Parent(s): 1a61e31

Update translator.py

Browse files
Files changed (1) hide show
  1. translator.py +440 -435
translator.py CHANGED
@@ -1,435 +1,440 @@
1
- # translator.py - Handles ASR, TTS, and translation tasks
2
-
3
- import os
4
- import sys
5
- import logging
6
- import traceback
7
- import torch
8
- import torchaudio
9
- import tempfile
10
- import soundfile as sf
11
- from pydub import AudioSegment
12
- from flask import jsonify
13
- from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
14
- from transformers import MarianMTModel, MarianTokenizer
15
-
16
- # Configure logging
17
- logger = logging.getLogger("speech_api")
18
-
19
- # Global variables to store models and processors
20
- asr_model = None
21
- asr_processor = None
22
- tts_models = {}
23
- tts_processors = {}
24
- translation_models = {}
25
- translation_tokenizers = {}
26
-
27
- # Language-specific configurations
28
- LANGUAGE_CODES = {
29
- "kapampangan": "pam",
30
- "filipino": "fil",
31
- "english": "eng",
32
- "tagalog": "tgl",
33
- }
34
-
35
- # TTS Models (Kapampangan, Tagalog, English)
36
- TTS_MODELS = {
37
- "kapampangan": "facebook/mms-tts-pam",
38
- "tagalog": "facebook/mms-tts-tgl",
39
- "english": "facebook/mms-tts-eng"
40
- }
41
-
42
- # Translation Models
43
- TRANSLATION_MODELS = {
44
- "pam-eng": "Coco-18/opus-mt-pam-en",
45
- "eng-pam": "Coco-18/opus-mt-en-pam",
46
- "tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
47
- "eng-tgl": "Helsinki-NLP/opus-mt-en-tl",
48
- "phi": "Coco-18/opus-mt-phi"
49
- }
50
-
51
-
52
- def init_models(device):
53
- """Initialize all models required for the API"""
54
- global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
55
-
56
- # Initialize ASR model
57
- ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
58
- logger.info(f"πŸ”„ Loading ASR model: {ASR_MODEL_ID}")
59
-
60
- try:
61
- asr_processor = AutoProcessor.from_pretrained(
62
- ASR_MODEL_ID,
63
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
64
- )
65
- logger.info("βœ… ASR processor loaded successfully")
66
-
67
- asr_model = Wav2Vec2ForCTC.from_pretrained(
68
- ASR_MODEL_ID,
69
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
70
- )
71
- asr_model.to(device)
72
- logger.info(f"βœ… ASR model loaded successfully on {device}")
73
- except Exception as e:
74
- logger.error(f"❌ Error loading ASR model: {str(e)}")
75
- logger.debug(f"Stack trace: {traceback.format_exc()}")
76
-
77
- # Initialize TTS models
78
- for lang, model_id in TTS_MODELS.items():
79
- logger.info(f"πŸ”„ Loading TTS model for {lang}: {model_id}")
80
- try:
81
- tts_processors[lang] = AutoTokenizer.from_pretrained(
82
- model_id,
83
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
84
- )
85
- logger.info(f"βœ… {lang} TTS processor loaded")
86
-
87
- tts_models[lang] = VitsModel.from_pretrained(
88
- model_id,
89
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
90
- )
91
- tts_models[lang].to(device)
92
- logger.info(f"βœ… {lang} TTS model loaded on {device}")
93
- except Exception as e:
94
- logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
95
- logger.debug(f"Stack trace: {traceback.format_exc()}")
96
- tts_models[lang] = None
97
-
98
- # Initialize translation models
99
- for model_key, model_id in TRANSLATION_MODELS.items():
100
- logger.info(f"πŸ”„ Loading Translation model: {model_id}")
101
-
102
- try:
103
- translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
104
- model_id,
105
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
106
- )
107
- logger.info(f"βœ… Translation tokenizer loaded successfully for {model_key}")
108
-
109
- translation_models[model_key] = MarianMTModel.from_pretrained(
110
- model_id,
111
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
112
- )
113
- translation_models[model_key].to(device)
114
- logger.info(f"βœ… Translation model loaded successfully on {device} for {model_key}")
115
- except Exception as e:
116
- logger.error(f"❌ Error loading Translation model for {model_key}: {str(e)}")
117
- logger.debug(f"Stack trace: {traceback.format_exc()}")
118
- translation_models[model_key] = None
119
- translation_tokenizers[model_key] = None
120
-
121
-
122
- def check_model_status():
123
- """Check and return the status of all models"""
124
- # Initialize direct language pair statuses based on loaded models
125
- translation_status = {}
126
-
127
- # Add status for direct model pairs
128
- for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]:
129
- translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[
130
- lang_pair] is not None else "failed"
131
-
132
- # Add special phi model status
133
- phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed"
134
- translation_status["pam-fil"] = phi_status
135
- translation_status["fil-pam"] = phi_status
136
- translation_status["pam-tgl"] = phi_status # Using phi model but replacing tgl with fil
137
- translation_status["tgl-pam"] = phi_status # Using phi model but replacing tgl with fil
138
-
139
- return {
140
- "asr_model": "loaded" if asr_model is not None else "failed",
141
- "tts_models": {lang: "loaded" if model is not None else "failed"
142
- for lang, model in tts_models.items()},
143
- "translation_models": translation_status
144
- }
145
-
146
-
147
- def handle_asr_request(request, output_dir, sample_rate):
148
- """Handle ASR (Automatic Speech Recognition) requests"""
149
- if asr_model is None or asr_processor is None:
150
- logger.error("❌ ASR endpoint called but models aren't loaded")
151
- return jsonify({"error": "ASR model not available"}), 503
152
-
153
- try:
154
- if "audio" not in request.files:
155
- logger.warning("⚠️ ASR request missing audio file")
156
- return jsonify({"error": "No audio file uploaded"}), 400
157
-
158
- audio_file = request.files["audio"]
159
- language = request.form.get("language", "english").lower()
160
-
161
- if language not in LANGUAGE_CODES:
162
- logger.warning(f"⚠️ Unsupported language requested: {language}")
163
- return jsonify(
164
- {"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
165
-
166
- lang_code = LANGUAGE_CODES[language]
167
- logger.info(f"πŸ”„ Processing {language} audio for ASR")
168
-
169
- # Save the uploaded file temporarily
170
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
171
- temp_audio.write(audio_file.read())
172
- temp_audio_path = temp_audio.name
173
- logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
174
-
175
- # Convert to WAV if necessary
176
- wav_path = temp_audio_path
177
- if not audio_file.filename.lower().endswith(".wav"):
178
- wav_path = os.path.join(output_dir, "converted_audio.wav")
179
- logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
180
- try:
181
- audio = AudioSegment.from_file(temp_audio_path)
182
- audio = audio.set_frame_rate(sample_rate).set_channels(1)
183
- audio.export(wav_path, format="wav")
184
- except Exception as e:
185
- logger.error(f"❌ Audio conversion failed: {str(e)}")
186
- return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
187
-
188
- # Load and process the WAV file
189
- try:
190
- waveform, sr = torchaudio.load(wav_path)
191
- logger.debug(f"βœ… Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
192
-
193
- # Resample if needed
194
- if sr != sample_rate:
195
- logger.info(f"πŸ”„ Resampling audio from {sr}Hz to {sample_rate}Hz")
196
- waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
197
-
198
- waveform = waveform / torch.max(torch.abs(waveform))
199
- except Exception as e:
200
- logger.error(f"❌ Failed to load or process audio: {str(e)}")
201
- return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
202
-
203
- # Process audio for ASR
204
- try:
205
- inputs = asr_processor(
206
- waveform.squeeze().numpy(),
207
- sampling_rate=sample_rate,
208
- return_tensors="pt",
209
- language=lang_code
210
- )
211
- inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
212
- except Exception as e:
213
- logger.error(f"❌ ASR preprocessing failed: {str(e)}")
214
- return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
215
-
216
- # Perform ASR
217
- try:
218
- with torch.no_grad():
219
- logits = asr_model(**inputs).logits
220
- ids = torch.argmax(logits, dim=-1)[0]
221
- transcription = asr_processor.decode(ids)
222
-
223
- logger.info(f"βœ… Transcription ({language}): {transcription}")
224
-
225
- # Clean up temp files
226
- try:
227
- os.unlink(temp_audio_path)
228
- if wav_path != temp_audio_path:
229
- os.unlink(wav_path)
230
- except Exception as e:
231
- logger.warning(f"⚠️ Failed to clean up temp files: {str(e)}")
232
-
233
- return jsonify({
234
- "transcription": transcription,
235
- "language": language,
236
- "language_code": lang_code
237
- })
238
- except Exception as e:
239
- logger.error(f"❌ ASR inference failed: {str(e)}")
240
- logger.debug(f"Stack trace: {traceback.format_exc()}")
241
- return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
242
-
243
- except Exception as e:
244
- logger.error(f"❌ Unhandled exception in ASR endpoint: {str(e)}")
245
- logger.debug(f"Stack trace: {traceback.format_exc()}")
246
- return jsonify({"error": f"Internal server error: {str(e)}"}), 500
247
-
248
- def handle_tts_request(request, output_dir):
249
- """Handle TTS (Text-to-Speech) requests"""
250
- try:
251
- data = request.get_json()
252
- if not data:
253
- logger.warning("⚠️ TTS endpoint called with no JSON data")
254
- return jsonify({"error": "No JSON data provided"}), 400
255
-
256
- text_input = data.get("text", "").strip()
257
- language = data.get("language", "kapampangan").lower()
258
-
259
- if not text_input:
260
- logger.warning("⚠️ TTS request with empty text")
261
- return jsonify({"error": "No text provided"}), 400
262
-
263
- if language not in TTS_MODELS:
264
- logger.warning(f"⚠️ TTS requested for unsupported language: {language}")
265
- return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
266
-
267
- if tts_models[language] is None:
268
- logger.error(f"❌ TTS model for {language} not loaded")
269
- return jsonify({"error": f"TTS model for {language} not available"}), 503
270
-
271
- logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
272
-
273
- try:
274
- processor = tts_processors[language]
275
- model = tts_models[language]
276
- inputs = processor(text_input, return_tensors="pt")
277
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
278
- except Exception as e:
279
- logger.error(f"❌ TTS preprocessing failed: {str(e)}")
280
- return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
281
-
282
- # Generate speech
283
- try:
284
- with torch.no_grad():
285
- output = model(**inputs).waveform
286
- waveform = output.squeeze().cpu().numpy()
287
- except Exception as e:
288
- logger.error(f"❌ TTS inference failed: {str(e)}")
289
- logger.debug(f"Stack trace: {traceback.format_exc()}")
290
- return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
291
-
292
- # Save to file
293
- try:
294
- output_filename = os.path.join(output_dir, f"{language}_output.wav")
295
- sampling_rate = model.config.sampling_rate
296
- sf.write(output_filename, waveform, sampling_rate)
297
- logger.info(f"βœ… Speech generated! File saved: {output_filename}")
298
- except Exception as e:
299
- logger.error(f"❌ Failed to save audio file: {str(e)}")
300
- return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
301
-
302
- return jsonify({
303
- "message": "TTS audio generated",
304
- "file_url": f"/download/{os.path.basename(output_filename)}",
305
- "language": language,
306
- "text_length": len(text_input)
307
- })
308
- except Exception as e:
309
- logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
310
- logger.debug(f"Stack trace: {traceback.format_exc()}")
311
- return jsonify({"error": f"Internal server error: {str(e)}"}), 500
312
-
313
- def handle_translation_request(request):
314
- """Handle translation requests"""
315
- try:
316
- data = request.get_json()
317
- if not data:
318
- logger.warning("⚠️ Translation endpoint called with no JSON data")
319
- return jsonify({"error": "No JSON data provided"}), 400
320
-
321
- source_text = data.get("text", "").strip()
322
- source_language = data.get("source_language", "").lower()
323
- target_language = data.get("target_language", "").lower()
324
-
325
- if not source_text:
326
- logger.warning("⚠️ Translation request with empty text")
327
- return jsonify({"error": "No text provided"}), 400
328
-
329
- # Map language names to codes
330
- source_code = LANGUAGE_CODES.get(source_language, source_language)
331
- target_code = LANGUAGE_CODES.get(target_language, target_language)
332
-
333
- logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
334
-
335
- # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
336
- use_phi_model = False
337
- actual_source_code = source_code
338
- actual_target_code = target_code
339
-
340
- # Check if we need to use the phi model with fil replacement
341
- if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
342
- use_phi_model = True
343
- elif (source_code == "pam" and target_code == "tgl"):
344
- use_phi_model = True
345
- actual_target_code = "fil" # Replace tgl with fil for the phi model
346
- elif (source_code == "tgl" and target_code == "pam"):
347
- use_phi_model = True
348
- actual_source_code = "fil" # Replace tgl with fil for the phi model
349
-
350
- if use_phi_model:
351
- model_key = "phi"
352
-
353
- # Check if we have the phi model
354
- if model_key not in translation_models or translation_models[model_key] is None:
355
- logger.error(f"❌ Translation model for {model_key} not loaded")
356
- return jsonify({"error": f"Translation model not available"}), 503
357
-
358
- try:
359
- # Get the phi model and tokenizer
360
- model = translation_models[model_key]
361
- tokenizer = translation_tokenizers[model_key]
362
-
363
- # Prepend target language token to input
364
- input_text = f">>{actual_target_code}<< {source_text}"
365
-
366
- logger.info(f"πŸ”„ Using phi model with input: '{input_text}'")
367
-
368
- # Tokenize the text
369
- tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
370
- tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
371
-
372
- # Generate translation
373
- with torch.no_grad():
374
- translated = model.generate(**tokenized)
375
-
376
- # Decode the translation
377
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
378
-
379
- logger.info(f"βœ… Translation result: '{result}'")
380
-
381
- return jsonify({
382
- "translated_text": result,
383
- "source_language": source_language,
384
- "target_language": target_language
385
- })
386
- except Exception as e:
387
- logger.error(f"❌ Translation processing failed: {str(e)}")
388
- logger.debug(f"Stack trace: {traceback.format_exc()}")
389
- return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
390
- else:
391
- # Create the regular language pair key for other language pairs
392
- lang_pair = f"{source_code}-{target_code}"
393
-
394
- # Check if we have a model for this language pair
395
- if lang_pair not in translation_models:
396
- logger.warning(f"⚠️ No translation model available for {lang_pair}")
397
- return jsonify(
398
- {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
399
-
400
- if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
401
- logger.error(f"❌ Translation model for {lang_pair} not loaded")
402
- return jsonify({"error": f"Translation model not available"}), 503
403
-
404
- try:
405
- # Regular translation process for other language pairs
406
- model = translation_models[lang_pair]
407
- tokenizer = translation_tokenizers[lang_pair]
408
-
409
- # Tokenize the text
410
- tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
411
- tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
412
-
413
- # Generate translation
414
- with torch.no_grad():
415
- translated = model.generate(**tokenized)
416
-
417
- # Decode the translation
418
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
419
-
420
- logger.info(f"βœ… Translation result: '{result}'")
421
-
422
- return jsonify({
423
- "translated_text": result,
424
- "source_language": source_language,
425
- "target_language": target_language
426
- })
427
- except Exception as e:
428
- logger.error(f"❌ Translation processing failed: {str(e)}")
429
- logger.debug(f"Stack trace: {traceback.format_exc()}")
430
- return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
431
-
432
- except Exception as e:
433
- logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")
434
- logger.debug(f"Stack trace: {traceback.format_exc()}")
435
- return jsonify({"error": f"Internal server error: {str(e)}"}), 500
 
 
 
 
 
 
1
+ # translator.py - Handles ASR, TTS, and translation tasks
2
+
3
+ import os
4
+ import sys
5
+ import logging
6
+ import traceback
7
+ import torch
8
+ import torchaudio
9
+ import tempfile
10
+ import soundfile as sf
11
+ from pydub import AudioSegment
12
+ from flask import jsonify
13
+ from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
14
+ from transformers import MarianMTModel, MarianTokenizer
15
+
16
+ # Configure logging
17
+ logger = logging.getLogger("speech_api")
18
+
19
+ # Global variables to store models and processors
20
+ asr_model = None
21
+ asr_processor = None
22
+ tts_models = {}
23
+ tts_processors = {}
24
+ translation_models = {}
25
+ translation_tokenizers = {}
26
+
27
+ # Language-specific configurations
28
+ LANGUAGE_CODES = {
29
+ "kapampangan": "pam",
30
+ "filipino": "fil",
31
+ "english": "eng",
32
+ "tagalog": "tgl",
33
+ }
34
+
35
+ # TTS Models (Kapampangan, Tagalog, English)
36
+ TTS_MODELS = {
37
+ "kapampangan": "facebook/mms-tts-pam",
38
+ "tagalog": "facebook/mms-tts-tgl",
39
+ "english": "facebook/mms-tts-eng"
40
+ }
41
+
42
+ # Translation Models
43
+ TRANSLATION_MODELS = {
44
+ "pam-eng": "Coco-18/opus-mt-pam-en",
45
+ "eng-pam": "Coco-18/opus-mt-en-pam",
46
+ "tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
47
+ "eng-tgl": "Helsinki-NLP/opus-mt-en-tl",
48
+ "phi": "Coco-18/opus-mt-phi"
49
+ }
50
+
51
+ def init_models(device):
52
+ """Initialize all models required for the API"""
53
+ global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
54
+
55
+ # Initialize ASR model
56
+ ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
57
+ logger.info(f"πŸ”„ Loading ASR model: {ASR_MODEL_ID}")
58
+
59
+ try:
60
+ asr_processor = AutoProcessor.from_pretrained(
61
+ ASR_MODEL_ID,
62
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
63
+ )
64
+ logger.info("βœ… ASR processor loaded successfully")
65
+
66
+ asr_model = Wav2Vec2ForCTC.from_pretrained(
67
+ ASR_MODEL_ID,
68
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
69
+ )
70
+ asr_model.to(device)
71
+ logger.info(f"βœ… ASR model loaded successfully on {device}")
72
+ except Exception as e:
73
+ logger.error(f"❌ Error loading ASR model: {str(e)}")
74
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
75
+
76
+ # Initialize TTS models
77
+ for lang, model_id in TTS_MODELS.items():
78
+ logger.info(f"πŸ”„ Loading TTS model for {lang}: {model_id}")
79
+ try:
80
+ tts_processors[lang] = AutoTokenizer.from_pretrained(
81
+ model_id,
82
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
83
+ )
84
+ logger.info(f"βœ… {lang} TTS processor loaded")
85
+
86
+ tts_models[lang] = VitsModel.from_pretrained(
87
+ model_id,
88
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
89
+ )
90
+ tts_models[lang].to(device)
91
+ logger.info(f"βœ… {lang} TTS model loaded on {device}")
92
+ except Exception as e:
93
+ logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
94
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
95
+ tts_models[lang] = None
96
+
97
+ # Initialize translation models
98
+ for model_key, model_id in TRANSLATION_MODELS.items():
99
+ logger.info(f"πŸ”„ Loading Translation model: {model_id}")
100
+
101
+ try:
102
+ translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
103
+ model_id,
104
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
105
+ )
106
+ logger.info(f"βœ… Translation tokenizer loaded successfully for {model_key}")
107
+
108
+ translation_models[model_key] = MarianMTModel.from_pretrained(
109
+ model_id,
110
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
111
+ )
112
+ translation_models[model_key].to(device)
113
+ logger.info(f"βœ… Translation model loaded successfully on {device} for {model_key}")
114
+ except Exception as e:
115
+ logger.error(f"❌ Error loading Translation model for {model_key}: {str(e)}")
116
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
117
+ translation_models[model_key] = None
118
+ translation_tokenizers[model_key] = None
119
+
120
+
121
+ def check_model_status():
122
+ """Check and return the status of all models"""
123
+ # Initialize direct language pair statuses based on loaded models
124
+ translation_status = {}
125
+
126
+ # Add status for direct model pairs
127
+ for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]:
128
+ translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[
129
+ lang_pair] is not None else "failed"
130
+
131
+ # Add special phi model status
132
+ phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed"
133
+ translation_status["pam-fil"] = phi_status
134
+ translation_status["fil-pam"] = phi_status
135
+ translation_status["pam-tgl"] = phi_status # Using phi model but replacing tgl with fil
136
+ translation_status["tgl-pam"] = phi_status # Using phi model but replacing tgl with fil
137
+
138
+ return {
139
+ "asr_model": "loaded" if asr_model is not None else "failed",
140
+ "tts_models": {lang: "loaded" if model is not None else "failed"
141
+ for lang, model in tts_models.items()},
142
+ "translation_models": translation_status
143
+ }
144
+
145
+
146
+ def handle_asr_request(request, output_dir, sample_rate):
147
+ """Handle ASR (Automatic Speech Recognition) requests"""
148
+ if asr_model is None or asr_processor is None:
149
+ logger.error("❌ ASR endpoint called but models aren't loaded")
150
+ return jsonify({"error": "ASR model not available"}), 503
151
+
152
+ try:
153
+ if "audio" not in request.files:
154
+ logger.warning("⚠️ ASR request missing audio file")
155
+ return jsonify({"error": "No audio file uploaded"}), 400
156
+
157
+ audio_file = request.files["audio"]
158
+ language = request.form.get("language", "english").lower()
159
+
160
+ if language not in LANGUAGE_CODES:
161
+ logger.warning(f"⚠️ Unsupported language requested: {language}")
162
+ return jsonify(
163
+ {"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
164
+
165
+ lang_code = LANGUAGE_CODES[language]
166
+ logger.info(f"πŸ”„ Processing {language} audio for ASR")
167
+
168
+ # Save the uploaded file temporarily
169
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
170
+ temp_audio.write(audio_file.read())
171
+ temp_audio_path = temp_audio.name
172
+ logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
173
+
174
+ # Convert to WAV if necessary
175
+ wav_path = temp_audio_path
176
+ if not audio_file.filename.lower().endswith(".wav"):
177
+ wav_path = os.path.join(output_dir, "converted_audio.wav")
178
+ logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
179
+ try:
180
+ audio = AudioSegment.from_file(temp_audio_path)
181
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
182
+ audio.export(wav_path, format="wav")
183
+ except Exception as e:
184
+ logger.error(f"❌ Audio conversion failed: {str(e)}")
185
+ return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
186
+
187
+ # Load and process the WAV file
188
+ try:
189
+ waveform, sr = torchaudio.load(wav_path)
190
+ logger.debug(f"βœ… Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
191
+
192
+ # Resample if needed
193
+ if sr != sample_rate:
194
+ logger.info(f"πŸ”„ Resampling audio from {sr}Hz to {sample_rate}Hz")
195
+ waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
196
+
197
+ waveform = waveform / torch.max(torch.abs(waveform))
198
+ except Exception as e:
199
+ logger.error(f"❌ Failed to load or process audio: {str(e)}")
200
+ return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
201
+
202
+ # Process audio for ASR
203
+ try:
204
+ inputs = asr_processor(
205
+ waveform.squeeze().numpy(),
206
+ sampling_rate=sample_rate,
207
+ return_tensors="pt",
208
+ language=lang_code
209
+ )
210
+ inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
211
+ except Exception as e:
212
+ logger.error(f"❌ ASR preprocessing failed: {str(e)}")
213
+ return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
214
+
215
+ # Perform ASR
216
+ try:
217
+ with torch.no_grad():
218
+ logits = asr_model(**inputs).logits
219
+ ids = torch.argmax(logits, dim=-1)[0]
220
+ transcription = asr_processor.decode(ids)
221
+
222
+ logger.info(f"βœ… Transcription ({language}): {transcription}")
223
+
224
+ # Clean up temp files
225
+ try:
226
+ os.unlink(temp_audio_path)
227
+ if wav_path != temp_audio_path:
228
+ os.unlink(wav_path)
229
+ except Exception as e:
230
+ logger.warning(f"⚠️ Failed to clean up temp files: {str(e)}")
231
+
232
+ return jsonify({
233
+ "transcription": transcription,
234
+ "language": language,
235
+ "language_code": lang_code
236
+ })
237
+ except Exception as e:
238
+ logger.error(f"❌ ASR inference failed: {str(e)}")
239
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
240
+ return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
241
+
242
+ except Exception as e:
243
+ logger.error(f"❌ Unhandled exception in ASR endpoint: {str(e)}")
244
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
245
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500
246
+
247
+ def handle_tts_request(request, output_dir):
248
+ """Handle TTS (Text-to-Speech) requests"""
249
+ try:
250
+ data = request.get_json()
251
+ if not data:
252
+ logger.warning("⚠️ TTS endpoint called with no JSON data")
253
+ return jsonify({"error": "No JSON data provided"}), 400
254
+
255
+ text_input = data.get("text", "").strip()
256
+ language = data.get("language", "kapampangan").lower()
257
+
258
+ if not text_input:
259
+ logger.warning("⚠️ TTS request with empty text")
260
+ return jsonify({"error": "No text provided"}), 400
261
+
262
+ if language not in TTS_MODELS:
263
+ logger.warning(f"⚠️ TTS requested for unsupported language: {language}")
264
+ return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
265
+
266
+ if tts_models[language] is None:
267
+ logger.error(f"❌ TTS model for {language} not loaded")
268
+ return jsonify({"error": f"TTS model for {language} not available"}), 503
269
+
270
+ logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
271
+
272
+ try:
273
+ processor = tts_processors[language]
274
+ model = tts_models[language]
275
+ inputs = processor(text_input, return_tensors="pt")
276
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
277
+ except Exception as e:
278
+ logger.error(f"❌ TTS preprocessing failed: {str(e)}")
279
+ return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
280
+
281
+ # Generate speech
282
+ try:
283
+ with torch.no_grad():
284
+ output = model(**inputs).waveform
285
+ waveform = output.squeeze().cpu().numpy()
286
+ except Exception as e:
287
+ logger.error(f"❌ TTS inference failed: {str(e)}")
288
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
289
+ return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
290
+
291
+ # Save to file
292
+ try:
293
+ output_filename = os.path.join(output_dir, f"{language}_output.wav")
294
+ sampling_rate = model.config.sampling_rate
295
+ sf.write(output_filename, waveform, sampling_rate)
296
+ logger.info(f"βœ… Speech generated! File saved: {output_filename}")
297
+ except Exception as e:
298
+ logger.error(f"❌ Failed to save audio file: {str(e)}")
299
+ return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
300
+
301
+ return jsonify({
302
+ "message": "TTS audio generated",
303
+ "file_url": f"/download/{os.path.basename(output_filename)}",
304
+ "language": language,
305
+ "text_length": len(text_input)
306
+ })
307
+ except Exception as e:
308
+ logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
309
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
310
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500
311
+
312
+ def handle_translation_request(request):
313
+ """Handle translation requests"""
314
+ try:
315
+ data = request.get_json()
316
+ if not data:
317
+ logger.warning("⚠️ Translation endpoint called with no JSON data")
318
+ return jsonify({"error": "No JSON data provided"}), 400
319
+
320
+ source_text = data.get("text", "").strip()
321
+ source_language = data.get("source_language", "").lower()
322
+ target_language = data.get("target_language", "").lower()
323
+
324
+ if not source_text:
325
+ logger.warning("⚠️ Translation request with empty text")
326
+ return jsonify({"error": "No text provided"}), 400
327
+
328
+ # Map language names to codes
329
+ source_code = LANGUAGE_CODES.get(source_language, source_language)
330
+ target_code = LANGUAGE_CODES.get(target_language, target_language)
331
+
332
+ logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
333
+
334
+ # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
335
+ use_phi_model = False
336
+ actual_source_code = source_code
337
+ actual_target_code = target_code
338
+
339
+ # Check if we need to use the phi model with fil replacement
340
+ if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
341
+ use_phi_model = True
342
+ elif (source_code == "pam" and target_code == "tgl"):
343
+ use_phi_model = True
344
+ actual_target_code = "fil" # Replace tgl with fil for the phi model
345
+ elif (source_code == "tgl" and target_code == "pam"):
346
+ use_phi_model = True
347
+ actual_source_code = "fil" # Replace tgl with fil for the phi model
348
+
349
+ if use_phi_model:
350
+ model_key = "phi"
351
+
352
+ # Check if we have the phi model
353
+ if model_key not in translation_models or translation_models[model_key] is None:
354
+ logger.error(f"❌ Translation model for {model_key} not loaded")
355
+ return jsonify({"error": f"Translation model not available"}), 503
356
+
357
+ try:
358
+ # Get the phi model and tokenizer
359
+ model = translation_models[model_key]
360
+ tokenizer = translation_tokenizers[model_key]
361
+
362
+ # Prepend target language token to input
363
+ input_text = f">>{actual_target_code}<< {source_text}"
364
+
365
+ logger.info(f"πŸ”„ Using phi model with input: '{input_text}'")
366
+
367
+ # Tokenize the text
368
+ tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
369
+ tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
370
+
371
+ # Generate translation
372
+ with torch.no_grad():
373
+ translated = model.generate(**tokenized)
374
+
375
+ # Decode the translation
376
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
377
+
378
+ logger.info(f"βœ… Translation result: '{result}'")
379
+
380
+ return jsonify({
381
+ "translated_text": result,
382
+ "source_language": source_language,
383
+ "target_language": target_language
384
+ })
385
+ except Exception as e:
386
+ logger.error(f"❌ Translation processing failed: {str(e)}")
387
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
388
+ return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
389
+ else:
390
+ # Create the regular language pair key for other language pairs
391
+ lang_pair = f"{source_code}-{target_code}"
392
+
393
+ # Check if we have a model for this language pair
394
+ if lang_pair not in translation_models:
395
+ logger.warning(f"⚠️ No translation model available for {lang_pair}")
396
+ return jsonify(
397
+ {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
398
+
399
+ if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
400
+ logger.error(f"❌ Translation model for {lang_pair} not loaded")
401
+ return jsonify({"error": f"Translation model not available"}), 503
402
+
403
+ try:
404
+ # Regular translation process for other language pairs
405
+ model = translation_models[lang_pair]
406
+ tokenizer = translation_tokenizers[lang_pair]
407
+
408
+ # Tokenize the text
409
+ tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
410
+ tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
411
+
412
+ # Generate translation
413
+ with torch.no_grad():
414
+ translated = model.generate(**tokenized)
415
+
416
+ # Decode the translation
417
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
418
+
419
+ logger.info(f"βœ… Translation result: '{result}'")
420
+
421
+ return jsonify({
422
+ "translated_text": result,
423
+ "source_language": source_language,
424
+ "target_language": target_language
425
+ })
426
+ except Exception as e:
427
+ logger.error(f"❌ Translation processing failed: {str(e)}")
428
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
429
+ return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
430
+
431
+ except Exception as e:
432
+ logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")
433
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
434
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500
435
+
436
+ def get_asr_model():
437
+ return asr_model
438
+
439
+ def get_asr_processor():
440
+ return asr_processor