Update evaluate.py
Browse files- evaluate.py +82 -24
evaluate.py
CHANGED
@@ -20,6 +20,9 @@ from translator import get_asr_model, get_asr_processor, LANGUAGE_CODES
|
|
20 |
# Configure logging
|
21 |
logger = logging.getLogger("speech_api")
|
22 |
|
|
|
|
|
|
|
23 |
def calculate_similarity(text1, text2):
|
24 |
"""Calculate text similarity percentage."""
|
25 |
def clean_text(text):
|
@@ -309,9 +312,9 @@ def handle_upload_reference(request, reference_dir, sample_rate):
|
|
309 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
310 |
|
311 |
def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
312 |
-
"""Handle pronunciation evaluation requests with
|
313 |
-
request_id = f"req-{id(request)}"
|
314 |
-
logger.info(f"[{request_id}] π Starting
|
315 |
|
316 |
temp_dir = None
|
317 |
|
@@ -324,6 +327,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
324 |
return jsonify({"error": "ASR model not available"}), 503
|
325 |
|
326 |
try:
|
|
|
327 |
if "audio" not in request.files:
|
328 |
logger.warning(f"[{request_id}] β οΈ Evaluation request missing audio file")
|
329 |
return jsonify({"error": "No audio file uploaded"}), 400
|
@@ -337,6 +341,19 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
337 |
logger.warning(f"[{request_id}] β οΈ No reference locator provided")
|
338 |
return jsonify({"error": "Reference locator is required"}), 400
|
339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
# Construct full reference directory path
|
341 |
reference_dir_path = os.path.join(reference_dir, reference_locator)
|
342 |
logger.info(f"[{request_id}] π Reference directory path: {reference_dir_path}")
|
@@ -366,7 +383,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
366 |
}), 404
|
367 |
|
368 |
lang_code = LANGUAGE_CODES.get(language, language)
|
369 |
-
logger.info(f"[{request_id}] π Evaluating pronunciation for reference: {reference_locator}
|
370 |
|
371 |
# Create a request-specific temp directory to avoid conflicts
|
372 |
temp_dir = os.path.join(output_dir, f"temp_{request_id}")
|
@@ -375,7 +392,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
375 |
# Process user audio
|
376 |
user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
|
377 |
with open(user_audio_path, 'wb') as f:
|
378 |
-
f.write(
|
379 |
|
380 |
try:
|
381 |
logger.info(f"[{request_id}] π Processing user audio file")
|
@@ -415,18 +432,23 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
415 |
logger.error(f"[{request_id}] β ASR inference failed: {str(e)}")
|
416 |
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
417 |
|
418 |
-
# OPTIMIZATION:
|
419 |
import multiprocessing
|
|
|
420 |
|
421 |
-
#
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
428 |
|
429 |
-
logger.info(f"[{request_id}] π
|
430 |
|
431 |
# Function to process a single reference file
|
432 |
def process_reference_file(ref_file):
|
@@ -438,8 +460,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
438 |
ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
|
439 |
ref_waveform = ref_waveform.squeeze().numpy()
|
440 |
|
441 |
-
# Transcribe reference audio
|
442 |
-
# Remove language parameter if causing warnings
|
443 |
inputs = asr_processor(
|
444 |
ref_waveform,
|
445 |
sampling_rate=sample_rate,
|
@@ -472,20 +493,40 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
472 |
"error": str(e)
|
473 |
}
|
474 |
|
475 |
-
#
|
476 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
477 |
-
|
478 |
|
479 |
-
# Find the best result
|
480 |
best_score = 0
|
481 |
best_reference = None
|
482 |
best_transcription = None
|
483 |
|
484 |
-
for result in
|
485 |
if result["similarity_score"] > best_score:
|
486 |
best_score = result["similarity_score"]
|
487 |
best_reference = result["reference_file"]
|
488 |
best_transcription = result["reference_text"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
# Clean up temp files
|
491 |
try:
|
@@ -514,17 +555,34 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
|
514 |
logger.info(f"[{request_id}] β
Evaluation complete")
|
515 |
|
516 |
# Sort results by score descending
|
517 |
-
|
518 |
-
|
519 |
-
|
|
|
520 |
"is_correct": is_correct,
|
521 |
"score": best_score,
|
522 |
"feedback": feedback,
|
523 |
"user_transcription": user_transcription,
|
524 |
"best_reference_transcription": best_transcription,
|
525 |
"reference_locator": reference_locator,
|
526 |
-
"details":
|
|
|
|
|
|
|
527 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
|
529 |
except Exception as e:
|
530 |
logger.error(f"[{request_id}] β Unhandled exception in evaluation endpoint: {str(e)}")
|
|
|
20 |
# Configure logging
|
21 |
logger = logging.getLogger("speech_api")
|
22 |
|
23 |
+
if not hasattr(handle_evaluation_request, 'cache'):
|
24 |
+
handle_evaluation_request.cache = {}
|
25 |
+
|
26 |
def calculate_similarity(text1, text2):
|
27 |
"""Calculate text similarity percentage."""
|
28 |
def clean_text(text):
|
|
|
312 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
313 |
|
314 |
def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
|
315 |
+
"""Handle pronunciation evaluation requests with speed optimizations"""
|
316 |
+
request_id = f"req-{id(request)}"
|
317 |
+
logger.info(f"[{request_id}] π Starting pronunciation evaluation request")
|
318 |
|
319 |
temp_dir = None
|
320 |
|
|
|
327 |
return jsonify({"error": "ASR model not available"}), 503
|
328 |
|
329 |
try:
|
330 |
+
# OPTIMIZATION 1: Check cache first for identical audio
|
331 |
if "audio" not in request.files:
|
332 |
logger.warning(f"[{request_id}] β οΈ Evaluation request missing audio file")
|
333 |
return jsonify({"error": "No audio file uploaded"}), 400
|
|
|
341 |
logger.warning(f"[{request_id}] β οΈ No reference locator provided")
|
342 |
return jsonify({"error": "Reference locator is required"}), 400
|
343 |
|
344 |
+
# OPTIMIZATION 2: Simple caching based on audio content hash + reference_locator
|
345 |
+
audio_content = audio_file.read()
|
346 |
+
audio_file.seek(0) # Reset file pointer after reading
|
347 |
+
|
348 |
+
import hashlib
|
349 |
+
audio_hash = hashlib.md5(audio_content).hexdigest()
|
350 |
+
cache_key = f"{audio_hash}_{reference_locator}_{language}"
|
351 |
+
|
352 |
+
# Check in-memory cache (define EVALUATION_CACHE at module level)
|
353 |
+
if hasattr(handle_evaluation_request, 'cache') and cache_key in handle_evaluation_request.cache:
|
354 |
+
logger.info(f"[{request_id}] β
Using cached evaluation result")
|
355 |
+
return handle_evaluation_request.cache[cache_key]
|
356 |
+
|
357 |
# Construct full reference directory path
|
358 |
reference_dir_path = os.path.join(reference_dir, reference_locator)
|
359 |
logger.info(f"[{request_id}] π Reference directory path: {reference_dir_path}")
|
|
|
383 |
}), 404
|
384 |
|
385 |
lang_code = LANGUAGE_CODES.get(language, language)
|
386 |
+
logger.info(f"[{request_id}] π Evaluating pronunciation for reference: {reference_locator}")
|
387 |
|
388 |
# Create a request-specific temp directory to avoid conflicts
|
389 |
temp_dir = os.path.join(output_dir, f"temp_{request_id}")
|
|
|
392 |
# Process user audio
|
393 |
user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
|
394 |
with open(user_audio_path, 'wb') as f:
|
395 |
+
f.write(audio_content) # Use the content we already read
|
396 |
|
397 |
try:
|
398 |
logger.info(f"[{request_id}] π Processing user audio file")
|
|
|
432 |
logger.error(f"[{request_id}] β ASR inference failed: {str(e)}")
|
433 |
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
434 |
|
435 |
+
# OPTIMIZATION 3: Use a smaller sample of reference files
|
436 |
import multiprocessing
|
437 |
+
import random
|
438 |
|
439 |
+
# OPTIMIZATION 4: Limit to just a few files for initial comparison
|
440 |
+
# If we have many reference files, randomly sample some for quick evaluation
|
441 |
+
if len(reference_files) > 3:
|
442 |
+
# Randomly select 3 files for faster comparison
|
443 |
+
reference_files_sample = random.sample(reference_files, 3)
|
444 |
+
else:
|
445 |
+
reference_files_sample = reference_files
|
446 |
+
|
447 |
+
# Determine optimal number of workers based on CPU count (but keep it small)
|
448 |
+
max_workers = min(multiprocessing.cpu_count(), len(reference_files_sample), 3)
|
449 |
+
initial_results = []
|
450 |
|
451 |
+
logger.info(f"[{request_id}] π Quick scan: processing {len(reference_files_sample)} reference files")
|
452 |
|
453 |
# Function to process a single reference file
|
454 |
def process_reference_file(ref_file):
|
|
|
460 |
ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
|
461 |
ref_waveform = ref_waveform.squeeze().numpy()
|
462 |
|
463 |
+
# Transcribe reference audio
|
|
|
464 |
inputs = asr_processor(
|
465 |
ref_waveform,
|
466 |
sampling_rate=sample_rate,
|
|
|
493 |
"error": str(e)
|
494 |
}
|
495 |
|
496 |
+
# Process the sample files in parallel
|
497 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
498 |
+
initial_results = list(executor.map(process_reference_file, reference_files_sample))
|
499 |
|
500 |
+
# Find the best result from the initial sample
|
501 |
best_score = 0
|
502 |
best_reference = None
|
503 |
best_transcription = None
|
504 |
|
505 |
+
for result in initial_results:
|
506 |
if result["similarity_score"] > best_score:
|
507 |
best_score = result["similarity_score"]
|
508 |
best_reference = result["reference_file"]
|
509 |
best_transcription = result["reference_text"]
|
510 |
+
|
511 |
+
# OPTIMIZATION 5: If we already found a very good match, don't process more files
|
512 |
+
all_results = initial_results.copy()
|
513 |
+
remaining_files = [f for f in reference_files if f not in reference_files_sample]
|
514 |
+
|
515 |
+
# Only process more files if our best score isn't already very good
|
516 |
+
if best_score < 80.0 and remaining_files:
|
517 |
+
logger.info(f"[{request_id}] π Score {best_score:.2f}% not high enough, checking {len(remaining_files)} more references")
|
518 |
+
|
519 |
+
# Process remaining files
|
520 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
521 |
+
additional_results = list(executor.map(process_reference_file, remaining_files[:5])) # Process max 5 more
|
522 |
+
all_results.extend(additional_results)
|
523 |
+
|
524 |
+
# Update best result if we found a better one
|
525 |
+
for result in additional_results:
|
526 |
+
if result["similarity_score"] > best_score:
|
527 |
+
best_score = result["similarity_score"]
|
528 |
+
best_reference = result["reference_file"]
|
529 |
+
best_transcription = result["reference_text"]
|
530 |
|
531 |
# Clean up temp files
|
532 |
try:
|
|
|
555 |
logger.info(f"[{request_id}] β
Evaluation complete")
|
556 |
|
557 |
# Sort results by score descending
|
558 |
+
all_results.sort(key=lambda x: x["similarity_score"], reverse=True)
|
559 |
+
|
560 |
+
# Create response
|
561 |
+
response = jsonify({
|
562 |
"is_correct": is_correct,
|
563 |
"score": best_score,
|
564 |
"feedback": feedback,
|
565 |
"user_transcription": user_transcription,
|
566 |
"best_reference_transcription": best_transcription,
|
567 |
"reference_locator": reference_locator,
|
568 |
+
"details": all_results,
|
569 |
+
"total_references_compared": len(all_results),
|
570 |
+
"total_available_references": len(reference_files),
|
571 |
+
"quick_evaluation": True
|
572 |
})
|
573 |
+
|
574 |
+
# OPTIMIZATION 6: Cache the result for future requests
|
575 |
+
if not hasattr(handle_evaluation_request, 'cache'):
|
576 |
+
handle_evaluation_request.cache = {}
|
577 |
+
|
578 |
+
# Store in cache (limit cache size to avoid memory issues)
|
579 |
+
MAX_CACHE_SIZE = 50
|
580 |
+
handle_evaluation_request.cache[cache_key] = response
|
581 |
+
if len(handle_evaluation_request.cache) > MAX_CACHE_SIZE:
|
582 |
+
# Remove oldest entry (simplified approach)
|
583 |
+
handle_evaluation_request.cache.pop(next(iter(handle_evaluation_request.cache)))
|
584 |
+
|
585 |
+
return response
|
586 |
|
587 |
except Exception as e:
|
588 |
logger.error(f"[{request_id}] β Unhandled exception in evaluation endpoint: {str(e)}")
|