Coco-18 commited on
Commit
5d62ccb
Β·
verified Β·
1 Parent(s): 4edd3da

Update translator.py

Browse files
Files changed (1) hide show
  1. translator.py +123 -57
translator.py CHANGED
@@ -288,9 +288,15 @@ def handle_tts_request(request, output_dir):
288
  logger.debug(f"Stack trace: {traceback.format_exc()}")
289
  return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
290
 
291
- # Save to file
292
  try:
293
- output_filename = os.path.join(output_dir, f"{language}_output.wav")
 
 
 
 
 
 
294
  sampling_rate = model.config.sampling_rate
295
  sf.write(output_filename, waveform, sampling_rate)
296
  logger.info(f"βœ… Speech generated! File saved: {output_filename}")
@@ -298,9 +304,10 @@ def handle_tts_request(request, output_dir):
298
  logger.error(f"❌ Failed to save audio file: {str(e)}")
299
  return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
300
 
 
301
  return jsonify({
302
  "message": "TTS audio generated",
303
- "file_url": f"/download/{os.path.basename(output_filename)}",
304
  "language": language,
305
  "text_length": len(text_input)
306
  })
@@ -309,75 +316,134 @@ def handle_tts_request(request, output_dir):
309
  logger.debug(f"Stack trace: {traceback.format_exc()}")
310
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
311
 
312
- def handle_tts_request(request, output_dir):
313
- """Handle TTS (Text-to-Speech) requests"""
314
  try:
315
  data = request.get_json()
316
  if not data:
317
- logger.warning("⚠️ TTS endpoint called with no JSON data")
318
  return jsonify({"error": "No JSON data provided"}), 400
319
 
320
- text_input = data.get("text", "").strip()
321
- language = data.get("language", "kapampangan").lower()
 
322
 
323
- if not text_input:
324
- logger.warning("⚠️ TTS request with empty text")
325
  return jsonify({"error": "No text provided"}), 400
326
 
327
- if language not in TTS_MODELS:
328
- logger.warning(f"⚠️ TTS requested for unsupported language: {language}")
329
- return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
330
 
331
- if tts_models[language] is None:
332
- logger.error(f"❌ TTS model for {language} not loaded")
333
- return jsonify({"error": f"TTS model for {language} not available"}), 503
334
 
335
- logger.info(f"πŸ”„ Generating TTS for language: {language}, text: '{text_input}'")
 
 
 
336
 
337
- try:
338
- processor = tts_processors[language]
339
- model = tts_models[language]
340
- inputs = processor(text_input, return_tensors="pt")
341
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
342
- except Exception as e:
343
- logger.error(f"❌ TTS preprocessing failed: {str(e)}")
344
- return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
 
345
 
346
- # Generate speech
347
- try:
348
- with torch.no_grad():
349
- output = model(**inputs).waveform
350
- waveform = output.squeeze().cpu().numpy()
351
- except Exception as e:
352
- logger.error(f"❌ TTS inference failed: {str(e)}")
353
- logger.debug(f"Stack trace: {traceback.format_exc()}")
354
- return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
355
 
356
- # Save to file with a unique name to prevent overwriting
357
- try:
358
- # Create a unique filename using timestamp and text hash
359
- import hashlib
360
- import time
361
- text_hash = hashlib.md5(text_input.encode()).hexdigest()[:8]
362
- timestamp = int(time.time())
363
-
364
- output_filename = os.path.join(output_dir, f"{language}_{text_hash}_{timestamp}.wav")
365
- sampling_rate = model.config.sampling_rate
366
- sf.write(output_filename, waveform, sampling_rate)
367
- logger.info(f"βœ… Speech generated! File saved: {output_filename}")
368
- except Exception as e:
369
- logger.error(f"❌ Failed to save audio file: {str(e)}")
370
- return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- # Add cache-busting parameter to URL
373
- return jsonify({
374
- "message": "TTS audio generated",
375
- "file_url": f"/download/{os.path.basename(output_filename)}?t={timestamp}",
376
- "language": language,
377
- "text_length": len(text_input)
378
- })
379
  except Exception as e:
380
- logger.error(f"❌ Unhandled exception in TTS endpoint: {str(e)}")
381
  logger.debug(f"Stack trace: {traceback.format_exc()}")
382
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
383
 
 
288
  logger.debug(f"Stack trace: {traceback.format_exc()}")
289
  return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
290
 
291
+ # Save to file with a unique name to prevent overwriting
292
  try:
293
+ # Create a unique filename using timestamp and text hash
294
+ import hashlib
295
+ import time
296
+ text_hash = hashlib.md5(text_input.encode()).hexdigest()[:8]
297
+ timestamp = int(time.time())
298
+
299
+ output_filename = os.path.join(output_dir, f"{language}_{text_hash}_{timestamp}.wav")
300
  sampling_rate = model.config.sampling_rate
301
  sf.write(output_filename, waveform, sampling_rate)
302
  logger.info(f"βœ… Speech generated! File saved: {output_filename}")
 
304
  logger.error(f"❌ Failed to save audio file: {str(e)}")
305
  return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
306
 
307
+ # Add cache-busting parameter to URL
308
  return jsonify({
309
  "message": "TTS audio generated",
310
+ "file_url": f"/download/{os.path.basename(output_filename)}?t={timestamp}",
311
  "language": language,
312
  "text_length": len(text_input)
313
  })
 
316
  logger.debug(f"Stack trace: {traceback.format_exc()}")
317
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
318
 
319
+ def handle_translation_request(request):
320
+ """Handle translation requests"""
321
  try:
322
  data = request.get_json()
323
  if not data:
324
+ logger.warning("⚠️ Translation endpoint called with no JSON data")
325
  return jsonify({"error": "No JSON data provided"}), 400
326
 
327
+ source_text = data.get("text", "").strip()
328
+ source_language = data.get("source_language", "").lower()
329
+ target_language = data.get("target_language", "").lower()
330
 
331
+ if not source_text:
332
+ logger.warning("⚠️ Translation request with empty text")
333
  return jsonify({"error": "No text provided"}), 400
334
 
335
+ # Map language names to codes
336
+ source_code = LANGUAGE_CODES.get(source_language, source_language)
337
+ target_code = LANGUAGE_CODES.get(target_language, target_language)
338
 
339
+ logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
 
 
340
 
341
+ # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
342
+ use_phi_model = False
343
+ actual_source_code = source_code
344
+ actual_target_code = target_code
345
 
346
+ # Check if we need to use the phi model with fil replacement
347
+ if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
348
+ use_phi_model = True
349
+ elif (source_code == "pam" and target_code == "tgl"):
350
+ use_phi_model = True
351
+ actual_target_code = "fil" # Replace tgl with fil for the phi model
352
+ elif (source_code == "tgl" and target_code == "pam"):
353
+ use_phi_model = True
354
+ actual_source_code = "fil" # Replace tgl with fil for the phi model
355
 
356
+ if use_phi_model:
357
+ model_key = "phi"
 
 
 
 
 
 
 
358
 
359
+ # Check if we have the phi model
360
+ if model_key not in translation_models or translation_models[model_key] is None:
361
+ logger.error(f"❌ Translation model for {model_key} not loaded")
362
+ return jsonify({"error": f"Translation model not available"}), 503
363
+
364
+ try:
365
+ # Get the phi model and tokenizer
366
+ model = translation_models[model_key]
367
+ tokenizer = translation_tokenizers[model_key]
368
+
369
+ # Prepend target language token to input
370
+ input_text = f">>{actual_target_code}<< {source_text}"
371
+
372
+ logger.info(f"πŸ”„ Using phi model with input: '{input_text}'")
373
+
374
+ # Tokenize the text
375
+ tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
376
+ tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
377
+
378
+ with torch.no_grad():
379
+ translated = model.generate(
380
+ **tokenized,
381
+ max_length=100, # Reasonable output length
382
+ num_beams=4, # Same as in training
383
+ length_penalty=0.6, # Same as in training
384
+ early_stopping=True, # Same as in training
385
+ repetition_penalty=1.5, # Add this to prevent repetition
386
+ no_repeat_ngram_size=3 # Add this to prevent repetition
387
+ )
388
+
389
+ # Decode the translation
390
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
391
+
392
+ logger.info(f"βœ… Translation result: '{result}'")
393
+
394
+ return jsonify({
395
+ "translated_text": result,
396
+ "source_language": source_language,
397
+ "target_language": target_language
398
+ })
399
+ except Exception as e:
400
+ logger.error(f"❌ Translation processing failed: {str(e)}")
401
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
402
+ return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
403
+ else:
404
+ # Create the regular language pair key for other language pairs
405
+ lang_pair = f"{source_code}-{target_code}"
406
+
407
+ # Check if we have a model for this language pair
408
+ if lang_pair not in translation_models:
409
+ logger.warning(f"⚠️ No translation model available for {lang_pair}")
410
+ return jsonify(
411
+ {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
412
+
413
+ if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
414
+ logger.error(f"❌ Translation model for {lang_pair} not loaded")
415
+ return jsonify({"error": f"Translation model not available"}), 503
416
+
417
+ try:
418
+ # Regular translation process for other language pairs
419
+ model = translation_models[lang_pair]
420
+ tokenizer = translation_tokenizers[lang_pair]
421
+
422
+ # Tokenize the text
423
+ tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
424
+ tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
425
+
426
+ # Generate translation
427
+ with torch.no_grad():
428
+ translated = model.generate(**tokenized)
429
+
430
+ # Decode the translation
431
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
432
+
433
+ logger.info(f"βœ… Translation result: '{result}'")
434
+
435
+ return jsonify({
436
+ "translated_text": result,
437
+ "source_language": source_language,
438
+ "target_language": target_language
439
+ })
440
+ except Exception as e:
441
+ logger.error(f"❌ Translation processing failed: {str(e)}")
442
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
443
+ return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
444
 
 
 
 
 
 
 
 
445
  except Exception as e:
446
+ logger.error(f"❌ Unhandled exception in translation endpoint: {str(e)}")
447
  logger.debug(f"Stack trace: {traceback.format_exc()}")
448
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
449