Update translator.py
Browse files- translator.py +54 -113
    	
        translator.py
    CHANGED
    
    | @@ -309,134 +309,75 @@ def handle_tts_request(request, output_dir): | |
| 309 | 
             
                    logger.debug(f"Stack trace: {traceback.format_exc()}")
         | 
| 310 | 
             
                    return jsonify({"error": f"Internal server error: {str(e)}"}), 500
         | 
| 311 |  | 
| 312 | 
            -
            def  | 
| 313 | 
            -
                """Handle  | 
| 314 | 
             
                try:
         | 
| 315 | 
             
                    data = request.get_json()
         | 
| 316 | 
             
                    if not data:
         | 
| 317 | 
            -
                        logger.warning("β οΈ  | 
| 318 | 
             
                        return jsonify({"error": "No JSON data provided"}), 400
         | 
| 319 |  | 
| 320 | 
            -
                     | 
| 321 | 
            -
                     | 
| 322 | 
            -
                    target_language = data.get("target_language", "").lower()
         | 
| 323 |  | 
| 324 | 
            -
                    if not  | 
| 325 | 
            -
                        logger.warning("β οΈ  | 
| 326 | 
             
                        return jsonify({"error": "No text provided"}), 400
         | 
| 327 |  | 
| 328 | 
            -
                     | 
| 329 | 
            -
             | 
| 330 | 
            -
             | 
| 331 | 
            -
             | 
| 332 | 
            -
                    logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
         | 
| 333 | 
            -
             | 
| 334 | 
            -
                    # Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
         | 
| 335 | 
            -
                    use_phi_model = False
         | 
| 336 | 
            -
                    actual_source_code = source_code
         | 
| 337 | 
            -
                    actual_target_code = target_code
         | 
| 338 | 
            -
             | 
| 339 | 
            -
                    # Check if we need to use the phi model with fil replacement
         | 
| 340 | 
            -
                    if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
         | 
| 341 | 
            -
                        use_phi_model = True
         | 
| 342 | 
            -
                    elif (source_code == "pam" and target_code == "tgl"):
         | 
| 343 | 
            -
                        use_phi_model = True
         | 
| 344 | 
            -
                        actual_target_code = "fil"  # Replace tgl with fil for the phi model
         | 
| 345 | 
            -
                    elif (source_code == "tgl" and target_code == "pam"):
         | 
| 346 | 
            -
                        use_phi_model = True
         | 
| 347 | 
            -
                        actual_source_code = "fil"  # Replace tgl with fil for the phi model
         | 
| 348 | 
            -
             | 
| 349 | 
            -
                    if use_phi_model:
         | 
| 350 | 
            -
                        model_key = "phi"
         | 
| 351 | 
            -
             | 
| 352 | 
            -
                        # Check if we have the phi model
         | 
| 353 | 
            -
                        if model_key not in translation_models or translation_models[model_key] is None:
         | 
| 354 | 
            -
                            logger.error(f"β Translation model for {model_key} not loaded")
         | 
| 355 | 
            -
                            return jsonify({"error": f"Translation model not available"}), 503
         | 
| 356 | 
            -
             | 
| 357 | 
            -
                        try:
         | 
| 358 | 
            -
                            # Get the phi model and tokenizer
         | 
| 359 | 
            -
                            model = translation_models[model_key]
         | 
| 360 | 
            -
                            tokenizer = translation_tokenizers[model_key]
         | 
| 361 | 
            -
             | 
| 362 | 
            -
                            # Prepend target language token to input
         | 
| 363 | 
            -
                            input_text = f">>{actual_target_code}<< {source_text}"
         | 
| 364 | 
            -
             | 
| 365 | 
            -
                            logger.info(f"π Using phi model with input: '{input_text}'")
         | 
| 366 | 
            -
             | 
| 367 | 
            -
                            # Tokenize the text
         | 
| 368 | 
            -
                            tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
         | 
| 369 | 
            -
                            tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
         | 
| 370 | 
            -
             | 
| 371 | 
            -
                            with torch.no_grad():
         | 
| 372 | 
            -
                                translated = model.generate(
         | 
| 373 | 
            -
                                    **tokenized,
         | 
| 374 | 
            -
                                    max_length=100,              # Reasonable output length
         | 
| 375 | 
            -
                                    num_beams=4,                 # Same as in training
         | 
| 376 | 
            -
                                    length_penalty=0.6,          # Same as in training
         | 
| 377 | 
            -
                                    early_stopping=True,         # Same as in training
         | 
| 378 | 
            -
                                    repetition_penalty=1.5,      # Add this to prevent repetition
         | 
| 379 | 
            -
                                    no_repeat_ngram_size=3       # Add this to prevent repetition
         | 
| 380 | 
            -
                                )
         | 
| 381 | 
            -
             | 
| 382 | 
            -
                            # Decode the translation
         | 
| 383 | 
            -
                            result = tokenizer.decode(translated[0], skip_special_tokens=True)
         | 
| 384 | 
            -
             | 
| 385 | 
            -
                            logger.info(f"β
 Translation result: '{result}'")
         | 
| 386 | 
            -
             | 
| 387 | 
            -
                            return jsonify({
         | 
| 388 | 
            -
                                "translated_text": result,
         | 
| 389 | 
            -
                                "source_language": source_language,
         | 
| 390 | 
            -
                                "target_language": target_language
         | 
| 391 | 
            -
                            })
         | 
| 392 | 
            -
                        except Exception as e:
         | 
| 393 | 
            -
                            logger.error(f"β Translation processing failed: {str(e)}")
         | 
| 394 | 
            -
                            logger.debug(f"Stack trace: {traceback.format_exc()}")
         | 
| 395 | 
            -
                            return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
         | 
| 396 | 
            -
                    else:
         | 
| 397 | 
            -
                        # Create the regular language pair key for other language pairs
         | 
| 398 | 
            -
                        lang_pair = f"{source_code}-{target_code}"
         | 
| 399 | 
            -
             | 
| 400 | 
            -
                        # Check if we have a model for this language pair
         | 
| 401 | 
            -
                        if lang_pair not in translation_models:
         | 
| 402 | 
            -
                            logger.warning(f"β οΈ No translation model available for {lang_pair}")
         | 
| 403 | 
            -
                            return jsonify(
         | 
| 404 | 
            -
                                {"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
         | 
| 405 | 
            -
             | 
| 406 | 
            -
                        if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
         | 
| 407 | 
            -
                            logger.error(f"β Translation model for {lang_pair} not loaded")
         | 
| 408 | 
            -
                            return jsonify({"error": f"Translation model not available"}), 503
         | 
| 409 | 
            -
             | 
| 410 | 
            -
                        try:
         | 
| 411 | 
            -
                            # Regular translation process for other language pairs
         | 
| 412 | 
            -
                            model = translation_models[lang_pair]
         | 
| 413 | 
            -
                            tokenizer = translation_tokenizers[lang_pair]
         | 
| 414 |  | 
| 415 | 
            -
             | 
| 416 | 
            -
             | 
| 417 | 
            -
             | 
| 418 |  | 
| 419 | 
            -
             | 
| 420 | 
            -
                            with torch.no_grad():
         | 
| 421 | 
            -
                                translated = model.generate(**tokenized)
         | 
| 422 |  | 
| 423 | 
            -
             | 
| 424 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 425 |  | 
| 426 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 427 |  | 
| 428 | 
            -
             | 
| 429 | 
            -
             | 
| 430 | 
            -
             | 
| 431 | 
            -
             | 
| 432 | 
            -
             | 
| 433 | 
            -
                         | 
| 434 | 
            -
             | 
| 435 | 
            -
             | 
| 436 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 437 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 438 | 
             
                except Exception as e:
         | 
| 439 | 
            -
                    logger.error(f"β Unhandled exception in  | 
| 440 | 
             
                    logger.debug(f"Stack trace: {traceback.format_exc()}")
         | 
| 441 | 
             
                    return jsonify({"error": f"Internal server error: {str(e)}"}), 500
         | 
| 442 |  | 
|  | |
| 309 | 
             
                    logger.debug(f"Stack trace: {traceback.format_exc()}")
         | 
| 310 | 
             
                    return jsonify({"error": f"Internal server error: {str(e)}"}), 500
         | 
| 311 |  | 
| 312 | 
            +
            def handle_tts_request(request, output_dir):
         | 
| 313 | 
            +
                """Handle TTS (Text-to-Speech) requests"""
         | 
| 314 | 
             
                try:
         | 
| 315 | 
             
                    data = request.get_json()
         | 
| 316 | 
             
                    if not data:
         | 
| 317 | 
            +
                        logger.warning("β οΈ TTS endpoint called with no JSON data")
         | 
| 318 | 
             
                        return jsonify({"error": "No JSON data provided"}), 400
         | 
| 319 |  | 
| 320 | 
            +
                    text_input = data.get("text", "").strip()
         | 
| 321 | 
            +
                    language = data.get("language", "kapampangan").lower()
         | 
|  | |
| 322 |  | 
| 323 | 
            +
                    if not text_input:
         | 
| 324 | 
            +
                        logger.warning("β οΈ TTS request with empty text")
         | 
| 325 | 
             
                        return jsonify({"error": "No text provided"}), 400
         | 
| 326 |  | 
| 327 | 
            +
                    if language not in TTS_MODELS:
         | 
| 328 | 
            +
                        logger.warning(f"β οΈ TTS requested for unsupported language: {language}")
         | 
| 329 | 
            +
                        return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 330 |  | 
| 331 | 
            +
                    if tts_models[language] is None:
         | 
| 332 | 
            +
                        logger.error(f"β TTS model for {language} not loaded")
         | 
| 333 | 
            +
                        return jsonify({"error": f"TTS model for {language} not available"}), 503
         | 
| 334 |  | 
| 335 | 
            +
                    logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'")
         | 
|  | |
|  | |
| 336 |  | 
| 337 | 
            +
                    try:
         | 
| 338 | 
            +
                        processor = tts_processors[language]
         | 
| 339 | 
            +
                        model = tts_models[language]
         | 
| 340 | 
            +
                        inputs = processor(text_input, return_tensors="pt")
         | 
| 341 | 
            +
                        inputs = {k: v.to(model.device) for k, v in inputs.items()}
         | 
| 342 | 
            +
                    except Exception as e:
         | 
| 343 | 
            +
                        logger.error(f"β TTS preprocessing failed: {str(e)}")
         | 
| 344 | 
            +
                        return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
         | 
| 345 |  | 
| 346 | 
            +
                    # Generate speech
         | 
| 347 | 
            +
                    try:
         | 
| 348 | 
            +
                        with torch.no_grad():
         | 
| 349 | 
            +
                            output = model(**inputs).waveform
         | 
| 350 | 
            +
                            waveform = output.squeeze().cpu().numpy()
         | 
| 351 | 
            +
                    except Exception as e:
         | 
| 352 | 
            +
                        logger.error(f"β TTS inference failed: {str(e)}")
         | 
| 353 | 
            +
                        logger.debug(f"Stack trace: {traceback.format_exc()}")
         | 
| 354 | 
            +
                        return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
         | 
| 355 |  | 
| 356 | 
            +
                    # Save to file with a unique name to prevent overwriting
         | 
| 357 | 
            +
                    try:
         | 
| 358 | 
            +
                        # Create a unique filename using timestamp and text hash
         | 
| 359 | 
            +
                        import hashlib
         | 
| 360 | 
            +
                        import time
         | 
| 361 | 
            +
                        text_hash = hashlib.md5(text_input.encode()).hexdigest()[:8]
         | 
| 362 | 
            +
                        timestamp = int(time.time())
         | 
| 363 | 
            +
                        
         | 
| 364 | 
            +
                        output_filename = os.path.join(output_dir, f"{language}_{text_hash}_{timestamp}.wav")
         | 
| 365 | 
            +
                        sampling_rate = model.config.sampling_rate
         | 
| 366 | 
            +
                        sf.write(output_filename, waveform, sampling_rate)
         | 
| 367 | 
            +
                        logger.info(f"β
 Speech generated! File saved: {output_filename}")
         | 
| 368 | 
            +
                    except Exception as e:
         | 
| 369 | 
            +
                        logger.error(f"β Failed to save audio file: {str(e)}")
         | 
| 370 | 
            +
                        return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
         | 
| 371 |  | 
| 372 | 
            +
                    # Add cache-busting parameter to URL
         | 
| 373 | 
            +
                    return jsonify({
         | 
| 374 | 
            +
                        "message": "TTS audio generated",
         | 
| 375 | 
            +
                        "file_url": f"/download/{os.path.basename(output_filename)}?t={timestamp}",
         | 
| 376 | 
            +
                        "language": language,
         | 
| 377 | 
            +
                        "text_length": len(text_input)
         | 
| 378 | 
            +
                    })
         | 
| 379 | 
             
                except Exception as e:
         | 
| 380 | 
            +
                    logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}")
         | 
| 381 | 
             
                    logger.debug(f"Stack trace: {traceback.format_exc()}")
         | 
| 382 | 
             
                    return jsonify({"error": f"Internal server error: {str(e)}"}), 500
         | 
| 383 |  | 
