Coco-18 commited on
Commit
d39ccee
Β·
verified Β·
1 Parent(s): 203cc78

Update translator.py

Browse files
Files changed (1) hide show
  1. translator.py +336 -173
translator.py CHANGED
@@ -1,4 +1,4 @@
1
- # translator.py - Handles ASR, TTS, and translation tasks
2
 
3
  import os
4
  import sys
@@ -12,6 +12,11 @@ from pydub import AudioSegment
12
  from flask import jsonify
13
  from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
14
  from transformers import MarianMTModel, MarianTokenizer
 
 
 
 
 
15
 
16
  # Configure logging
17
  logger = logging.getLogger("speech_api")
@@ -24,6 +29,16 @@ tts_processors = {}
24
  translation_models = {}
25
  translation_tokenizers = {}
26
 
 
 
 
 
 
 
 
 
 
 
27
  # Language-specific configurations
28
  LANGUAGE_CODES = {
29
  "kapampangan": "pam",
@@ -48,74 +63,114 @@ TRANSLATION_MODELS = {
48
  "phi": "Coco-18/opus-mt-phi"
49
  }
50
 
 
 
 
 
51
  def init_models(device):
52
- """Initialize all models required for the API"""
53
  global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
54
-
55
- # Initialize ASR model
56
- ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
57
- logger.info(f"πŸ”„ Loading ASR model: {ASR_MODEL_ID}")
58
-
59
- try:
60
- asr_processor = AutoProcessor.from_pretrained(
61
- ASR_MODEL_ID,
62
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
63
- )
64
- logger.info("βœ… ASR processor loaded successfully")
65
-
66
- asr_model = Wav2Vec2ForCTC.from_pretrained(
67
- ASR_MODEL_ID,
68
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
69
- )
70
- asr_model.to(device)
71
- logger.info(f"βœ… ASR model loaded successfully on {device}")
72
- except Exception as e:
73
- logger.error(f"❌ Error loading ASR model: {str(e)}")
74
- logger.debug(f"Stack trace: {traceback.format_exc()}")
75
-
76
- # Initialize TTS models
77
- for lang, model_id in TTS_MODELS.items():
78
- logger.info(f"πŸ”„ Loading TTS model for {lang}: {model_id}")
 
79
  try:
80
- tts_processors[lang] = AutoTokenizer.from_pretrained(
81
  model_id,
82
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
83
  )
84
- logger.info(f"βœ… {lang} TTS processor loaded")
85
-
86
- tts_models[lang] = VitsModel.from_pretrained(
87
  model_id,
88
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
89
  )
90
- tts_models[lang].to(device)
91
  logger.info(f"βœ… {lang} TTS model loaded on {device}")
 
92
  except Exception as e:
93
  logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
94
  logger.debug(f"Stack trace: {traceback.format_exc()}")
95
- tts_models[lang] = None
96
-
97
- # Initialize translation models
98
- for model_key, model_id in TRANSLATION_MODELS.items():
99
- logger.info(f"πŸ”„ Loading Translation model: {model_id}")
100
-
101
  try:
102
- translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
103
  model_id,
104
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
105
  )
106
- logger.info(f"βœ… Translation tokenizer loaded successfully for {model_key}")
107
-
108
- translation_models[model_key] = MarianMTModel.from_pretrained(
109
  model_id,
110
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
111
  )
112
- translation_models[model_key].to(device)
113
  logger.info(f"βœ… Translation model loaded successfully on {device} for {model_key}")
 
114
  except Exception as e:
115
  logger.error(f"❌ Error loading Translation model for {model_key}: {str(e)}")
116
  logger.debug(f"Stack trace: {traceback.format_exc()}")
117
- translation_models[model_key] = None
118
- translation_tokenizers[model_key] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  def check_model_status():
@@ -142,9 +197,50 @@ def check_model_status():
142
  "translation_models": translation_status
143
  }
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def handle_asr_request(request, output_dir, sample_rate):
147
- """Handle ASR (Automatic Speech Recognition) requests"""
148
  if asr_model is None or asr_processor is None:
149
  logger.error("❌ ASR endpoint called but models aren't loaded")
150
  return jsonify({"error": "ASR model not available"}), 503
@@ -165,44 +261,40 @@ def handle_asr_request(request, output_dir, sample_rate):
165
  lang_code = LANGUAGE_CODES[language]
166
  logger.info(f"πŸ”„ Processing {language} audio for ASR")
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # Save the uploaded file temporarily
169
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
170
- temp_audio.write(audio_file.read())
171
  temp_audio_path = temp_audio.name
172
  logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
173
 
174
- # Convert to WAV if necessary
175
- wav_path = temp_audio_path
176
- if not audio_file.filename.lower().endswith(".wav"):
177
- wav_path = os.path.join(output_dir, "converted_audio.wav")
178
- logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
179
- try:
180
- audio = AudioSegment.from_file(temp_audio_path)
181
- audio = audio.set_frame_rate(sample_rate).set_channels(1)
182
- audio.export(wav_path, format="wav")
183
- except Exception as e:
184
- logger.error(f"❌ Audio conversion failed: {str(e)}")
185
- return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
186
-
187
- # Load and process the WAV file
188
  try:
189
- waveform, sr = torchaudio.load(wav_path)
190
- logger.debug(f"βœ… Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
191
-
192
- # Resample if needed
193
- if sr != sample_rate:
194
- logger.info(f"πŸ”„ Resampling audio from {sr}Hz to {sample_rate}Hz")
195
- waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
196
-
197
- waveform = waveform / torch.max(torch.abs(waveform))
198
  except Exception as e:
199
- logger.error(f"❌ Failed to load or process audio: {str(e)}")
200
- return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
201
 
202
  # Process audio for ASR
203
  try:
204
  inputs = asr_processor(
205
- waveform.squeeze().numpy(),
206
  sampling_rate=sample_rate,
207
  return_tensors="pt",
208
  language=lang_code
@@ -220,6 +312,14 @@ def handle_asr_request(request, output_dir, sample_rate):
220
  transcription = asr_processor.decode(ids)
221
 
222
  logger.info(f"βœ… Transcription ({language}): {transcription}")
 
 
 
 
 
 
 
 
223
 
224
  # Clean up temp files
225
  try:
@@ -232,7 +332,8 @@ def handle_asr_request(request, output_dir, sample_rate):
232
  return jsonify({
233
  "transcription": transcription,
234
  "language": language,
235
- "language_code": lang_code
 
236
  })
237
  except Exception as e:
238
  logger.error(f"❌ ASR inference failed: {str(e)}")
@@ -244,8 +345,14 @@ def handle_asr_request(request, output_dir, sample_rate):
244
  logger.debug(f"Stack trace: {traceback.format_exc()}")
245
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
246
 
 
 
 
 
 
 
247
  def handle_tts_request(request, output_dir):
248
- """Handle TTS (Text-to-Speech) requests"""
249
  try:
250
  data = request.get_json()
251
  if not data:
@@ -268,7 +375,57 @@ def handle_tts_request(request, output_dir):
268
  return jsonify({"error": f"TTS model for {language} not available"}), 503
269
 
270
  logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  try:
273
  processor = tts_processors[language]
274
  model = tts_models[language]
@@ -290,10 +447,22 @@ def handle_tts_request(request, output_dir):
290
 
291
  # Save to file
292
  try:
293
- output_filename = os.path.join(output_dir, f"{language}_output.wav")
294
  sampling_rate = model.config.sampling_rate
295
  sf.write(output_filename, waveform, sampling_rate)
296
  logger.info(f"βœ… Speech generated! File saved: {output_filename}")
 
 
 
 
 
 
 
 
 
 
 
 
297
  except Exception as e:
298
  logger.error(f"❌ Failed to save audio file: {str(e)}")
299
  return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
@@ -302,15 +471,22 @@ def handle_tts_request(request, output_dir):
302
  "message": "TTS audio generated",
303
  "file_url": f"/download/{os.path.basename(output_filename)}",
304
  "language": language,
305
- "text_length": len(text_input)
 
306
  })
307
  except Exception as e:
308
  logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
309
  logger.debug(f"Stack trace: {traceback.format_exc()}")
310
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
311
 
 
 
 
 
 
 
312
  def handle_translation_request(request):
313
- """Handle translation requests"""
314
  try:
315
  data = request.get_json()
316
  if not data:
@@ -330,110 +506,97 @@ def handle_translation_request(request):
330
  target_code = LANGUAGE_CODES.get(target_language, target_language)
331
 
332
  logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
- # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
335
- use_phi_model = False
336
  actual_source_code = source_code
337
  actual_target_code = target_code
338
-
339
- # Check if we need to use the phi model with fil replacement
340
- if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
341
- use_phi_model = True
342
- elif (source_code == "pam" and target_code == "tgl"):
343
- use_phi_model = True
344
- actual_target_code = "fil" # Replace tgl with fil for the phi model
345
- elif (source_code == "tgl" and target_code == "pam"):
346
- use_phi_model = True
347
- actual_source_code = "fil" # Replace tgl with fil for the phi model
348
-
349
- if use_phi_model:
350
  model_key = "phi"
351
-
352
- # Check if we have the phi model
353
- if model_key not in translation_models or translation_models[model_key] is None:
354
- logger.error(f"❌ Translation model for {model_key} not loaded")
355
- return jsonify({"error": f"Translation model not available"}), 503
356
-
357
- try:
358
- # Get the phi model and tokenizer
359
- model = translation_models[model_key]
360
- tokenizer = translation_tokenizers[model_key]
361
-
362
- # Prepend target language token to input
363
- input_text = f">>{actual_target_code}<< {source_text}"
364
-
365
- logger.info(f"πŸ”„ Using phi model with input: '{input_text}'")
366
-
367
- # Tokenize the text
368
- tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
369
- tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
370
-
371
- with torch.no_grad():
372
- translated = model.generate(
373
- **tokenized,
374
- max_length=100, # Reasonable output length
375
- num_beams=4, # Same as in training
376
- length_penalty=0.6, # Same as in training
377
- early_stopping=True, # Same as in training
378
- repetition_penalty=1.5, # Add this to prevent repetition
379
- no_repeat_ngram_size=3 # Add this to prevent repetition
380
- )
381
-
382
- # Decode the translation
383
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
384
-
385
- logger.info(f"βœ… Translation result: '{result}'")
386
-
387
- return jsonify({
388
- "translated_text": result,
389
- "source_language": source_language,
390
- "target_language": target_language
391
- })
392
- except Exception as e:
393
- logger.error(f"❌ Translation processing failed: {str(e)}")
394
- logger.debug(f"Stack trace: {traceback.format_exc()}")
395
- return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
396
  else:
397
- # Create the regular language pair key for other language pairs
398
- lang_pair = f"{source_code}-{target_code}"
399
-
400
- # Check if we have a model for this language pair
401
- if lang_pair not in translation_models:
402
- logger.warning(f"⚠️ No translation model available for {lang_pair}")
403
- return jsonify(
404
- {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
405
-
406
- if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
407
- logger.error(f"❌ Translation model for {lang_pair} not loaded")
408
- return jsonify({"error": f"Translation model not available"}), 503
409
-
410
- try:
411
- # Regular translation process for other language pairs
412
- model = translation_models[lang_pair]
413
- tokenizer = translation_tokenizers[lang_pair]
414
-
415
- # Tokenize the text
416
- tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
417
- tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
418
 
419
- # Generate translation
420
- with torch.no_grad():
421
- translated = model.generate(**tokenized)
 
422
 
423
- # Decode the translation
424
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
 
425
 
426
- logger.info(f"βœ… Translation result: '{result}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- return jsonify({
429
- "translated_text": result,
430
- "source_language": source_language,
431
- "target_language": target_language
432
- })
433
- except Exception as e:
434
- logger.error(f"❌ Translation processing failed: {str(e)}")
435
- logger.debug(f"Stack trace: {traceback.format_exc()}")
436
- return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
 
437
 
438
  except Exception as e:
439
  logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")
 
1
+ # translator.py - Handles ASR, TTS, and translation tasks (OPTIMIZED)
2
 
3
  import os
4
  import sys
 
12
  from flask import jsonify
13
  from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
14
  from transformers import MarianMTModel, MarianTokenizer
15
+ import concurrent.futures
16
+ import functools
17
+ import threading
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from functools import lru_cache
20
 
21
  # Configure logging
22
  logger = logging.getLogger("speech_api")
 
29
  translation_models = {}
30
  translation_tokenizers = {}
31
 
32
+ # Caching dictionaries
33
+ asr_cache = {}
34
+ tts_cache = {}
35
+ translation_cache = {}
36
+
37
+ # Mutex locks for thread safety
38
+ asr_lock = threading.Lock()
39
+ tts_lock = threading.Lock()
40
+ translation_lock = threading.Lock()
41
+
42
  # Language-specific configurations
43
  LANGUAGE_CODES = {
44
  "kapampangan": "pam",
 
63
  "phi": "Coco-18/opus-mt-phi"
64
  }
65
 
66
+ # Cache settings
67
+ MAX_CACHE_SIZE = 100 # Maximum number of items to cache
68
+ CACHE_TTL = 3600 # Time to live in seconds (1 hour)
69
+
70
  def init_models(device):
71
+ """Initialize all models required for the API with parallelization"""
72
  global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
73
+
74
+ logger.info("πŸ”„ Starting parallel model initialization")
75
+
76
+ # Define model initialization functions
77
+ def init_asr():
78
+ global asr_model, asr_processor
79
+ ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
80
+ try:
81
+ asr_processor = AutoProcessor.from_pretrained(
82
+ ASR_MODEL_ID,
83
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
84
+ )
85
+
86
+ asr_model = Wav2Vec2ForCTC.from_pretrained(
87
+ ASR_MODEL_ID,
88
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
89
+ )
90
+ asr_model.to(device)
91
+ logger.info(f"βœ… ASR model loaded successfully on {device}")
92
+ return True
93
+ except Exception as e:
94
+ logger.error(f"❌ Error loading ASR model: {str(e)}")
95
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
96
+ return False
97
+
98
+ def init_tts(lang, model_id):
99
  try:
100
+ processor = AutoTokenizer.from_pretrained(
101
  model_id,
102
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
103
  )
104
+
105
+ model = VitsModel.from_pretrained(
 
106
  model_id,
107
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
108
  )
109
+ model.to(device)
110
  logger.info(f"βœ… {lang} TTS model loaded on {device}")
111
+ return lang, processor, model
112
  except Exception as e:
113
  logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
114
  logger.debug(f"Stack trace: {traceback.format_exc()}")
115
+ return lang, None, None
116
+
117
+ def init_translation(model_key, model_id):
 
 
 
118
  try:
119
+ tokenizer = MarianTokenizer.from_pretrained(
120
  model_id,
121
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
122
  )
123
+
124
+ model = MarianMTModel.from_pretrained(
 
125
  model_id,
126
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
127
  )
128
+ model.to(device)
129
  logger.info(f"βœ… Translation model loaded successfully on {device} for {model_key}")
130
+ return model_key, tokenizer, model
131
  except Exception as e:
132
  logger.error(f"❌ Error loading Translation model for {model_key}: {str(e)}")
133
  logger.debug(f"Stack trace: {traceback.format_exc()}")
134
+ return model_key, None, None
135
+
136
+ # Use ThreadPoolExecutor to initialize models in parallel
137
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
138
+ # Start ASR model initialization
139
+ asr_future = executor.submit(init_asr)
140
+
141
+ # Start TTS model initialization in parallel
142
+ tts_futures = {
143
+ executor.submit(init_tts, lang, model_id): lang
144
+ for lang, model_id in TTS_MODELS.items()
145
+ }
146
+
147
+ # Start translation model initialization in parallel
148
+ translation_futures = {
149
+ executor.submit(init_translation, model_key, model_id): model_key
150
+ for model_key, model_id in TRANSLATION_MODELS.items()
151
+ }
152
+
153
+ # Wait for all futures to complete and process results
154
+
155
+ # Process TTS results
156
+ for future in concurrent.futures.as_completed(tts_futures):
157
+ lang, processor, model = future.result()
158
+ if processor is not None and model is not None:
159
+ tts_processors[lang] = processor
160
+ tts_models[lang] = model
161
+
162
+ # Process translation results
163
+ for future in concurrent.futures.as_completed(translation_futures):
164
+ model_key, tokenizer, model = future.result()
165
+ if tokenizer is not None and model is not None:
166
+ translation_tokenizers[model_key] = tokenizer
167
+ translation_models[model_key] = model
168
+
169
+ # Log summary of loaded models
170
+ logger.info("πŸ“Š Model initialization summary:")
171
+ logger.info(f" - ASR model: {'loaded' if asr_model is not None else 'failed'}")
172
+ logger.info(f" - TTS models loaded: {sum(1 for m in tts_models.values() if m is not None)}/{len(TTS_MODELS)}")
173
+ logger.info(f" - Translation models loaded: {sum(1 for m in translation_models.values() if m is not None)}/{len(TRANSLATION_MODELS)}")
174
 
175
 
176
  def check_model_status():
 
197
  "translation_models": translation_status
198
  }
199
 
200
+ # Cache for ASR results
201
+ @lru_cache(maxsize=MAX_CACHE_SIZE)
202
+ def get_cached_transcription(file_hash, language_code):
203
+ """Retrieve cached transcription result if available"""
204
+ return asr_cache.get((file_hash, language_code))
205
+
206
+ def process_audio_file(audio_data, temp_audio_path, output_dir, sample_rate):
207
+ """Process audio file for ASR (separate from ASR logic)"""
208
+ wav_path = temp_audio_path
209
+
210
+ if not temp_audio_path.lower().endswith(".wav"):
211
+ wav_path = os.path.join(output_dir, "converted_audio.wav")
212
+ logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
213
+ try:
214
+ audio = AudioSegment.from_file(temp_audio_path)
215
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
216
+ audio.export(wav_path, format="wav")
217
+ except Exception as e:
218
+ logger.error(f"❌ Audio conversion failed: {str(e)}")
219
+ raise Exception(f"Audio conversion failed: {str(e)}")
220
+
221
+ # Load and process the WAV file
222
+ try:
223
+ waveform, sr = torchaudio.load(wav_path)
224
+
225
+ # Resample if needed
226
+ if sr != sample_rate:
227
+ waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
228
+
229
+ # Normalize waveform
230
+ waveform = waveform / torch.max(torch.abs(waveform))
231
+
232
+ return waveform.squeeze().numpy(), wav_path
233
+ except Exception as e:
234
+ logger.error(f"❌ Failed to load or process audio: {str(e)}")
235
+ raise Exception(f"Audio processing failed: {str(e)}")
236
+
237
+ def compute_audio_hash(audio_data):
238
+ """Compute a hash of audio data for caching purposes"""
239
+ import hashlib
240
+ return hashlib.md5(audio_data).hexdigest()
241
 
242
  def handle_asr_request(request, output_dir, sample_rate):
243
+ """Handle ASR (Automatic Speech Recognition) requests with optimization"""
244
  if asr_model is None or asr_processor is None:
245
  logger.error("❌ ASR endpoint called but models aren't loaded")
246
  return jsonify({"error": "ASR model not available"}), 503
 
261
  lang_code = LANGUAGE_CODES[language]
262
  logger.info(f"πŸ”„ Processing {language} audio for ASR")
263
 
264
+ # Read the file content for hashing
265
+ audio_content = audio_file.read()
266
+ audio_hash = compute_audio_hash(audio_content)
267
+
268
+ # Check cache first
269
+ with asr_lock:
270
+ cached_result = asr_cache.get((audio_hash, lang_code))
271
+ if cached_result:
272
+ logger.info(f"βœ… Using cached ASR result for {language}")
273
+ return jsonify({
274
+ "transcription": cached_result,
275
+ "language": language,
276
+ "language_code": lang_code,
277
+ "from_cache": True
278
+ })
279
+
280
  # Save the uploaded file temporarily
281
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
282
+ temp_audio.write(audio_content)
283
  temp_audio_path = temp_audio.name
284
  logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
285
 
286
+ # Process audio in a separate thread/process
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  try:
288
+ with ThreadPoolExecutor(max_workers=2) as executor:
289
+ future = executor.submit(process_audio_file, audio_content, temp_audio_path, output_dir, sample_rate)
290
+ waveform, wav_path = future.result()
 
 
 
 
 
 
291
  except Exception as e:
292
+ return jsonify({"error": str(e)}), 500
 
293
 
294
  # Process audio for ASR
295
  try:
296
  inputs = asr_processor(
297
+ waveform,
298
  sampling_rate=sample_rate,
299
  return_tensors="pt",
300
  language=lang_code
 
312
  transcription = asr_processor.decode(ids)
313
 
314
  logger.info(f"βœ… Transcription ({language}): {transcription}")
315
+
316
+ # Cache the result
317
+ with asr_lock:
318
+ asr_cache[(audio_hash, lang_code)] = transcription
319
+ # Implement cache size limitation if needed
320
+ if len(asr_cache) > MAX_CACHE_SIZE:
321
+ # Remove oldest entry (simplified approach)
322
+ asr_cache.pop(next(iter(asr_cache)))
323
 
324
  # Clean up temp files
325
  try:
 
332
  return jsonify({
333
  "transcription": transcription,
334
  "language": language,
335
+ "language_code": lang_code,
336
+ "from_cache": False
337
  })
338
  except Exception as e:
339
  logger.error(f"❌ ASR inference failed: {str(e)}")
 
345
  logger.debug(f"Stack trace: {traceback.format_exc()}")
346
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
347
 
348
+ # Cache key generator for TTS
349
+ def tts_cache_key(text, language):
350
+ """Generate a cache key for TTS results"""
351
+ import hashlib
352
+ return hashlib.md5(f"{text}:{language}".encode()).hexdigest()
353
+
354
  def handle_tts_request(request, output_dir):
355
+ """Handle TTS (Text-to-Speech) requests with optimization"""
356
  try:
357
  data = request.get_json()
358
  if not data:
 
375
  return jsonify({"error": f"TTS model for {language} not available"}), 503
376
 
377
  logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
378
+
379
+ # Generate cache key
380
+ cache_key = tts_cache_key(text_input, language)
381
+
382
+ # Check cache
383
+ with tts_lock:
384
+ cached_file = tts_cache.get(cache_key)
385
+ if cached_file and os.path.exists(cached_file):
386
+ logger.info(f"βœ… Using cached TTS audio for: '{text_input}'")
387
+ return jsonify({
388
+ "message": "TTS audio retrieved from cache",
389
+ "file_url": f"/download/{os.path.basename(cached_file)}",
390
+ "language": language,
391
+ "text_length": len(text_input),
392
+ "from_cache": True
393
+ })
394
 
395
+ # Chunk text if too long (optional optimization for very long texts)
396
+ MAX_TEXT_LENGTH = 200 # Maximum text length to process in one go
397
+
398
+ if len(text_input) > MAX_TEXT_LENGTH:
399
+ # Simple chunking by splitting on periods
400
+ chunks = []
401
+ current_chunk = ""
402
+
403
+ for sentence in text_input.split("."):
404
+ if len(current_chunk) + len(sentence) < MAX_TEXT_LENGTH:
405
+ current_chunk += sentence + "."
406
+ else:
407
+ if current_chunk:
408
+ chunks.append(current_chunk)
409
+ current_chunk = sentence + "."
410
+
411
+ if current_chunk:
412
+ chunks.append(current_chunk)
413
+
414
+ logger.info(f"πŸ”„ Text chunked into {len(chunks)} parts for processing")
415
+
416
+ # Process chunks and combine results
417
+ try:
418
+ processor = tts_processors[language]
419
+ model = tts_models[language]
420
+
421
+ # For simplicity, we'll just use the first chunk in this example
422
+ # A full implementation would process all chunks and concatenate audio
423
+ text_input = chunks[0]
424
+ logger.info(f"⚠️ Using only the first chunk for demonstration: '{text_input}'")
425
+ except Exception as e:
426
+ logger.error(f"❌ TTS chunking failed: {str(e)}")
427
+ return jsonify({"error": f"TTS chunking failed: {str(e)}"}), 500
428
+
429
  try:
430
  processor = tts_processors[language]
431
  model = tts_models[language]
 
447
 
448
  # Save to file
449
  try:
450
+ output_filename = os.path.join(output_dir, f"{language}_{cache_key}.wav")
451
  sampling_rate = model.config.sampling_rate
452
  sf.write(output_filename, waveform, sampling_rate)
453
  logger.info(f"βœ… Speech generated! File saved: {output_filename}")
454
+
455
+ # Cache the result
456
+ with tts_lock:
457
+ tts_cache[cache_key] = output_filename
458
+ # Implement cache size limitation if needed
459
+ if len(tts_cache) > MAX_CACHE_SIZE:
460
+ oldest_key = next(iter(tts_cache))
461
+ try:
462
+ os.remove(tts_cache[oldest_key])
463
+ except:
464
+ pass
465
+ tts_cache.pop(oldest_key)
466
  except Exception as e:
467
  logger.error(f"❌ Failed to save audio file: {str(e)}")
468
  return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
 
471
  "message": "TTS audio generated",
472
  "file_url": f"/download/{os.path.basename(output_filename)}",
473
  "language": language,
474
+ "text_length": len(text_input),
475
+ "from_cache": False
476
  })
477
  except Exception as e:
478
  logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
479
  logger.debug(f"Stack trace: {traceback.format_exc()}")
480
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
481
 
482
+ # Cache key generator for translation
483
+ def translation_cache_key(text, source_lang, target_lang):
484
+ """Generate a cache key for translation results"""
485
+ import hashlib
486
+ return hashlib.md5(f"{text}:{source_lang}:{target_lang}".encode()).hexdigest()
487
+
488
  def handle_translation_request(request):
489
+ """Handle translation requests with optimization"""
490
  try:
491
  data = request.get_json()
492
  if not data:
 
506
  target_code = LANGUAGE_CODES.get(target_language, target_language)
507
 
508
  logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
509
+
510
+ # Generate cache key
511
+ cache_key = translation_cache_key(source_text, source_code, target_code)
512
+
513
+ # Check cache
514
+ with translation_lock:
515
+ cached_result = translation_cache.get(cache_key)
516
+ if cached_result:
517
+ logger.info(f"βœ… Using cached translation result")
518
+ return jsonify({
519
+ "translated_text": cached_result,
520
+ "source_language": source_language,
521
+ "target_language": target_language,
522
+ "from_cache": True
523
+ })
524
 
525
+ # OPTIMIZED: Simplified language pair determination logic
526
+ model_key = None
527
  actual_source_code = source_code
528
  actual_target_code = target_code
529
+ input_text = source_text
530
+
531
+ # Determine which model to use with simplified logic
532
+ if f"{source_code}-{target_code}" in translation_models:
533
+ # Direct model exists
534
+ model_key = f"{source_code}-{target_code}"
535
+ use_phi_model = False
536
+ elif (source_code in ["pam", "fil", "tgl"] and target_code in ["pam", "fil", "tgl"]):
537
+ # Use phi model with appropriate substitutions
 
 
 
538
  model_key = "phi"
539
+ use_phi_model = True
540
+ # Replace tgl with fil for the phi model if needed
541
+ if source_code == "tgl": actual_source_code = "fil"
542
+ if target_code == "tgl": actual_target_code = "fil"
543
+ # Prepare input text for phi model
544
+ input_text = f">>{actual_target_code}<< {source_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  else:
546
+ logger.warning(f"⚠️ No translation model available for {source_code}-{target_code}")
547
+ return jsonify(
548
+ {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
549
+
550
+ # Check if model exists and is loaded
551
+ if model_key not in translation_models or translation_models[model_key] is None:
552
+ logger.error(f"❌ Translation model for {model_key} not loaded")
553
+ return jsonify({"error": f"Translation model not available"}), 503
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
+ try:
556
+ # Get the model and tokenizer
557
+ model = translation_models[model_key]
558
+ tokenizer = translation_tokenizers[model_key]
559
 
560
+ # Tokenize the text
561
+ tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
562
+ tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
563
 
564
+ # Apply length-based optimizations
565
+ max_length = min(100, len(source_text.split()) * 2) # Adaptive length
566
+
567
+ with torch.no_grad():
568
+ translated = model.generate(
569
+ **tokenized,
570
+ max_length=max_length,
571
+ num_beams=4,
572
+ length_penalty=0.6,
573
+ early_stopping=True,
574
+ repetition_penalty=1.5,
575
+ no_repeat_ngram_size=3
576
+ )
577
+
578
+ # Decode the translation
579
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
580
+
581
+ logger.info(f"βœ… Translation result: '{result}'")
582
+
583
+ # Cache the result
584
+ with translation_lock:
585
+ translation_cache[cache_key] = result
586
+ # Implement cache size limitation if needed
587
+ if len(translation_cache) > MAX_CACHE_SIZE:
588
+ translation_cache.pop(next(iter(translation_cache)))
589
 
590
+ return jsonify({
591
+ "translated_text": result,
592
+ "source_language": source_language,
593
+ "target_language": target_language,
594
+ "from_cache": False
595
+ })
596
+ except Exception as e:
597
+ logger.error(f"❌ Translation processing failed: {str(e)}")
598
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
599
+ return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
600
 
601
  except Exception as e:
602
  logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")