Coco-18 commited on
Commit
4edd3da
Β·
verified Β·
1 Parent(s): aa906c3

Update translator.py

Browse files
Files changed (1) hide show
  1. translator.py +54 -113
translator.py CHANGED
@@ -309,134 +309,75 @@ def handle_tts_request(request, output_dir):
309
  logger.debug(f"Stack trace: {traceback.format_exc()}")
310
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
311
 
312
- def handle_translation_request(request):
313
- """Handle translation requests"""
314
  try:
315
  data = request.get_json()
316
  if not data:
317
- logger.warning("⚠️ Translation endpoint called with no JSON data")
318
  return jsonify({"error": "No JSON data provided"}), 400
319
 
320
- source_text = data.get("text", "").strip()
321
- source_language = data.get("source_language", "").lower()
322
- target_language = data.get("target_language", "").lower()
323
 
324
- if not source_text:
325
- logger.warning("⚠️ Translation request with empty text")
326
  return jsonify({"error": "No text provided"}), 400
327
 
328
- # Map language names to codes
329
- source_code = LANGUAGE_CODES.get(source_language, source_language)
330
- target_code = LANGUAGE_CODES.get(target_language, target_language)
331
-
332
- logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
333
-
334
- # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
335
- use_phi_model = False
336
- actual_source_code = source_code
337
- actual_target_code = target_code
338
-
339
- # Check if we need to use the phi model with fil replacement
340
- if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
341
- use_phi_model = True
342
- elif (source_code == "pam" and target_code == "tgl"):
343
- use_phi_model = True
344
- actual_target_code = "fil" # Replace tgl with fil for the phi model
345
- elif (source_code == "tgl" and target_code == "pam"):
346
- use_phi_model = True
347
- actual_source_code = "fil" # Replace tgl with fil for the phi model
348
-
349
- if use_phi_model:
350
- model_key = "phi"
351
-
352
- # Check if we have the phi model
353
- if model_key not in translation_models or translation_models[model_key] is None:
354
- logger.error(f"❌ Translation model for {model_key} not loaded")
355
- return jsonify({"error": f"Translation model not available"}), 503
356
-
357
- try:
358
- # Get the phi model and tokenizer
359
- model = translation_models[model_key]
360
- tokenizer = translation_tokenizers[model_key]
361
-
362
- # Prepend target language token to input
363
- input_text = f">>{actual_target_code}<< {source_text}"
364
-
365
- logger.info(f"πŸ”„ Using phi model with input: '{input_text}'")
366
-
367
- # Tokenize the text
368
- tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
369
- tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
370
-
371
- with torch.no_grad():
372
- translated = model.generate(
373
- **tokenized,
374
- max_length=100, # Reasonable output length
375
- num_beams=4, # Same as in training
376
- length_penalty=0.6, # Same as in training
377
- early_stopping=True, # Same as in training
378
- repetition_penalty=1.5, # Add this to prevent repetition
379
- no_repeat_ngram_size=3 # Add this to prevent repetition
380
- )
381
-
382
- # Decode the translation
383
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
384
-
385
- logger.info(f"βœ… Translation result: '{result}'")
386
-
387
- return jsonify({
388
- "translated_text": result,
389
- "source_language": source_language,
390
- "target_language": target_language
391
- })
392
- except Exception as e:
393
- logger.error(f"❌ Translation processing failed: {str(e)}")
394
- logger.debug(f"Stack trace: {traceback.format_exc()}")
395
- return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
396
- else:
397
- # Create the regular language pair key for other language pairs
398
- lang_pair = f"{source_code}-{target_code}"
399
-
400
- # Check if we have a model for this language pair
401
- if lang_pair not in translation_models:
402
- logger.warning(f"⚠️ No translation model available for {lang_pair}")
403
- return jsonify(
404
- {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
405
-
406
- if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
407
- logger.error(f"❌ Translation model for {lang_pair} not loaded")
408
- return jsonify({"error": f"Translation model not available"}), 503
409
-
410
- try:
411
- # Regular translation process for other language pairs
412
- model = translation_models[lang_pair]
413
- tokenizer = translation_tokenizers[lang_pair]
414
 
415
- # Tokenize the text
416
- tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
417
- tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
418
 
419
- # Generate translation
420
- with torch.no_grad():
421
- translated = model.generate(**tokenized)
422
 
423
- # Decode the translation
424
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
 
 
 
 
 
 
425
 
426
- logger.info(f"βœ… Translation result: '{result}'")
 
 
 
 
 
 
 
 
427
 
428
- return jsonify({
429
- "translated_text": result,
430
- "source_language": source_language,
431
- "target_language": target_language
432
- })
433
- except Exception as e:
434
- logger.error(f"❌ Translation processing failed: {str(e)}")
435
- logger.debug(f"Stack trace: {traceback.format_exc()}")
436
- return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
 
 
 
 
 
 
437
 
 
 
 
 
 
 
 
438
  except Exception as e:
439
- logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")
440
  logger.debug(f"Stack trace: {traceback.format_exc()}")
441
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
442
 
 
309
  logger.debug(f"Stack trace: {traceback.format_exc()}")
310
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
311
 
312
+ def handle_tts_request(request, output_dir):
313
+ """Handle TTS (Text-to-Speech) requests"""
314
  try:
315
  data = request.get_json()
316
  if not data:
317
+ logger.warning("⚠️ TTS endpoint called with no JSON data")
318
  return jsonify({"error": "No JSON data provided"}), 400
319
 
320
+ text_input = data.get("text", "").strip()
321
+ language = data.get("language", "kapampangan").lower()
 
322
 
323
+ if not text_input:
324
+ logger.warning("⚠️ TTS request with empty text")
325
  return jsonify({"error": "No text provided"}), 400
326
 
327
+ if language not in TTS_MODELS:
328
+ logger.warning(f"⚠️ TTS requested for unsupported language: {language}")
329
+ return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ if tts_models[language] is None:
332
+ logger.error(f"❌ TTS model for {language} not loaded")
333
+ return jsonify({"error": f"TTS model for {language} not available"}), 503
334
 
335
+ logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
 
 
336
 
337
+ try:
338
+ processor = tts_processors[language]
339
+ model = tts_models[language]
340
+ inputs = processor(text_input, return_tensors="pt")
341
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
342
+ except Exception as e:
343
+ logger.error(f"❌ TTS preprocessing failed: {str(e)}")
344
+ return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
345
 
346
+ # Generate speech
347
+ try:
348
+ with torch.no_grad():
349
+ output = model(**inputs).waveform
350
+ waveform = output.squeeze().cpu().numpy()
351
+ except Exception as e:
352
+ logger.error(f"❌ TTS inference failed: {str(e)}")
353
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
354
+ return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
355
 
356
+ # Save to file with a unique name to prevent overwriting
357
+ try:
358
+ # Create a unique filename using timestamp and text hash
359
+ import hashlib
360
+ import time
361
+ text_hash = hashlib.md5(text_input.encode()).hexdigest()[:8]
362
+ timestamp = int(time.time())
363
+
364
+ output_filename = os.path.join(output_dir, f"{language}_{text_hash}_{timestamp}.wav")
365
+ sampling_rate = model.config.sampling_rate
366
+ sf.write(output_filename, waveform, sampling_rate)
367
+ logger.info(f"βœ… Speech generated! File saved: {output_filename}")
368
+ except Exception as e:
369
+ logger.error(f"❌ Failed to save audio file: {str(e)}")
370
+ return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
371
 
372
+ # Add cache-busting parameter to URL
373
+ return jsonify({
374
+ "message": "TTS audio generated",
375
+ "file_url": f"/download/{os.path.basename(output_filename)}?t={timestamp}",
376
+ "language": language,
377
+ "text_length": len(text_input)
378
+ })
379
  except Exception as e:
380
+ logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
381
  logger.debug(f"Stack trace: {traceback.format_exc()}")
382
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
383