Coco-18 commited on
Commit
1a61e31
Β·
verified Β·
1 Parent(s): e7b87ef

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +10 -6
evaluate.py CHANGED
@@ -15,7 +15,7 @@ from werkzeug.utils import secure_filename
15
  from concurrent.futures import ThreadPoolExecutor
16
 
17
  # Import necessary functions from translator.py
18
- from translator import asr_model, asr_processor, LANGUAGE_CODES
19
 
20
  # Configure logging
21
  logger = logging.getLogger("speech_api")
@@ -153,12 +153,15 @@ def handle_upload_reference(request, reference_dir, sample_rate):
153
 
154
 
155
  def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
156
- """Handle pronunciation evaluation requests"""
157
  request_id = f"req-{id(request)}" # Create unique ID for this request
158
  logger.info(f"[{request_id}] πŸ†• Starting new pronunciation evaluation request")
159
-
160
  temp_dir = None
161
-
 
 
 
 
162
  if asr_model is None or asr_processor is None:
163
  logger.error(f"[{request_id}] ❌ Evaluation endpoint called but ASR models aren't loaded")
164
  return jsonify({"error": "ASR model not available"}), 503
@@ -265,8 +268,8 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
265
  if ref_sr != sample_rate:
266
  ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
267
  ref_waveform = ref_waveform.squeeze().numpy()
268
-
269
- # Transcribe reference audio
270
  inputs = asr_processor(
271
  ref_waveform,
272
  sampling_rate=sample_rate,
@@ -275,6 +278,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
275
  )
276
  inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
277
 
 
278
  with torch.no_grad():
279
  logits = asr_model(**inputs).logits
280
  ids = torch.argmax(logits, dim=-1)[0]
 
15
  from concurrent.futures import ThreadPoolExecutor
16
 
17
  # Import necessary functions from translator.py
18
+ from translator import get_asr_model, get_asr_processor, LANGUAGE_CODES
19
 
20
  # Configure logging
21
  logger = logging.getLogger("speech_api")
 
153
 
154
 
155
  def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
 
156
  request_id = f"req-{id(request)}" # Create unique ID for this request
157
  logger.info(f"[{request_id}] πŸ†• Starting new pronunciation evaluation request")
158
+
159
  temp_dir = None
160
+
161
+ # Get the ASR model and processor using the getter functions
162
+ asr_model = get_asr_model()
163
+ asr_processor = get_asr_processor()
164
+
165
  if asr_model is None or asr_processor is None:
166
  logger.error(f"[{request_id}] ❌ Evaluation endpoint called but ASR models aren't loaded")
167
  return jsonify({"error": "ASR model not available"}), 503
 
268
  if ref_sr != sample_rate:
269
  ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
270
  ref_waveform = ref_waveform.squeeze().numpy()
271
+
272
+ # Transcribe reference audio - use the local asr_model and asr_processor
273
  inputs = asr_processor(
274
  ref_waveform,
275
  sampling_rate=sample_rate,
 
278
  )
279
  inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
280
 
281
+
282
  with torch.no_grad():
283
  logits = asr_model(**inputs).logits
284
  ids = torch.argmax(logits, dim=-1)[0]