Update translator.py
Browse files- translator.py +440 -435
translator.py
CHANGED
@@ -1,435 +1,440 @@
|
|
1 |
-
# translator.py - Handles ASR, TTS, and translation tasks
|
2 |
-
|
3 |
-
import os
|
4 |
-
import sys
|
5 |
-
import logging
|
6 |
-
import traceback
|
7 |
-
import torch
|
8 |
-
import torchaudio
|
9 |
-
import tempfile
|
10 |
-
import soundfile as sf
|
11 |
-
from pydub import AudioSegment
|
12 |
-
from flask import jsonify
|
13 |
-
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
14 |
-
from transformers import MarianMTModel, MarianTokenizer
|
15 |
-
|
16 |
-
# Configure logging
|
17 |
-
logger = logging.getLogger("speech_api")
|
18 |
-
|
19 |
-
# Global variables to store models and processors
|
20 |
-
asr_model = None
|
21 |
-
asr_processor = None
|
22 |
-
tts_models = {}
|
23 |
-
tts_processors = {}
|
24 |
-
translation_models = {}
|
25 |
-
translation_tokenizers = {}
|
26 |
-
|
27 |
-
# Language-specific configurations
|
28 |
-
LANGUAGE_CODES = {
|
29 |
-
"kapampangan": "pam",
|
30 |
-
"filipino": "fil",
|
31 |
-
"english": "eng",
|
32 |
-
"tagalog": "tgl",
|
33 |
-
}
|
34 |
-
|
35 |
-
# TTS Models (Kapampangan, Tagalog, English)
|
36 |
-
TTS_MODELS = {
|
37 |
-
"kapampangan": "facebook/mms-tts-pam",
|
38 |
-
"tagalog": "facebook/mms-tts-tgl",
|
39 |
-
"english": "facebook/mms-tts-eng"
|
40 |
-
}
|
41 |
-
|
42 |
-
# Translation Models
|
43 |
-
TRANSLATION_MODELS = {
|
44 |
-
"pam-eng": "Coco-18/opus-mt-pam-en",
|
45 |
-
"eng-pam": "Coco-18/opus-mt-en-pam",
|
46 |
-
"tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
|
47 |
-
"eng-tgl": "Helsinki-NLP/opus-mt-en-tl",
|
48 |
-
"phi": "Coco-18/opus-mt-phi"
|
49 |
-
}
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
logger.
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
)
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
logger.
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
)
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
)
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
logger.
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
translation_status["pam
|
135 |
-
translation_status["
|
136 |
-
translation_status["pam
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
"
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
temp_audio.
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
audio =
|
182 |
-
audio =
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
)
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
"
|
235 |
-
"
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
logger.
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
logger.
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
inputs =
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
logger.
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
"
|
304 |
-
"
|
305 |
-
"
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
logger.
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
tokenized =
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
"
|
383 |
-
"
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
logger.
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
tokenized =
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
"
|
424 |
-
"
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
logger.
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
logger.
|
434 |
-
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# translator.py - Handles ASR, TTS, and translation tasks
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
import tempfile
|
10 |
+
import soundfile as sf
|
11 |
+
from pydub import AudioSegment
|
12 |
+
from flask import jsonify
|
13 |
+
from transformers import Wav2Vec2ForCTC, AutoProcessor, VitsModel, AutoTokenizer
|
14 |
+
from transformers import MarianMTModel, MarianTokenizer
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logger = logging.getLogger("speech_api")
|
18 |
+
|
19 |
+
# Global variables to store models and processors
|
20 |
+
asr_model = None
|
21 |
+
asr_processor = None
|
22 |
+
tts_models = {}
|
23 |
+
tts_processors = {}
|
24 |
+
translation_models = {}
|
25 |
+
translation_tokenizers = {}
|
26 |
+
|
27 |
+
# Language-specific configurations
|
28 |
+
LANGUAGE_CODES = {
|
29 |
+
"kapampangan": "pam",
|
30 |
+
"filipino": "fil",
|
31 |
+
"english": "eng",
|
32 |
+
"tagalog": "tgl",
|
33 |
+
}
|
34 |
+
|
35 |
+
# TTS Models (Kapampangan, Tagalog, English)
|
36 |
+
TTS_MODELS = {
|
37 |
+
"kapampangan": "facebook/mms-tts-pam",
|
38 |
+
"tagalog": "facebook/mms-tts-tgl",
|
39 |
+
"english": "facebook/mms-tts-eng"
|
40 |
+
}
|
41 |
+
|
42 |
+
# Translation Models
|
43 |
+
TRANSLATION_MODELS = {
|
44 |
+
"pam-eng": "Coco-18/opus-mt-pam-en",
|
45 |
+
"eng-pam": "Coco-18/opus-mt-en-pam",
|
46 |
+
"tgl-eng": "Helsinki-NLP/opus-mt-tl-en",
|
47 |
+
"eng-tgl": "Helsinki-NLP/opus-mt-en-tl",
|
48 |
+
"phi": "Coco-18/opus-mt-phi"
|
49 |
+
}
|
50 |
+
|
51 |
+
def init_models(device):
|
52 |
+
"""Initialize all models required for the API"""
|
53 |
+
global asr_model, asr_processor, tts_models, tts_processors, translation_models, translation_tokenizers
|
54 |
+
|
55 |
+
# Initialize ASR model
|
56 |
+
ASR_MODEL_ID = "Coco-18/mms-asr-tgl-en-safetensor"
|
57 |
+
logger.info(f"π Loading ASR model: {ASR_MODEL_ID}")
|
58 |
+
|
59 |
+
try:
|
60 |
+
asr_processor = AutoProcessor.from_pretrained(
|
61 |
+
ASR_MODEL_ID,
|
62 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
63 |
+
)
|
64 |
+
logger.info("β
ASR processor loaded successfully")
|
65 |
+
|
66 |
+
asr_model = Wav2Vec2ForCTC.from_pretrained(
|
67 |
+
ASR_MODEL_ID,
|
68 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
69 |
+
)
|
70 |
+
asr_model.to(device)
|
71 |
+
logger.info(f"β
ASR model loaded successfully on {device}")
|
72 |
+
except Exception as e:
|
73 |
+
logger.error(f"β Error loading ASR model: {str(e)}")
|
74 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
75 |
+
|
76 |
+
# Initialize TTS models
|
77 |
+
for lang, model_id in TTS_MODELS.items():
|
78 |
+
logger.info(f"π Loading TTS model for {lang}: {model_id}")
|
79 |
+
try:
|
80 |
+
tts_processors[lang] = AutoTokenizer.from_pretrained(
|
81 |
+
model_id,
|
82 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
83 |
+
)
|
84 |
+
logger.info(f"β
{lang} TTS processor loaded")
|
85 |
+
|
86 |
+
tts_models[lang] = VitsModel.from_pretrained(
|
87 |
+
model_id,
|
88 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
89 |
+
)
|
90 |
+
tts_models[lang].to(device)
|
91 |
+
logger.info(f"β
{lang} TTS model loaded on {device}")
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"β Failed to load {lang} TTS model: {str(e)}")
|
94 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
95 |
+
tts_models[lang] = None
|
96 |
+
|
97 |
+
# Initialize translation models
|
98 |
+
for model_key, model_id in TRANSLATION_MODELS.items():
|
99 |
+
logger.info(f"π Loading Translation model: {model_id}")
|
100 |
+
|
101 |
+
try:
|
102 |
+
translation_tokenizers[model_key] = MarianTokenizer.from_pretrained(
|
103 |
+
model_id,
|
104 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
105 |
+
)
|
106 |
+
logger.info(f"β
Translation tokenizer loaded successfully for {model_key}")
|
107 |
+
|
108 |
+
translation_models[model_key] = MarianMTModel.from_pretrained(
|
109 |
+
model_id,
|
110 |
+
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
|
111 |
+
)
|
112 |
+
translation_models[model_key].to(device)
|
113 |
+
logger.info(f"β
Translation model loaded successfully on {device} for {model_key}")
|
114 |
+
except Exception as e:
|
115 |
+
logger.error(f"β Error loading Translation model for {model_key}: {str(e)}")
|
116 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
117 |
+
translation_models[model_key] = None
|
118 |
+
translation_tokenizers[model_key] = None
|
119 |
+
|
120 |
+
|
121 |
+
def check_model_status():
|
122 |
+
"""Check and return the status of all models"""
|
123 |
+
# Initialize direct language pair statuses based on loaded models
|
124 |
+
translation_status = {}
|
125 |
+
|
126 |
+
# Add status for direct model pairs
|
127 |
+
for lang_pair in ["pam-eng", "eng-pam", "tgl-eng", "eng-tgl"]:
|
128 |
+
translation_status[lang_pair] = "loaded" if lang_pair in translation_models and translation_models[
|
129 |
+
lang_pair] is not None else "failed"
|
130 |
+
|
131 |
+
# Add special phi model status
|
132 |
+
phi_status = "loaded" if "phi" in translation_models and translation_models["phi"] is not None else "failed"
|
133 |
+
translation_status["pam-fil"] = phi_status
|
134 |
+
translation_status["fil-pam"] = phi_status
|
135 |
+
translation_status["pam-tgl"] = phi_status # Using phi model but replacing tgl with fil
|
136 |
+
translation_status["tgl-pam"] = phi_status # Using phi model but replacing tgl with fil
|
137 |
+
|
138 |
+
return {
|
139 |
+
"asr_model": "loaded" if asr_model is not None else "failed",
|
140 |
+
"tts_models": {lang: "loaded" if model is not None else "failed"
|
141 |
+
for lang, model in tts_models.items()},
|
142 |
+
"translation_models": translation_status
|
143 |
+
}
|
144 |
+
|
145 |
+
|
146 |
+
def handle_asr_request(request, output_dir, sample_rate):
|
147 |
+
"""Handle ASR (Automatic Speech Recognition) requests"""
|
148 |
+
if asr_model is None or asr_processor is None:
|
149 |
+
logger.error("β ASR endpoint called but models aren't loaded")
|
150 |
+
return jsonify({"error": "ASR model not available"}), 503
|
151 |
+
|
152 |
+
try:
|
153 |
+
if "audio" not in request.files:
|
154 |
+
logger.warning("β οΈ ASR request missing audio file")
|
155 |
+
return jsonify({"error": "No audio file uploaded"}), 400
|
156 |
+
|
157 |
+
audio_file = request.files["audio"]
|
158 |
+
language = request.form.get("language", "english").lower()
|
159 |
+
|
160 |
+
if language not in LANGUAGE_CODES:
|
161 |
+
logger.warning(f"β οΈ Unsupported language requested: {language}")
|
162 |
+
return jsonify(
|
163 |
+
{"error": f"Unsupported language: {language}. Available: {list(LANGUAGE_CODES.keys())}"}), 400
|
164 |
+
|
165 |
+
lang_code = LANGUAGE_CODES[language]
|
166 |
+
logger.info(f"π Processing {language} audio for ASR")
|
167 |
+
|
168 |
+
# Save the uploaded file temporarily
|
169 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
|
170 |
+
temp_audio.write(audio_file.read())
|
171 |
+
temp_audio_path = temp_audio.name
|
172 |
+
logger.debug(f"π Temporary audio saved to {temp_audio_path}")
|
173 |
+
|
174 |
+
# Convert to WAV if necessary
|
175 |
+
wav_path = temp_audio_path
|
176 |
+
if not audio_file.filename.lower().endswith(".wav"):
|
177 |
+
wav_path = os.path.join(output_dir, "converted_audio.wav")
|
178 |
+
logger.info(f"π Converting audio to WAV format: {wav_path}")
|
179 |
+
try:
|
180 |
+
audio = AudioSegment.from_file(temp_audio_path)
|
181 |
+
audio = audio.set_frame_rate(sample_rate).set_channels(1)
|
182 |
+
audio.export(wav_path, format="wav")
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"β Audio conversion failed: {str(e)}")
|
185 |
+
return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
|
186 |
+
|
187 |
+
# Load and process the WAV file
|
188 |
+
try:
|
189 |
+
waveform, sr = torchaudio.load(wav_path)
|
190 |
+
logger.debug(f"β
Audio loaded: {wav_path} (Sample rate: {sr}Hz)")
|
191 |
+
|
192 |
+
# Resample if needed
|
193 |
+
if sr != sample_rate:
|
194 |
+
logger.info(f"π Resampling audio from {sr}Hz to {sample_rate}Hz")
|
195 |
+
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
|
196 |
+
|
197 |
+
waveform = waveform / torch.max(torch.abs(waveform))
|
198 |
+
except Exception as e:
|
199 |
+
logger.error(f"β Failed to load or process audio: {str(e)}")
|
200 |
+
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
201 |
+
|
202 |
+
# Process audio for ASR
|
203 |
+
try:
|
204 |
+
inputs = asr_processor(
|
205 |
+
waveform.squeeze().numpy(),
|
206 |
+
sampling_rate=sample_rate,
|
207 |
+
return_tensors="pt",
|
208 |
+
language=lang_code
|
209 |
+
)
|
210 |
+
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
|
211 |
+
except Exception as e:
|
212 |
+
logger.error(f"β ASR preprocessing failed: {str(e)}")
|
213 |
+
return jsonify({"error": f"ASR preprocessing failed: {str(e)}"}), 500
|
214 |
+
|
215 |
+
# Perform ASR
|
216 |
+
try:
|
217 |
+
with torch.no_grad():
|
218 |
+
logits = asr_model(**inputs).logits
|
219 |
+
ids = torch.argmax(logits, dim=-1)[0]
|
220 |
+
transcription = asr_processor.decode(ids)
|
221 |
+
|
222 |
+
logger.info(f"β
Transcription ({language}): {transcription}")
|
223 |
+
|
224 |
+
# Clean up temp files
|
225 |
+
try:
|
226 |
+
os.unlink(temp_audio_path)
|
227 |
+
if wav_path != temp_audio_path:
|
228 |
+
os.unlink(wav_path)
|
229 |
+
except Exception as e:
|
230 |
+
logger.warning(f"β οΈ Failed to clean up temp files: {str(e)}")
|
231 |
+
|
232 |
+
return jsonify({
|
233 |
+
"transcription": transcription,
|
234 |
+
"language": language,
|
235 |
+
"language_code": lang_code
|
236 |
+
})
|
237 |
+
except Exception as e:
|
238 |
+
logger.error(f"β ASR inference failed: {str(e)}")
|
239 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
240 |
+
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
logger.error(f"β Unhandled exception in ASR endpoint: {str(e)}")
|
244 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
245 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
246 |
+
|
247 |
+
def handle_tts_request(request, output_dir):
|
248 |
+
"""Handle TTS (Text-to-Speech) requests"""
|
249 |
+
try:
|
250 |
+
data = request.get_json()
|
251 |
+
if not data:
|
252 |
+
logger.warning("β οΈ TTS endpoint called with no JSON data")
|
253 |
+
return jsonify({"error": "No JSON data provided"}), 400
|
254 |
+
|
255 |
+
text_input = data.get("text", "").strip()
|
256 |
+
language = data.get("language", "kapampangan").lower()
|
257 |
+
|
258 |
+
if not text_input:
|
259 |
+
logger.warning("β οΈ TTS request with empty text")
|
260 |
+
return jsonify({"error": "No text provided"}), 400
|
261 |
+
|
262 |
+
if language not in TTS_MODELS:
|
263 |
+
logger.warning(f"β οΈ TTS requested for unsupported language: {language}")
|
264 |
+
return jsonify({"error": f"Invalid language. Available options: {list(TTS_MODELS.keys())}"}), 400
|
265 |
+
|
266 |
+
if tts_models[language] is None:
|
267 |
+
logger.error(f"β TTS model for {language} not loaded")
|
268 |
+
return jsonify({"error": f"TTS model for {language} not available"}), 503
|
269 |
+
|
270 |
+
logger.info(f"π Generating TTS for language: {language}, text: '{text_input}'")
|
271 |
+
|
272 |
+
try:
|
273 |
+
processor = tts_processors[language]
|
274 |
+
model = tts_models[language]
|
275 |
+
inputs = processor(text_input, return_tensors="pt")
|
276 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
277 |
+
except Exception as e:
|
278 |
+
logger.error(f"β TTS preprocessing failed: {str(e)}")
|
279 |
+
return jsonify({"error": f"TTS preprocessing failed: {str(e)}"}), 500
|
280 |
+
|
281 |
+
# Generate speech
|
282 |
+
try:
|
283 |
+
with torch.no_grad():
|
284 |
+
output = model(**inputs).waveform
|
285 |
+
waveform = output.squeeze().cpu().numpy()
|
286 |
+
except Exception as e:
|
287 |
+
logger.error(f"β TTS inference failed: {str(e)}")
|
288 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
289 |
+
return jsonify({"error": f"TTS inference failed: {str(e)}"}), 500
|
290 |
+
|
291 |
+
# Save to file
|
292 |
+
try:
|
293 |
+
output_filename = os.path.join(output_dir, f"{language}_output.wav")
|
294 |
+
sampling_rate = model.config.sampling_rate
|
295 |
+
sf.write(output_filename, waveform, sampling_rate)
|
296 |
+
logger.info(f"β
Speech generated! File saved: {output_filename}")
|
297 |
+
except Exception as e:
|
298 |
+
logger.error(f"β Failed to save audio file: {str(e)}")
|
299 |
+
return jsonify({"error": f"Failed to save audio file: {str(e)}"}), 500
|
300 |
+
|
301 |
+
return jsonify({
|
302 |
+
"message": "TTS audio generated",
|
303 |
+
"file_url": f"/download/{os.path.basename(output_filename)}",
|
304 |
+
"language": language,
|
305 |
+
"text_length": len(text_input)
|
306 |
+
})
|
307 |
+
except Exception as e:
|
308 |
+
logger.error(f"β Unhandled exception in TTS endpoint: {str(e)}")
|
309 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
310 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
311 |
+
|
312 |
+
def handle_translation_request(request):
|
313 |
+
"""Handle translation requests"""
|
314 |
+
try:
|
315 |
+
data = request.get_json()
|
316 |
+
if not data:
|
317 |
+
logger.warning("β οΈ Translation endpoint called with no JSON data")
|
318 |
+
return jsonify({"error": "No JSON data provided"}), 400
|
319 |
+
|
320 |
+
source_text = data.get("text", "").strip()
|
321 |
+
source_language = data.get("source_language", "").lower()
|
322 |
+
target_language = data.get("target_language", "").lower()
|
323 |
+
|
324 |
+
if not source_text:
|
325 |
+
logger.warning("β οΈ Translation request with empty text")
|
326 |
+
return jsonify({"error": "No text provided"}), 400
|
327 |
+
|
328 |
+
# Map language names to codes
|
329 |
+
source_code = LANGUAGE_CODES.get(source_language, source_language)
|
330 |
+
target_code = LANGUAGE_CODES.get(target_language, target_language)
|
331 |
+
|
332 |
+
logger.info(f"π Translating from {source_language} to {target_language}: '{source_text}'")
|
333 |
+
|
334 |
+
# Special handling for pam-fil, fil-pam, pam-tgl and tgl-pam using the phi model
|
335 |
+
use_phi_model = False
|
336 |
+
actual_source_code = source_code
|
337 |
+
actual_target_code = target_code
|
338 |
+
|
339 |
+
# Check if we need to use the phi model with fil replacement
|
340 |
+
if (source_code == "pam" and target_code == "fil") or (source_code == "fil" and target_code == "pam"):
|
341 |
+
use_phi_model = True
|
342 |
+
elif (source_code == "pam" and target_code == "tgl"):
|
343 |
+
use_phi_model = True
|
344 |
+
actual_target_code = "fil" # Replace tgl with fil for the phi model
|
345 |
+
elif (source_code == "tgl" and target_code == "pam"):
|
346 |
+
use_phi_model = True
|
347 |
+
actual_source_code = "fil" # Replace tgl with fil for the phi model
|
348 |
+
|
349 |
+
if use_phi_model:
|
350 |
+
model_key = "phi"
|
351 |
+
|
352 |
+
# Check if we have the phi model
|
353 |
+
if model_key not in translation_models or translation_models[model_key] is None:
|
354 |
+
logger.error(f"β Translation model for {model_key} not loaded")
|
355 |
+
return jsonify({"error": f"Translation model not available"}), 503
|
356 |
+
|
357 |
+
try:
|
358 |
+
# Get the phi model and tokenizer
|
359 |
+
model = translation_models[model_key]
|
360 |
+
tokenizer = translation_tokenizers[model_key]
|
361 |
+
|
362 |
+
# Prepend target language token to input
|
363 |
+
input_text = f">>{actual_target_code}<< {source_text}"
|
364 |
+
|
365 |
+
logger.info(f"π Using phi model with input: '{input_text}'")
|
366 |
+
|
367 |
+
# Tokenize the text
|
368 |
+
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
369 |
+
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
370 |
+
|
371 |
+
# Generate translation
|
372 |
+
with torch.no_grad():
|
373 |
+
translated = model.generate(**tokenized)
|
374 |
+
|
375 |
+
# Decode the translation
|
376 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
377 |
+
|
378 |
+
logger.info(f"β
Translation result: '{result}'")
|
379 |
+
|
380 |
+
return jsonify({
|
381 |
+
"translated_text": result,
|
382 |
+
"source_language": source_language,
|
383 |
+
"target_language": target_language
|
384 |
+
})
|
385 |
+
except Exception as e:
|
386 |
+
logger.error(f"β Translation processing failed: {str(e)}")
|
387 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
388 |
+
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
389 |
+
else:
|
390 |
+
# Create the regular language pair key for other language pairs
|
391 |
+
lang_pair = f"{source_code}-{target_code}"
|
392 |
+
|
393 |
+
# Check if we have a model for this language pair
|
394 |
+
if lang_pair not in translation_models:
|
395 |
+
logger.warning(f"β οΈ No translation model available for {lang_pair}")
|
396 |
+
return jsonify(
|
397 |
+
{"error": f"Translation from {source_language} to {target_language} is not supported yet"}), 400
|
398 |
+
|
399 |
+
if translation_models[lang_pair] is None or translation_tokenizers[lang_pair] is None:
|
400 |
+
logger.error(f"β Translation model for {lang_pair} not loaded")
|
401 |
+
return jsonify({"error": f"Translation model not available"}), 503
|
402 |
+
|
403 |
+
try:
|
404 |
+
# Regular translation process for other language pairs
|
405 |
+
model = translation_models[lang_pair]
|
406 |
+
tokenizer = translation_tokenizers[lang_pair]
|
407 |
+
|
408 |
+
# Tokenize the text
|
409 |
+
tokenized = tokenizer(source_text, return_tensors="pt", padding=True)
|
410 |
+
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
411 |
+
|
412 |
+
# Generate translation
|
413 |
+
with torch.no_grad():
|
414 |
+
translated = model.generate(**tokenized)
|
415 |
+
|
416 |
+
# Decode the translation
|
417 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
418 |
+
|
419 |
+
logger.info(f"β
Translation result: '{result}'")
|
420 |
+
|
421 |
+
return jsonify({
|
422 |
+
"translated_text": result,
|
423 |
+
"source_language": source_language,
|
424 |
+
"target_language": target_language
|
425 |
+
})
|
426 |
+
except Exception as e:
|
427 |
+
logger.error(f"β Translation processing failed: {str(e)}")
|
428 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
429 |
+
return jsonify({"error": f"Translation processing failed: {str(e)}"}), 500
|
430 |
+
|
431 |
+
except Exception as e:
|
432 |
+
logger.error(f"β Unhandled exception in translation endpoint: {str(e)}")
|
433 |
+
logger.debug(f"Stack trace: {traceback.format_exc()}")
|
434 |
+
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
435 |
+
|
436 |
+
def get_asr_model():
|
437 |
+
return asr_model
|
438 |
+
|
439 |
+
def get_asr_processor():
|
440 |
+
return asr_processor
|