Update app.py
Browse files
app.py
CHANGED
@@ -40,6 +40,7 @@ try:
|
|
40 |
from flask import Flask, request, jsonify, send_file
|
41 |
from flask_cors import CORS
|
42 |
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
|
|
43 |
logger.info("β
All required libraries imported successfully")
|
44 |
except ImportError as e:
|
45 |
logger.critical(f"β Failed to import necessary libraries: {str(e)}")
|
@@ -120,6 +121,31 @@ for lang, model_id in TTS_MODELS.items():
|
|
120 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
121 |
tts_models[lang] = None
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
# Constants
|
124 |
SAMPLE_RATE = 16000
|
125 |
OUTPUT_DIR = "/tmp/audio_outputs"
|
@@ -140,6 +166,7 @@ def health_check():
|
|
140 |
"asr_model": "loaded" if asr_model is not None else "failed",
|
141 |
"tts_models": {lang: "loaded" if model is not None else "failed"
|
142 |
for lang, model in tts_models.items()},
|
|
|
143 |
"device": device
|
144 |
}
|
145 |
return jsonify(health_status)
|
@@ -321,6 +348,57 @@ def download_audio(filename):
|
|
321 |
logger.warning(f"β οΈ Requested file not found: {file_path}")
|
322 |
return jsonify({"error": "File not found"}), 404
|
323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
if __name__ == "__main__":
|
326 |
logger.info("π Starting Speech API server")
|
|
|
40 |
from flask import Flask, request, jsonify, send_file
|
41 |
from flask_cors import CORS
|
42 |
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
43 |
+
from transformers import MarianMTModel, MarianTokenizer
|
44 |
logger.info("β
All required libraries imported successfully")
|
45 |
except ImportError as e:
|
46 |
logger.critical(f"β Failed to import necessary libraries: {str(e)}")
|
|
|
121 |
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
122 |
tts_models[lang] = None
|
123 |
|
124 |
+
# Add this with your other model configurations
|
125 |
+
TRANSLATION_MODEL_ID = "Helsinki-NLP/opus-mt-tc-bible-big-phi-en"
|
126 |
+
logger.info(f"π Loading Translation model: {TRANSLATION_MODEL_ID}")
|
127 |
+
|
128 |
+
# Initialize translation model and tokenizer (add this after your other model initializations)
|
129 |
+
translation_model = None
|
130 |
+
translation_tokenizer = None
|
131 |
+
|
132 |
+
try:
|
133 |
+
translation_tokenizer = MarianTokenizer.from_pretrained(
|
134 |
+
TRANSLATION_MODEL_ID,
|
135 |
+
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
|
136 |
+
)
|
137 |
+
logger.info("β
Translation tokenizer loaded successfully")
|
138 |
+
|
139 |
+
translation_model = MarianMTModel.from_pretrained(
|
140 |
+
TRANSLATION_MODEL_ID,
|
141 |
+
cache_dir=cache_dirs["TRANSFORMERS_CACHE"]
|
142 |
+
)
|
143 |
+
translation_model.to(device)
|
144 |
+
logger.info(f"β
Translation model loaded successfully on {device}")
|
145 |
+
except Exception as e:
|
146 |
+
logger.error(f"β Error loading Translation model: {str(e)}")
|
147 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
148 |
+
|
149 |
# Constants
|
150 |
SAMPLE_RATE = 16000
|
151 |
OUTPUT_DIR = "/tmp/audio_outputs"
|
|
|
166 |
"asr_model": "loaded" if asr_model is not None else "failed",
|
167 |
"tts_models": {lang: "loaded" if model is not None else "failed"
|
168 |
for lang, model in tts_models.items()},
|
169 |
+
"translation_model": "loaded" if translation_model is not None else "failed",
|
170 |
"device": device
|
171 |
}
|
172 |
return jsonify(health_status)
|
|
|
348 |
logger.warning(f"β οΈ Requested file not found: {file_path}")
|
349 |
return jsonify({"error": "File not found"}), 404
|
350 |
|
351 |
+
@app.route("/translate", methods=["POST"])
|
352 |
+
def translate_text():
|
353 |
+
if translation_model is None or translation_tokenizer is None:
|
354 |
+
logger.error("β Translation endpoint called but models aren't loaded")
|
355 |
+
return jsonify({"error": "Translation model not available"}), 503
|
356 |
+
|
357 |
+
try:
|
358 |
+
data = request.get_json()
|
359 |
+
if not data:
|
360 |
+
logger.warning("β οΈ Translation endpoint called with no JSON data")
|
361 |
+
return jsonify({"error": "No JSON data provided"}), 400
|
362 |
+
|
363 |
+
source_text = data.get("text", "").strip()
|
364 |
+
source_language = data.get("source_language", "").lower()
|
365 |
+
target_language = data.get("target_language", "").lower()
|
366 |
+
|
367 |
+
if not source_text:
|
368 |
+
logger.warning("β οΈ Translation request with empty text")
|
369 |
+
return jsonify({"error": "No text provided"}), 400
|
370 |
+
|
371 |
+
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
372 |
+
|
373 |
+
try:
|
374 |
+
# Tokenize the text
|
375 |
+
tokenized = translation_tokenizer(source_text, return_tensors="pt", padding=True)
|
376 |
+
tokenized = {k: v.to(device) for k, v in tokenized.items()}
|
377 |
+
|
378 |
+
# Generate translation
|
379 |
+
with torch.no_grad():
|
380 |
+
translated = translation_model.generate(**tokenized)
|
381 |
+
|
382 |
+
# Decode the translation
|
383 |
+
result = translation_tokenizer.decode(translated[0], skip_special_tokens=True)
|
384 |
+
|
385 |
+
logger.info(f"β
Translation result: '{result}'")
|
386 |
+
|
387 |
+
return jsonify({
|
388 |
+
"translated_text": result,
|
389 |
+
"source_language": source_language,
|
390 |
+
"target_language": target_language
|
391 |
+
})
|
392 |
+
except Exception as e:
|
393 |
+
logger.error(f"β Translation processing failed: {str(e)}")
|
394 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
395 |
+
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
396 |
+
|
397 |
+
except Exception as e:
|
398 |
+
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}")
|
399 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
400 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
401 |
+
|
402 |
|
403 |
if __name__ == "__main__":
|
404 |
logger.info("π Starting Speech API server")
|