Coco-18 commited on
Commit
a70fb66
Β·
verified Β·
1 Parent(s): ef9fa31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -29
app.py CHANGED
@@ -121,30 +121,42 @@ for lang, model_id in TTS_MODELS.items():
121
  logger.debug(f"Stack trace: {traceback.format_exc()}")
122
  tts_models[lang] = None
123
 
124
- # Add this with your other model configurations
125
- TRANSLATION_MODEL_ID = "Coco-18/opus-mt-pam-en"
126
- logger.info(f"πŸ”„ Loading Translation model: {TRANSLATION_MODEL_ID}")
 
 
 
 
 
127
 
128
- # Initialize translation model and tokenizer (add this after your other model initializations)
129
- translation_model = None
130
- translation_tokenizer = None
131
 
132
- try:
133
- translation_tokenizer = MarianTokenizer.from_pretrained(
134
- TRANSLATION_MODEL_ID,
135
- cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
136
- )
137
- logger.info("βœ… Translation tokenizer loaded successfully")
138
 
139
- translation_model = MarianMTModel.from_pretrained(
140
- TRANSLATION_MODEL_ID,
141
- cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
142
- )
143
- translation_model.to(device)
144
- logger.info(f"βœ… Translation model loaded successfully on {device}")
145
- except Exception as e:
146
- logger.error(f"❌ Error loading Translation model: {str(e)}")
147
- logger.debug(f"Stack trace: {traceback.format_exc()}")
 
 
 
 
 
 
 
 
 
148
 
149
  # Constants
150
  SAMPLE_RATE = 16000
@@ -350,10 +362,6 @@ def download_audio(filename):
350
 
351
  @app.route("/translate", methods=["POST"])
352
  def translate_text():
353
- if translation_model is None or translation_tokenizer is None:
354
- logger.error("❌ Translation endpoint called but models aren't loaded")
355
- return jsonify({"error": "Translation model not available"}), 503
356
-
357
  try:
358
  data = request.get_json()
359
  if not data:
@@ -367,20 +375,40 @@ def translate_text():
367
  if not source_text:
368
  logger.warning("⚠️ Translation request with empty text")
369
  return jsonify({"error": "No text provided"}), 400
 
 
 
 
 
 
 
370
 
371
  logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
372
 
 
 
 
 
 
 
 
 
 
373
  try:
 
 
 
 
374
  # Tokenize the text
375
- tokenized = translation_tokenizer(source_text, return_tensors="pt", padding=True)
376
  tokenized = {k: v.to(device) for k, v in tokenized.items()}
377
 
378
  # Generate translation
379
  with torch.no_grad():
380
- translated = translation_model.generate(**tokenized)
381
 
382
  # Decode the translation
383
- result = translation_tokenizer.decode(translated[0], skip_special_tokens=True)
384
 
385
  logger.info(f"βœ… Translation result: '{result}'")
386
 
@@ -399,7 +427,6 @@ def translate_text():
399
  logger.debug(f"Stack trace: {traceback.format_exc()}")
400
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
401
 
402
-
403
  if __name__ == "__main__":
404
  logger.info("πŸš€ Starting Speech API server")
405
  logger.info(f"πŸ“Š System status: ASR model: {'βœ…' if asr_model else '❌'}")
 
121
  logger.debug(f"Stack trace: {traceback.format_exc()}")
122
  tts_models[lang] = None
123
 
124
+ # Replace the single translation model with a dictionary of models
125
+ TRANSLATION_MODELS = {
126
+ "pam-eng": "Coco-18/opus-mt-pam-en",
127
+ "eng-pam": "Coco-18/opus-mt-en-pam",
128
+ "tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
129
+ "eng-tgl": "Helsinki-NLP/opus-mt-en-tl"
130
+ # pam-tgl and tgl-pam will be added later
131
+ }
132
 
133
+ logger.info(f"πŸ”„ Loading Translation model: {TRANSLATION_MODELS}")
 
 
134
 
135
+ # Replace the single model initialization with:
136
+ translation_models = {}
137
+ translation_tokenizers = {}
138
+
139
+ for lang_pair, model_id in TRANSLATION_MODELS.items():
140
+ logger.info(f"πŸ”„ Loading Translation model: {model_id}")
141
 
142
+ try:
143
+ translation_tokenizers[lang_pair] = MarianTokenizer.from_pretrained(
144
+ model_id,
145
+ cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
146
+ )
147
+ logger.info(f"βœ… Translation tokenizer loaded successfully for {lang_pair}")
148
+
149
+ translation_models[lang_pair] = MarianMTModel.from_pretrained(
150
+ model_id,
151
+ cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
152
+ )
153
+ translation_models[lang_pair].to(device)
154
+ logger.info(f"βœ… Translation model loaded successfully on {device} for {lang_pair}")
155
+ except Exception as e:
156
+ logger.error(f"❌ Error loading Translation model for {lang_pair}: {str(e)}")
157
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
158
+
159
+
160
 
161
  # Constants
162
  SAMPLE_RATE = 16000
 
362
 
363
  @app.route("/translate", methods=["POST"])
364
  def translate_text():
 
 
 
 
365
  try:
366
  data = request.get_json()
367
  if not data:
 
375
  if not source_text:
376
  logger.warning("⚠️ Translation request with empty text")
377
  return jsonify({"error": "No text provided"}), 400
378
+
379
+ # Map language names to codes
380
+ source_code = LANGUAGE_CODES.get(source_language, source_language)
381
+ target_code = LANGUAGE_CODES.get(target_language, target_language)
382
+
383
+ # Create the language pair key
384
+ lang_pair = f"{source_code}-{target_code}"
385
 
386
  logger.info(f"πŸ”„ Translating from {source_language} to {target_language}: '{source_text}'")
387
 
388
+ # Check if we have a model for this language pair
389
+ if lang_pair not in translation_models:
390
+ logger.warning(f"⚠️ No translation model available for {lang_pair}")
391
+ return jsonify({"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
392
+
393
+ if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
394
+ logger.error(f"❌ Translation model for {lang_pair} not loaded")
395
+ return jsonify({"error": f"Translation model not available"}), 503
396
+
397
  try:
398
+ # Get the appropriate model and tokenizer
399
+ model = translation_models[lang_pair]
400
+ tokenizer = translation_tokenizers[lang_pair]
401
+
402
  # Tokenize the text
403
+ tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
404
  tokenized = {k: v.to(device) for k, v in tokenized.items()}
405
 
406
  # Generate translation
407
  with torch.no_grad():
408
+ translated = model.generate(**tokenized)
409
 
410
  # Decode the translation
411
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
412
 
413
  logger.info(f"βœ… Translation result: '{result}'")
414
 
 
427
  logger.debug(f"Stack trace: {traceback.format_exc()}")
428
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
429
 
 
430
  if __name__ == "__main__":
431
  logger.info("πŸš€ Starting Speech API server")
432
  logger.info(f"πŸ“Š System status: ASR model: {'βœ…' if asr_model else '❌'}")