Coco-18 commited on
Commit
aa906c3
Β·
verified Β·
1 Parent(s): fec0be4

Update translator.py

Browse files
Files changed (1) hide show
  1. translator.py +173 -336
translator.py CHANGED
@@ -1,4 +1,4 @@
1
- # translator.py - Handles ASR, TTS, and translation tasks (OPTIMIZED)
2
 
3
  import os
4
  import sys
@@ -12,11 +12,6 @@ from pydub import AudioSegment
12
  from flask import jsonify
13
  from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
14
  from transformers import MarianMTModel, MarianTokenizer
15
- import concurrent.futures
16
- import functools
17
- import threading
18
- from concurrent.futures import ThreadPoolExecutor
19
- from functools import lru_cache
20
 
21
  # Configure logging
22
  logger = logging.getLogger("speech_api")
@@ -29,16 +24,6 @@ tts_processors = {}
29
  translation_models = {}
30
  translation_tokenizers = {}
31
 
32
- # Caching dictionaries
33
- asr_cache = {}
34
- tts_cache = {}
35
- translation_cache = {}
36
-
37
- # Mutex locks for thread safety
38
- asr_lock = threading.Lock()
39
- tts_lock = threading.Lock()
40
- translation_lock = threading.Lock()
41
-
42
  # Language-specific configurations
43
  LANGUAGE_CODES = {
44
  "kapampangan": "pam",
@@ -63,114 +48,74 @@ TRANSLATION_MODELS = {
63
  "phi": "Coco-18/opus-mt-phi"
64
  }
65
 
66
- # Cache settings
67
- MAX_CACHE_SIZE = 100 # Maximum number of items to cache
68
- CACHE_TTL = 3600 # Time to live in seconds (1 hour)
69
-
70
  def init_models(device):
71
- """Initialize all models required for the API with parallelization"""
72
  global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
73
-
74
- logger.info("πŸ”„ Starting parallel model initialization")
75
-
76
- # Define model initialization functions
77
- def init_asr():
78
- global asr_model, asr_processor
79
- ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
80
- try:
81
- asr_processor = AutoProcessor.from_pretrained(
82
- ASR_MODEL_ID,
83
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
84
- )
85
-
86
- asr_model = Wav2Vec2ForCTC.from_pretrained(
87
- ASR_MODEL_ID,
88
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
89
- )
90
- asr_model.to(device)
91
- logger.info(f"βœ… ASR model loaded successfully on {device}")
92
- return True
93
- except Exception as e:
94
- logger.error(f"❌ Error loading ASR model: {str(e)}")
95
- logger.debug(f"Stack trace: {traceback.format_exc()}")
96
- return False
97
-
98
- def init_tts(lang, model_id):
99
  try:
100
- processor = AutoTokenizer.from_pretrained(
101
  model_id,
102
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
103
  )
104
-
105
- model = VitsModel.from_pretrained(
 
106
  model_id,
107
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
108
  )
109
- model.to(device)
110
  logger.info(f"βœ… {lang} TTS model loaded on {device}")
111
- return lang, processor, model
112
  except Exception as e:
113
  logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
114
  logger.debug(f"Stack trace: {traceback.format_exc()}")
115
- return lang, None, None
116
-
117
- def init_translation(model_key, model_id):
 
 
 
118
  try:
119
- tokenizer = MarianTokenizer.from_pretrained(
120
  model_id,
121
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
122
  )
123
-
124
- model = MarianMTModel.from_pretrained(
 
125
  model_id,
126
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
127
  )
128
- model.to(device)
129
  logger.info(f"βœ… Translation model loaded successfully on {device} for {model_key}")
130
- return model_key, tokenizer, model
131
  except Exception as e:
132
  logger.error(f"❌ Error loading Translation model for {model_key}: {str(e)}")
133
  logger.debug(f"Stack trace: {traceback.format_exc()}")
134
- return model_key, None, None
135
-
136
- # Use ThreadPoolExecutor to initialize models in parallel
137
- with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
138
- # Start ASR model initialization
139
- asr_future = executor.submit(init_asr)
140
-
141
- # Start TTS model initialization in parallel
142
- tts_futures = {
143
- executor.submit(init_tts, lang, model_id): lang
144
- for lang, model_id in TTS_MODELS.items()
145
- }
146
-
147
- # Start translation model initialization in parallel
148
- translation_futures = {
149
- executor.submit(init_translation, model_key, model_id): model_key
150
- for model_key, model_id in TRANSLATION_MODELS.items()
151
- }
152
-
153
- # Wait for all futures to complete and process results
154
-
155
- # Process TTS results
156
- for future in concurrent.futures.as_completed(tts_futures):
157
- lang, processor, model = future.result()
158
- if processor is not None and model is not None:
159
- tts_processors[lang] = processor
160
- tts_models[lang] = model
161
-
162
- # Process translation results
163
- for future in concurrent.futures.as_completed(translation_futures):
164
- model_key, tokenizer, model = future.result()
165
- if tokenizer is not None and model is not None:
166
- translation_tokenizers[model_key] = tokenizer
167
- translation_models[model_key] = model
168
-
169
- # Log summary of loaded models
170
- logger.info("πŸ“Š Model initialization summary:")
171
- logger.info(f" - ASR model: {'loaded' if asr_model is not None else 'failed'}")
172
- logger.info(f" - TTS models loaded: {sum(1 for m in tts_models.values() if m is not None)}/{len(TTS_MODELS)}")
173
- logger.info(f" - Translation models loaded: {sum(1 for m in translation_models.values() if m is not None)}/{len(TRANSLATION_MODELS)}")
174
 
175
 
176
  def check_model_status():
@@ -197,50 +142,9 @@ def check_model_status():
197
  "translation_models": translation_status
198
  }
199
 
200
- # Cache for ASR results
201
- @lru_cache(maxsize=MAX_CACHE_SIZE)
202
- def get_cached_transcription(file_hash, language_code):
203
- """Retrieve cached transcription result if available"""
204
- return asr_cache.get((file_hash, language_code))
205
-
206
- def process_audio_file(audio_data, temp_audio_path, output_dir, sample_rate):
207
- """Process audio file for ASR (separate from ASR logic)"""
208
- wav_path = temp_audio_path
209
-
210
- if not temp_audio_path.lower().endswith(".wav"):
211
- wav_path = os.path.join(output_dir, "converted_audio.wav")
212
- logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
213
- try:
214
- audio = AudioSegment.from_file(temp_audio_path)
215
- audio = audio.set_frame_rate(sample_rate).set_channels(1)
216
- audio.export(wav_path, format="wav")
217
- except Exception as e:
218
- logger.error(f"❌ Audio conversion failed: {str(e)}")
219
- raise Exception(f"Audio conversion failed: {str(e)}")
220
-
221
- # Load and process the WAV file
222
- try:
223
- waveform, sr = torchaudio.load(wav_path)
224
-
225
- # Resample if needed
226
- if sr != sample_rate:
227
- waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
228
-
229
- # Normalize waveform
230
- waveform = waveform / torch.max(torch.abs(waveform))
231
-
232
- return waveform.squeeze().numpy(), wav_path
233
- except Exception as e:
234
- logger.error(f"❌ Failed to load or process audio: {str(e)}")
235
- raise Exception(f"Audio processing failed: {str(e)}")
236
-
237
- def compute_audio_hash(audio_data):
238
- """Compute a hash of audio data for caching purposes"""
239
- import hashlib
240
- return hashlib.md5(audio_data).hexdigest()
241
 
242
  def handle_asr_request(request, output_dir, sample_rate):
243
- """Handle ASR (Automatic Speech Recognition) requests with optimization"""
244
  if asr_model is None or asr_processor is None:
245
  logger.error("❌ ASR endpoint called but models aren't loaded")
246
  return jsonify({"error": "ASR model not available"}), 503
@@ -261,40 +165,44 @@ def handle_asr_request(request, output_dir, sample_rate):
261
  lang_code = LANGUAGE_CODES[language]
262
  logger.info(f"πŸ”„ Processing {language} audio for ASR")
263
 
264
- # Read the file content for hashing
265
- audio_content = audio_file.read()
266
- audio_hash = compute_audio_hash(audio_content)
267
-
268
- # Check cache first
269
- with asr_lock:
270
- cached_result = asr_cache.get((audio_hash, lang_code))
271
- if cached_result:
272
- logger.info(f"βœ… Using cached ASR result for {language}")
273
- return jsonify({
274
- "transcription": cached_result,
275
- "language": language,
276
- "language_code": lang_code,
277
- "from_cache": True
278
- })
279
-
280
  # Save the uploaded file temporarily
281
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
282
- temp_audio.write(audio_content)
283
  temp_audio_path = temp_audio.name
284
  logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
285
 
286
- # Process audio in a separate thread/process
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  try:
288
- with ThreadPoolExecutor(max_workers=2) as executor:
289
- future = executor.submit(process_audio_file, audio_content, temp_audio_path, output_dir, sample_rate)
290
- waveform, wav_path = future.result()
 
 
 
 
 
 
291
  except Exception as e:
292
- return jsonify({"error": str(e)}), 500
 
293
 
294
  # Process audio for ASR
295
  try:
296
  inputs = asr_processor(
297
- waveform,
298
  sampling_rate=sample_rate,
299
  return_tensors="pt",
300
  language=lang_code
@@ -312,14 +220,6 @@ def handle_asr_request(request, output_dir, sample_rate):
312
  transcription = asr_processor.decode(ids)
313
 
314
  logger.info(f"βœ… Transcription ({language}): {transcription}")
315
-
316
- # Cache the result
317
- with asr_lock:
318
- asr_cache[(audio_hash, lang_code)] = transcription
319
- # Implement cache size limitation if needed
320
- if len(asr_cache) > MAX_CACHE_SIZE:
321
- # Remove oldest entry (simplified approach)
322
- asr_cache.pop(next(iter(asr_cache)))
323
 
324
  # Clean up temp files
325
  try:
@@ -332,8 +232,7 @@ def handle_asr_request(request, output_dir, sample_rate):
332
  return jsonify({
333
  "transcription": transcription,
334
  "language": language,
335
- "language_code": lang_code,
336
- "from_cache": False
337
  })
338
  except Exception as e:
339
  logger.error(f"❌ ASR inference failed: {str(e)}")
@@ -345,14 +244,8 @@ def handle_asr_request(request, output_dir, sample_rate):
345
  logger.debug(f"Stack trace: {traceback.format_exc()}")
346
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
347
 
348
- # Cache key generator for TTS
349
- def tts_cache_key(text, language):
350
- """Generate a cache key for TTS results"""
351
- import hashlib
352
- return hashlib.md5(f"{text}:{language}".encode()).hexdigest()
353
-
354
  def handle_tts_request(request, output_dir):
355
- """Handle TTS (Text-to-Speech) requests with optimization"""
356
  try:
357
  data = request.get_json()
358
  if not data:
@@ -375,57 +268,7 @@ def handle_tts_request(request, output_dir):
375
  return jsonify({"error": f"TTS model for {language} not available"}), 503
376
 
377
  logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
378
-
379
- # Generate cache key
380
- cache_key = tts_cache_key(text_input, language)
381
-
382
- # Check cache
383
- with tts_lock:
384
- cached_file = tts_cache.get(cache_key)
385
- if cached_file and os.path.exists(cached_file):
386
- logger.info(f"βœ… Using cached TTS audio for: '{text_input}'")
387
- return jsonify({
388
- "message": "TTS audio retrieved from cache",
389
- "file_url": f"/download/{os.path.basename(cached_file)}",
390
- "language": language,
391
- "text_length": len(text_input),
392
- "from_cache": True
393
- })
394
 
395
- # Chunk text if too long (optional optimization for very long texts)
396
- MAX_TEXT_LENGTH = 200 # Maximum text length to process in one go
397
-
398
- if len(text_input) > MAX_TEXT_LENGTH:
399
- # Simple chunking by splitting on periods
400
- chunks = []
401
- current_chunk = ""
402
-
403
- for sentence in text_input.split("."):
404
- if len(current_chunk) + len(sentence) < MAX_TEXT_LENGTH:
405
- current_chunk += sentence + "."
406
- else:
407
- if current_chunk:
408
- chunks.append(current_chunk)
409
- current_chunk = sentence + "."
410
-
411
- if current_chunk:
412
- chunks.append(current_chunk)
413
-
414
- logger.info(f"πŸ”„ Text chunked into {len(chunks)} parts for processing")
415
-
416
- # Process chunks and combine results
417
- try:
418
- processor = tts_processors[language]
419
- model = tts_models[language]
420
-
421
- # For simplicity, we'll just use the first chunk in this example
422
- # A full implementation would process all chunks and concatenate audio
423
- text_input = chunks[0]
424
- logger.info(f"⚠️ Using only the first chunk for demonstration: '{text_input}'")
425
- except Exception as e:
426
- logger.error(f"❌ TTS chunking failed: {str(e)}")
427
- return jsonify({"error": f"TTS chunking failed: {str(e)}"}), 500
428
-
429
  try:
430
  processor = tts_processors[language]
431
  model = tts_models[language]
@@ -447,22 +290,10 @@ def handle_tts_request(request, output_dir):
447
 
448
  # Save to file
449
  try:
450
- output_filename = os.path.join(output_dir, f"{language}_{cache_key}.wav")
451
  sampling_rate = model.config.sampling_rate
452
  sf.write(output_filename, waveform, sampling_rate)
453
  logger.info(f"βœ… Speech generated! File saved: {output_filename}")
454
-
455
- # Cache the result
456
- with tts_lock:
457
- tts_cache[cache_key] = output_filename
458
- # Implement cache size limitation if needed
459
- if len(tts_cache) > MAX_CACHE_SIZE:
460
- oldest_key = next(iter(tts_cache))
461
- try:
462
- os.remove(tts_cache[oldest_key])
463
- except:
464
- pass
465
- tts_cache.pop(oldest_key)
466
  except Exception as e:
467
  logger.error(f"❌ Failed to save audio file: {str(e)}")
468
  return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
@@ -471,22 +302,15 @@ def handle_tts_request(request, output_dir):
471
  "message": "TTS audio generated",
472
  "file_url": f"/download/{os.path.basename(output_filename)}",
473
  "language": language,
474
- "text_length": len(text_input),
475
- "from_cache": False
476
  })
477
  except Exception as e:
478
  logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
479
  logger.debug(f"Stack trace: {traceback.format_exc()}")
480
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
481
 
482
- # Cache key generator for translation
483
- def translation_cache_key(text, source_lang, target_lang):
484
- """Generate a cache key for translation results"""
485
- import hashlib
486
- return hashlib.md5(f"{text}:{source_lang}:{target_lang}".encode()).hexdigest()
487
-
488
  def handle_translation_request(request):
489
- """Handle translation requests with optimization"""
490
  try:
491
  data = request.get_json()
492
  if not data:
@@ -506,97 +330,110 @@ def handle_translation_request(request):
506
  target_code = LANGUAGE_CODES.get(target_language, target_language)
507
 
508
  logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
509
-
510
- # Generate cache key
511
- cache_key = translation_cache_key(source_text, source_code, target_code)
512
-
513
- # Check cache
514
- with translation_lock:
515
- cached_result = translation_cache.get(cache_key)
516
- if cached_result:
517
- logger.info(f"βœ… Using cached translation result")
518
- return jsonify({
519
- "translated_text": cached_result,
520
- "source_language": source_language,
521
- "target_language": target_language,
522
- "from_cache": True
523
- })
524
 
525
- # OPTIMIZED: Simplified language pair determination logic
526
- model_key = None
527
  actual_source_code = source_code
528
  actual_target_code = target_code
529
- input_text = source_text
530
-
531
- # Determine which model to use with simplified logic
532
- if f"{source_code}-{target_code}" in translation_models:
533
- # Direct model exists
534
- model_key = f"{source_code}-{target_code}"
535
- use_phi_model = False
536
- elif (source_code in ["pam", "fil", "tgl"] and target_code in ["pam", "fil", "tgl"]):
537
- # Use phi model with appropriate substitutions
538
- model_key = "phi"
539
  use_phi_model = True
540
- # Replace tgl with fil for the phi model if needed
541
- if source_code == "tgl": actual_source_code = "fil"
542
- if target_code == "tgl": actual_target_code = "fil"
543
- # Prepare input text for phi model
544
- input_text = f">>{actual_target_code}<< {source_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  else:
546
- logger.warning(f"⚠️ No translation model available for {source_code}-{target_code}")
547
- return jsonify(
548
- {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
549
-
550
- # Check if model exists and is loaded
551
- if model_key not in translation_models or translation_models[model_key] is None:
552
- logger.error(f"❌ Translation model for {model_key} not loaded")
553
- return jsonify({"error": f"Translation model not available"}), 503
554
 
555
- try:
556
- # Get the model and tokenizer
557
- model = translation_models[model_key]
558
- tokenizer = translation_tokenizers[model_key]
 
559
 
560
- # Tokenize the text
561
- tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
562
- tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
563
 
564
- # Apply length-based optimizations
565
- max_length = min(100, len(source_text.split()) * 2) # Adaptive length
566
-
567
- with torch.no_grad():
568
- translated = model.generate(
569
- **tokenized,
570
- max_length=max_length,
571
- num_beams=4,
572
- length_penalty=0.6,
573
- early_stopping=True,
574
- repetition_penalty=1.5,
575
- no_repeat_ngram_size=3
576
- )
577
-
578
- # Decode the translation
579
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
580
-
581
- logger.info(f"βœ… Translation result: '{result}'")
582
-
583
- # Cache the result
584
- with translation_lock:
585
- translation_cache[cache_key] = result
586
- # Implement cache size limitation if needed
587
- if len(translation_cache) > MAX_CACHE_SIZE:
588
- translation_cache.pop(next(iter(translation_cache)))
589
 
590
- return jsonify({
591
- "translated_text": result,
592
- "source_language": source_language,
593
- "target_language": target_language,
594
- "from_cache": False
595
- })
596
- except Exception as e:
597
- logger.error(f"❌ Translation processing failed: {str(e)}")
598
- logger.debug(f"Stack trace: {traceback.format_exc()}")
599
- return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  except Exception as e:
602
  logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")
 
1
+ # translator.py - Handles ASR, TTS, and translation tasks
2
 
3
  import os
4
  import sys
 
12
  from flask import jsonify
13
  from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
14
  from transformers import MarianMTModel, MarianTokenizer
 
 
 
 
 
15
 
16
  # Configure logging
17
  logger = logging.getLogger("speech_api")
 
24
  translation_models = {}
25
  translation_tokenizers = {}
26
 
 
 
 
 
 
 
 
 
 
 
27
  # Language-specific configurations
28
  LANGUAGE_CODES = {
29
  "kapampangan": "pam",
 
48
  "phi": "Coco-18/opus-mt-phi"
49
  }
50
 
 
 
 
 
51
  def init_models(device):
52
+ """Initialize all models required for the API"""
53
  global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
54
+
55
+ # Initialize ASR model
56
+ ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
57
+ logger.info(f"πŸ”„ Loading ASR model: {ASR_MODEL_ID}")
58
+
59
+ try:
60
+ asr_processor = AutoProcessor.from_pretrained(
61
+ ASR_MODEL_ID,
62
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
63
+ )
64
+ logger.info("βœ… ASR processor loaded successfully")
65
+
66
+ asr_model = Wav2Vec2ForCTC.from_pretrained(
67
+ ASR_MODEL_ID,
68
+ cache_dir=os.environ.get("TRANSFORMERS_CACHE")
69
+ )
70
+ asr_model.to(device)
71
+ logger.info(f"βœ… ASR model loaded successfully on {device}")
72
+ except Exception as e:
73
+ logger.error(f"❌ Error loading ASR model: {str(e)}")
74
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
75
+
76
+ # Initialize TTS models
77
+ for lang, model_id in TTS_MODELS.items():
78
+ logger.info(f"πŸ”„ Loading TTS model for {lang}: {model_id}")
 
79
  try:
80
+ tts_processors[lang] = AutoTokenizer.from_pretrained(
81
  model_id,
82
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
83
  )
84
+ logger.info(f"βœ… {lang} TTS processor loaded")
85
+
86
+ tts_models[lang] = VitsModel.from_pretrained(
87
  model_id,
88
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
89
  )
90
+ tts_models[lang].to(device)
91
  logger.info(f"βœ… {lang} TTS model loaded on {device}")
 
92
  except Exception as e:
93
  logger.error(f"❌ Failed to load {lang} TTS model: {str(e)}")
94
  logger.debug(f"Stack trace: {traceback.format_exc()}")
95
+ tts_models[lang] = None
96
+
97
+ # Initialize translation models
98
+ for model_key, model_id in TRANSLATION_MODELS.items():
99
+ logger.info(f"πŸ”„ Loading Translation model: {model_id}")
100
+
101
  try:
102
+ translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
103
  model_id,
104
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
105
  )
106
+ logger.info(f"βœ… Translation tokenizer loaded successfully for {model_key}")
107
+
108
+ translation_models[model_key] = MarianMTModel.from_pretrained(
109
  model_id,
110
  cache_dir=os.environ.get("TRANSFORMERS_CACHE")
111
  )
112
+ translation_models[model_key].to(device)
113
  logger.info(f"βœ… Translation model loaded successfully on {device} for {model_key}")
 
114
  except Exception as e:
115
  logger.error(f"❌ Error loading Translation model for {model_key}: {str(e)}")
116
  logger.debug(f"Stack trace: {traceback.format_exc()}")
117
+ translation_models[model_key] = None
118
+ translation_tokenizers[model_key] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  def check_model_status():
 
142
  "translation_models": translation_status
143
  }
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def handle_asr_request(request, output_dir, sample_rate):
147
+ """Handle ASR (Automatic Speech Recognition) requests"""
148
  if asr_model is None or asr_processor is None:
149
  logger.error("❌ ASR endpoint called but models aren't loaded")
150
  return jsonify({"error": "ASR model not available"}), 503
 
165
  lang_code = LANGUAGE_CODES[language]
166
  logger.info(f"πŸ”„ Processing {language} audio for ASR")
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # Save the uploaded file temporarily
169
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
170
+ temp_audio.write(audio_file.read())
171
  temp_audio_path = temp_audio.name
172
  logger.debug(f"πŸ“ Temporary audio saved to {temp_audio_path}")
173
 
174
+ # Convert to WAV if necessary
175
+ wav_path = temp_audio_path
176
+ if not audio_file.filename.lower().endswith(".wav"):
177
+ wav_path = os.path.join(output_dir, "converted_audio.wav")
178
+ logger.info(f"πŸ”„ Converting audio to WAV format: {wav_path}")
179
+ try:
180
+ audio = AudioSegment.from_file(temp_audio_path)
181
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
182
+ audio.export(wav_path, format="wav")
183
+ except Exception as e:
184
+ logger.error(f"❌ Audio conversion failed: {str(e)}")
185
+ return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
186
+
187
+ # Load and process the WAV file
188
  try:
189
+ waveform, sr = torchaudio.load(wav_path)
190
+ logger.debug(f"βœ… Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
191
+
192
+ # Resample if needed
193
+ if sr != sample_rate:
194
+ logger.info(f"πŸ”„ Resampling audio from {sr}Hz to {sample_rate}Hz")
195
+ waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
196
+
197
+ waveform = waveform / torch.max(torch.abs(waveform))
198
  except Exception as e:
199
+ logger.error(f"❌ Failed to load or process audio: {str(e)}")
200
+ return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
201
 
202
  # Process audio for ASR
203
  try:
204
  inputs = asr_processor(
205
+ waveform.squeeze().numpy(),
206
  sampling_rate=sample_rate,
207
  return_tensors="pt",
208
  language=lang_code
 
220
  transcription = asr_processor.decode(ids)
221
 
222
  logger.info(f"βœ… Transcription ({language}): {transcription}")
 
 
 
 
 
 
 
 
223
 
224
  # Clean up temp files
225
  try:
 
232
  return jsonify({
233
  "transcription": transcription,
234
  "language": language,
235
+ "language_code": lang_code
 
236
  })
237
  except Exception as e:
238
  logger.error(f"❌ ASR inference failed: {str(e)}")
 
244
  logger.debug(f"Stack trace: {traceback.format_exc()}")
245
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
246
 
 
 
 
 
 
 
247
  def handle_tts_request(request, output_dir):
248
+ """Handle TTS (Text-to-Speech) requests"""
249
  try:
250
  data = request.get_json()
251
  if not data:
 
268
  return jsonify({"error": f"TTS model for {language} not available"}), 503
269
 
270
  logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  try:
273
  processor = tts_processors[language]
274
  model = tts_models[language]
 
290
 
291
  # Save to file
292
  try:
293
+ output_filename = os.path.join(output_dir, f"{language}_output.wav")
294
  sampling_rate = model.config.sampling_rate
295
  sf.write(output_filename, waveform, sampling_rate)
296
  logger.info(f"βœ… Speech generated! File saved: {output_filename}")
 
 
 
 
 
 
 
 
 
 
 
 
297
  except Exception as e:
298
  logger.error(f"❌ Failed to save audio file: {str(e)}")
299
  return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
 
302
  "message": "TTS audio generated",
303
  "file_url": f"/download/{os.path.basename(output_filename)}",
304
  "language": language,
305
+ "text_length": len(text_input)
 
306
  })
307
  except Exception as e:
308
  logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
309
  logger.debug(f"Stack trace: {traceback.format_exc()}")
310
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
311
 
 
 
 
 
 
 
312
  def handle_translation_request(request):
313
+ """Handle translation requests"""
314
  try:
315
  data = request.get_json()
316
  if not data:
 
330
  target_code = LANGUAGE_CODES.get(target_language, target_language)
331
 
332
  logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
+ # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
335
+ use_phi_model = False
336
  actual_source_code = source_code
337
  actual_target_code = target_code
338
+
339
+ # Check if we need to use the phi model with fil replacement
340
+ if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
341
+ use_phi_model = True
342
+ elif (source_code == "pam" and target_code == "tgl"):
 
 
 
 
 
343
  use_phi_model = True
344
+ actual_target_code = "fil" # Replace tgl with fil for the phi model
345
+ elif (source_code == "tgl" and target_code == "pam"):
346
+ use_phi_model = True
347
+ actual_source_code = "fil" # Replace tgl with fil for the phi model
348
+
349
+ if use_phi_model:
350
+ model_key = "phi"
351
+
352
+ # Check if we have the phi model
353
+ if model_key not in translation_models or translation_models[model_key] is None:
354
+ logger.error(f"❌ Translation model for {model_key} not loaded")
355
+ return jsonify({"error": f"Translation model not available"}), 503
356
+
357
+ try:
358
+ # Get the phi model and tokenizer
359
+ model = translation_models[model_key]
360
+ tokenizer = translation_tokenizers[model_key]
361
+
362
+ # Prepend target language token to input
363
+ input_text = f">>{actual_target_code}<< {source_text}"
364
+
365
+ logger.info(f"πŸ”„ Using phi model with input: '{input_text}'")
366
+
367
+ # Tokenize the text
368
+ tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
369
+ tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
370
+
371
+ with torch.no_grad():
372
+ translated = model.generate(
373
+ **tokenized,
374
+ max_length=100, # Reasonable output length
375
+ num_beams=4, # Same as in training
376
+ length_penalty=0.6, # Same as in training
377
+ early_stopping=True, # Same as in training
378
+ repetition_penalty=1.5, # Add this to prevent repetition
379
+ no_repeat_ngram_size=3 # Add this to prevent repetition
380
+ )
381
+
382
+ # Decode the translation
383
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
384
+
385
+ logger.info(f"βœ… Translation result: '{result}'")
386
+
387
+ return jsonify({
388
+ "translated_text": result,
389
+ "source_language": source_language,
390
+ "target_language": target_language
391
+ })
392
+ except Exception as e:
393
+ logger.error(f"❌ Translation processing failed: {str(e)}")
394
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
395
+ return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
396
  else:
397
+ # Create the regular language pair key for other language pairs
398
+ lang_pair = f"{source_code}-{target_code}"
 
 
 
 
 
 
399
 
400
+ # Check if we have a model for this language pair
401
+ if lang_pair not in translation_models:
402
+ logger.warning(f"⚠️ No translation model available for {lang_pair}")
403
+ return jsonify(
404
+ {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
405
 
406
+ if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
407
+ logger.error(f"❌ Translation model for {lang_pair} not loaded")
408
+ return jsonify({"error": f"Translation model not available"}), 503
409
 
410
+ try:
411
+ # Regular translation process for other language pairs
412
+ model = translation_models[lang_pair]
413
+ tokenizer = translation_tokenizers[lang_pair]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ # Tokenize the text
416
+ tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
417
+ tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
418
+
419
+ # Generate translation
420
+ with torch.no_grad():
421
+ translated = model.generate(**tokenized)
422
+
423
+ # Decode the translation
424
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
425
+
426
+ logger.info(f"βœ… Translation result: '{result}'")
427
+
428
+ return jsonify({
429
+ "translated_text": result,
430
+ "source_language": source_language,
431
+ "target_language": target_language
432
+ })
433
+ except Exception as e:
434
+ logger.error(f"❌ Translation processing failed: {str(e)}")
435
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
436
+ return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
437
 
438
  except Exception as e:
439
  logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")