Update evaluate.py
Browse files- evaluate.py +10 -6
evaluate.py
CHANGED
@@ -15,7 +15,7 @@ from werkzeug.utils import secure_filename
|
|
15 |
from concurrent.futures import ThreadPoolExecutor
|
16 |
|
17 |
# Import necessary functions from translator.py
|
18 |
-
from translator import
|
19 |
|
20 |
# Configure logging
|
21 |
logger = logging.getLogger("speech_api")
|
@@ -153,12 +153,15 @@ def handle_upload_reference(request, reference_dir, sample_rate):
|
|
153 |
|
154 |
|
155 |
def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
156 |
-
"""Handle pronunciation evaluation requests"""
|
157 |
request_id = f"req-{id(request)}" # Create unique ID for this request
|
158 |
logger.info(f"[{request_id}] π Starting new pronunciation evaluation request")
|
159 |
-
|
160 |
temp_dir = None
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
if asr_model is None or asr_processor is None:
|
163 |
logger.error(f"[{request_id}] β Evaluation endpoint called but ASR models aren't loaded")
|
164 |
return jsonify({"error": "ASR model not available"}), 503
|
@@ -265,8 +268,8 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
265 |
if ref_sr != sample_rate:
|
266 |
ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
|
267 |
ref_waveform = ref_waveform.squeeze().numpy()
|
268 |
-
|
269 |
-
# Transcribe reference audio
|
270 |
inputs = asr_processor(
|
271 |
ref_waveform,
|
272 |
sampling_rate=sample_rate,
|
@@ -275,6 +278,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
275 |
)
|
276 |
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
|
277 |
|
|
|
278 |
with torch.no_grad():
|
279 |
logits = asr_model(**inputs).logits
|
280 |
ids = torch.argmax(logits, dim=-1)[0]
|
|
|
15 |
from concurrent.futures import ThreadPoolExecutor
|
16 |
|
17 |
# Import necessary functions from translator.py
|
18 |
+
from translator import get_asr_model, get_asr_processor, LANGUAGE_CODES
|
19 |
|
20 |
# Configure logging
|
21 |
logger = logging.getLogger("speech_api")
|
|
|
153 |
|
154 |
|
155 |
def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
|
156 |
request_id = f"req-{id(request)}" # Create unique ID for this request
|
157 |
logger.info(f"[{request_id}] π Starting new pronunciation evaluation request")
|
158 |
+
|
159 |
temp_dir = None
|
160 |
+
|
161 |
+
# Get the ASR model and processor using the getter functions
|
162 |
+
asr_model = get_asr_model()
|
163 |
+
asr_processor = get_asr_processor()
|
164 |
+
|
165 |
if asr_model is None or asr_processor is None:
|
166 |
logger.error(f"[{request_id}] β Evaluation endpoint called but ASR models aren't loaded")
|
167 |
return jsonify({"error": "ASR model not available"}), 503
|
|
|
268 |
if ref_sr != sample_rate:
|
269 |
ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
|
270 |
ref_waveform = ref_waveform.squeeze().numpy()
|
271 |
+
|
272 |
+
# Transcribe reference audio - use the local asr_model and asr_processor
|
273 |
inputs = asr_processor(
|
274 |
ref_waveform,
|
275 |
sampling_rate=sample_rate,
|
|
|
278 |
)
|
279 |
inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
|
280 |
|
281 |
+
|
282 |
with torch.no_grad():
|
283 |
logits = asr_model(**inputs).logits
|
284 |
ids = torch.argmax(logits, dim=-1)[0]
|