Update app.py
Browse files
app.py
CHANGED
@@ -121,30 +121,42 @@ for lang, model_id in TTS_MODELS.items():
|
|
121 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
122 |
tts_models[lang] = None
|
123 |
|
124 |
-
#
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
translation_model = None
|
130 |
-
translation_tokenizer = None
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
logger.info("
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
# Constants
|
150 |
SAMPLE_RATE = 16000
|
@@ -350,10 +362,6 @@ def download_audio(filename):
|
|
350 |
|
351 |
@app.route("/translate", methods=["POST"])
|
352 |
def translate_text():
|
353 |
-
if translation_model is None or translation_tokenizer is None:
|
354 |
-
logger.error("β Translation endpoint called but models aren't loaded")
|
355 |
-
return jsonify({"error": "Translation model not available"}), 503
|
356 |
-
|
357 |
try:
|
358 |
data = request.get_json()
|
359 |
if not data:
|
@@ -367,20 +375,40 @@ def translate_text():
|
|
367 |
if not source_text:
|
368 |
logger.warning("β οΈ Translation request with empty text")
|
369 |
return jsonify({"error": "No text provided"}), 400
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
try:
|
|
|
|
|
|
|
|
|
374 |
# Tokenize the text
|
375 |
-
tokenized =
|
376 |
tokenized = {k: v.to(device) for k, v in tokenized.items()}
|
377 |
|
378 |
# Generate translation
|
379 |
with torch.no_grad():
|
380 |
-
translated =
|
381 |
|
382 |
# Decode the translation
|
383 |
-
result =
|
384 |
|
385 |
logger.info(f"β
Translation result: '{result}'")
|
386 |
|
@@ -399,7 +427,6 @@ def translate_text():
|
|
399 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
400 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
401 |
|
402 |
-
|
403 |
if __name__ == "__main__":
|
404 |
logger.info("π Starting Speech API server")
|
405 |
logger.info(f"π System status: ASR model: {'β
' if asr_model else 'β'}")
|
|
|
121 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
122 |
tts_models[lang] = None
|
123 |
|
124 |
+
# Replace the single translation model with a dictionary of models
|
125 |
+
TRANSLATION_MODELS = {
|
126 |
+
"pam-eng": "Coco-18/opus-mt-pam-en",
|
127 |
+
"eng-pam": "Coco-18/opus-mt-en-pam",
|
128 |
+
"tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
|
129 |
+
"eng-tgl": "Helsinki-NLP/opus-mt-en-tl"
|
130 |
+
# pam-tgl and tgl-pam will be added later
|
131 |
+
}
|
132 |
|
133 |
+
logger.info(f"π Loading Translation model: {TRANSLATION_MODELS}")
|
|
|
|
|
134 |
|
135 |
+
# Replace the single model initialization with:
|
136 |
+
translation_models = {}
|
137 |
+
translation_tokenizers = {}
|
138 |
+
|
139 |
+
for lang_pair, model_id in TRANSLATION_MODELS.items():
|
140 |
+
logger.info(f"π Loading Translation model: {model_id}")
|
141 |
|
142 |
+
try:
|
143 |
+
translation_tokenizers[lang_pair] = MarianTokenizer.from_pretrained(
|
144 |
+
model_id,
|
145 |
+
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
|
146 |
+
)
|
147 |
+
logger.info(f"β
Translation tokenizer loaded successfully for {lang_pair}")
|
148 |
+
|
149 |
+
translation_models[lang_pair] = MarianMTModel.from_pretrained(
|
150 |
+
model_id,
|
151 |
+
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
|
152 |
+
)
|
153 |
+
translation_models[lang_pair].to(device)
|
154 |
+
logger.info(f"β
Translation model loaded successfully on {device} for {lang_pair}")
|
155 |
+
except Exception as e:
|
156 |
+
logger.error(f"β Error loading Translation model for {lang_pair}: {str(e)}")
|
157 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
158 |
+
|
159 |
+
|
160 |
|
161 |
# Constants
|
162 |
SAMPLE_RATE = 16000
|
|
|
362 |
|
363 |
@app.route("/translate", methods=["POST"])
|
364 |
def translate_text():
|
|
|
|
|
|
|
|
|
365 |
try:
|
366 |
data = request.get_json()
|
367 |
if not data:
|
|
|
375 |
if not source_text:
|
376 |
logger.warning("β οΈ Translation request with empty text")
|
377 |
return jsonify({"error": "No text provided"}), 400
|
378 |
+
|
379 |
+
# Map language names to codes
|
380 |
+
source_code = LANGUAGE_CODES.get(source_language, source_language)
|
381 |
+
target_code = LANGUAGE_CODES.get(target_language, target_language)
|
382 |
+
|
383 |
+
# Create the language pair key
|
384 |
+
lang_pair = f"{source_code}-{target_code}"
|
385 |
|
386 |
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
387 |
|
388 |
+
# Check if we have a model for this language pair
|
389 |
+
if lang_pair not in translation_models:
|
390 |
+
logger.warning(f"β οΈ No translation model available for {lang_pair}")
|
391 |
+
return jsonify({"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
|
392 |
+
|
393 |
+
if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
|
394 |
+
logger.error(f"β Translation model for {lang_pair} not loaded")
|
395 |
+
return jsonify({"error": f"Translation model not available"}), 503
|
396 |
+
|
397 |
try:
|
398 |
+
# Get the appropriate model and tokenizer
|
399 |
+
model = translation_models[lang_pair]
|
400 |
+
tokenizer = translation_tokenizers[lang_pair]
|
401 |
+
|
402 |
# Tokenize the text
|
403 |
+
tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
|
404 |
tokenized = {k: v.to(device) for k, v in tokenized.items()}
|
405 |
|
406 |
# Generate translation
|
407 |
with torch.no_grad():
|
408 |
+
translated = model.generate(**tokenized)
|
409 |
|
410 |
# Decode the translation
|
411 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
412 |
|
413 |
logger.info(f"β
Translation result: '{result}'")
|
414 |
|
|
|
427 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
428 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
429 |
|
|
|
430 |
if __name__ == "__main__":
|
431 |
logger.info("π Starting Speech API server")
|
432 |
logger.info(f"π System status: ASR model: {'β
' if asr_model else 'β'}")
|